| import os | |
| def decode_sequence(vocab, seq): | |
| N, T = seq.size() | |
| sents = [] | |
| for n in range(N): | |
| words = [] | |
| for t in range(T): | |
| ix = seq[n, t] | |
| if ix == 0: | |
| break | |
| words.append(vocab[ix]) | |
| sent = ' '.join(words) | |
| sents.append(sent) | |
| return sents | |
| def decode_sequence_bert(tokenizer, seq, sep_token_id): | |
| N, T = seq.size() | |
| seq = seq.data.cpu().numpy() | |
| sents = [] | |
| for n in range(N): | |
| words = [] | |
| for t in range(T): | |
| ix = seq[n, t] | |
| if ix == sep_token_id: | |
| break | |
| words.append(tokenizer.ids_to_tokens[ix]) | |
| sent = tokenizer.convert_tokens_to_string(words) | |
| sents.append(sent) | |
| return sents |