|
|
|
|
|
import json |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
class ShivikM1Tokenizer(PreTrainedTokenizer): |
|
|
vocab_files_names = { |
|
|
"vocab_file": "vocab.json", |
|
|
"merges_file": "merges.txt" |
|
|
} |
|
|
|
|
|
def __init__(self, vocab_file=None, merges_file=None, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if vocab_file is None: |
|
|
raise ValueError("vocab_file must be provided.") |
|
|
|
|
|
with open(vocab_file, "r", encoding="utf-8") as f: |
|
|
self.encoder = json.load(f) |
|
|
|
|
|
self.decoder = {v: k for k, v in self.encoder.items()} |
|
|
self.vocab_file = vocab_file |
|
|
self.merges_file = merges_file |
|
|
|
|
|
def get_vocab(self): |
|
|
return dict(self.encoder) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return len(self.encoder) |
|
|
|
|
|
def _tokenize(self, text): |
|
|
return text.split() |
|
|
|
|
|
def _convert_token_to_id(self, token): |
|
|
return self.encoder.get(token, self.encoder.get("<unk>", 0)) |
|
|
|
|
|
def _convert_id_to_token(self, idx): |
|
|
return self.decoder.get(idx, "<unk>") |
|
|
|
|
|
def convert_tokens_to_string(self, tokens): |
|
|
return " ".join(tokens) |
|
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids): |
|
|
return token_ids |
|
|
|