|
|
|
|
|
|
|
|
|
|
|
|
|
|
use crate::{Error, Result}; |
|
|
use std::collections::HashMap; |
|
|
use std::path::Path; |
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub struct TokenizerConfig { |
|
|
|
|
|
pub model_path: String, |
|
|
|
|
|
pub vocab_size: usize, |
|
|
|
|
|
pub bos_id: i64, |
|
|
|
|
|
pub eos_id: i64, |
|
|
|
|
|
pub unk_id: i64, |
|
|
|
|
|
pub pad_id: i64, |
|
|
} |
|
|
|
|
|
impl Default for TokenizerConfig { |
|
|
fn default() -> Self { |
|
|
Self { |
|
|
model_path: "models/bpe.model".to_string(), |
|
|
vocab_size: 6681, |
|
|
bos_id: 1, |
|
|
eos_id: 2, |
|
|
unk_id: 0, |
|
|
pad_id: 3, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Debug)] |
|
|
pub struct TextTokenizer { |
|
|
|
|
|
config: TokenizerConfig, |
|
|
|
|
|
token_to_id: HashMap<String, i64>, |
|
|
|
|
|
id_to_token: HashMap<i64, String>, |
|
|
|
|
|
char_vocab: HashMap<char, i64>, |
|
|
} |
|
|
|
|
|
impl TextTokenizer { |
|
|
|
|
|
pub fn new(config: TokenizerConfig) -> Result<Self> { |
|
|
let mut token_to_id = HashMap::new(); |
|
|
let mut id_to_token = HashMap::new(); |
|
|
let mut char_vocab = HashMap::new(); |
|
|
|
|
|
|
|
|
token_to_id.insert("<unk>".to_string(), config.unk_id); |
|
|
token_to_id.insert("<s>".to_string(), config.bos_id); |
|
|
token_to_id.insert("</s>".to_string(), config.eos_id); |
|
|
token_to_id.insert("<pad>".to_string(), config.pad_id); |
|
|
|
|
|
id_to_token.insert(config.unk_id, "<unk>".to_string()); |
|
|
id_to_token.insert(config.bos_id, "<s>".to_string()); |
|
|
id_to_token.insert(config.eos_id, "</s>".to_string()); |
|
|
id_to_token.insert(config.pad_id, "<pad>".to_string()); |
|
|
|
|
|
|
|
|
let mut next_id = 4i64; |
|
|
for c in ' '..='~' { |
|
|
char_vocab.insert(c, next_id); |
|
|
token_to_id.insert(c.to_string(), next_id); |
|
|
id_to_token.insert(next_id, c.to_string()); |
|
|
next_id += 1; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for code_point in 0x4E00u32..=0x9FFF { |
|
|
if let Some(c) = char::from_u32(code_point) { |
|
|
char_vocab.insert(c, next_id); |
|
|
token_to_id.insert(c.to_string(), next_id); |
|
|
id_to_token.insert(next_id, c.to_string()); |
|
|
next_id += 1; |
|
|
|
|
|
if next_id >= config.vocab_size as i64 { |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
Ok(Self { |
|
|
config, |
|
|
token_to_id, |
|
|
id_to_token, |
|
|
char_vocab, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
|
|
let path = path.as_ref(); |
|
|
if !path.exists() { |
|
|
return Err(Error::FileNotFound(path.display().to_string())); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
let config = TokenizerConfig { |
|
|
model_path: path.display().to_string(), |
|
|
..Default::default() |
|
|
}; |
|
|
|
|
|
Self::new(config) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode(&self, text: &str) -> Result<Vec<i64>> { |
|
|
let mut tokens = Vec::new(); |
|
|
|
|
|
|
|
|
tokens.push(self.config.bos_id); |
|
|
|
|
|
|
|
|
|
|
|
for ch in text.chars() { |
|
|
if let Some(&id) = self.char_vocab.get(&ch) { |
|
|
tokens.push(id); |
|
|
} else if let Some(&id) = self.token_to_id.get(&ch.to_string()) { |
|
|
tokens.push(id); |
|
|
} else { |
|
|
|
|
|
tokens.push(self.config.unk_id); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
tokens.push(self.config.eos_id); |
|
|
|
|
|
Ok(tokens) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn encode_without_special(&self, text: &str) -> Result<Vec<i64>> { |
|
|
let mut tokens = Vec::new(); |
|
|
|
|
|
for ch in text.chars() { |
|
|
if let Some(&id) = self.char_vocab.get(&ch) { |
|
|
tokens.push(id); |
|
|
} else if let Some(&id) = self.token_to_id.get(&ch.to_string()) { |
|
|
tokens.push(id); |
|
|
} else { |
|
|
tokens.push(self.config.unk_id); |
|
|
} |
|
|
} |
|
|
|
|
|
Ok(tokens) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn decode(&self, tokens: &[i64]) -> Result<String> { |
|
|
let mut text = String::new(); |
|
|
|
|
|
for &token_id in tokens { |
|
|
|
|
|
if token_id == self.config.bos_id |
|
|
|| token_id == self.config.eos_id |
|
|
|| token_id == self.config.pad_id |
|
|
{ |
|
|
continue; |
|
|
} |
|
|
|
|
|
if let Some(token) = self.id_to_token.get(&token_id) { |
|
|
text.push_str(token); |
|
|
} else { |
|
|
|
|
|
text.push('?'); |
|
|
} |
|
|
} |
|
|
|
|
|
Ok(text) |
|
|
} |
|
|
|
|
|
|
|
|
pub fn vocab_size(&self) -> usize { |
|
|
self.config.vocab_size |
|
|
} |
|
|
|
|
|
|
|
|
pub fn bos_id(&self) -> i64 { |
|
|
self.config.bos_id |
|
|
} |
|
|
|
|
|
|
|
|
pub fn eos_id(&self) -> i64 { |
|
|
self.config.eos_id |
|
|
} |
|
|
|
|
|
|
|
|
pub fn unk_id(&self) -> i64 { |
|
|
self.config.unk_id |
|
|
} |
|
|
|
|
|
|
|
|
pub fn pad_id(&self) -> i64 { |
|
|
self.config.pad_id |
|
|
} |
|
|
|
|
|
|
|
|
pub fn pad_sequences(&self, sequences: &[Vec<i64>], max_len: Option<usize>) -> Vec<Vec<i64>> { |
|
|
let max_length = max_len.unwrap_or_else(|| sequences.iter().map(|s| s.len()).max().unwrap_or(0)); |
|
|
|
|
|
sequences |
|
|
.iter() |
|
|
.map(|seq| { |
|
|
let mut padded = seq.clone(); |
|
|
while padded.len() < max_length { |
|
|
padded.push(self.config.pad_id); |
|
|
} |
|
|
padded.truncate(max_length); |
|
|
padded |
|
|
}) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn create_attention_mask(&self, tokens: &[i64]) -> Vec<i64> { |
|
|
tokens |
|
|
.iter() |
|
|
.map(|&t| if t == self.config.pad_id { 0 } else { 1 }) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn batch_encode(&self, texts: &[&str]) -> Result<Vec<Vec<i64>>> { |
|
|
texts.iter().map(|text| self.encode(text)).collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn batch_encode_padded( |
|
|
&self, |
|
|
texts: &[&str], |
|
|
max_len: Option<usize>, |
|
|
) -> Result<Vec<Vec<i64>>> { |
|
|
let encoded: Vec<Vec<i64>> = self.batch_encode(texts)?; |
|
|
Ok(self.pad_sequences(&encoded, max_len)) |
|
|
} |
|
|
} |
|
|
|
|
|
#[cfg(test)] |
|
|
mod tests { |
|
|
use super::*; |
|
|
|
|
|
#[test] |
|
|
fn test_tokenizer_creation() { |
|
|
let config = TokenizerConfig::default(); |
|
|
let tokenizer = TextTokenizer::new(config).unwrap(); |
|
|
assert!(tokenizer.vocab_size() > 0); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_encode_decode() { |
|
|
let config = TokenizerConfig::default(); |
|
|
let tokenizer = TextTokenizer::new(config).unwrap(); |
|
|
|
|
|
let text = "Hello world"; |
|
|
let tokens = tokenizer.encode(text).unwrap(); |
|
|
|
|
|
|
|
|
assert_eq!(tokens[0], tokenizer.bos_id()); |
|
|
assert_eq!(*tokens.last().unwrap(), tokenizer.eos_id()); |
|
|
|
|
|
let decoded = tokenizer.decode(&tokens).unwrap(); |
|
|
assert_eq!(decoded, text); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_encode_chinese() { |
|
|
let config = TokenizerConfig::default(); |
|
|
let tokenizer = TextTokenizer::new(config).unwrap(); |
|
|
|
|
|
let text = "你好"; |
|
|
let tokens = tokenizer.encode(text).unwrap(); |
|
|
|
|
|
|
|
|
assert_eq!(tokens.len(), 4); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_pad_sequences() { |
|
|
let config = TokenizerConfig::default(); |
|
|
let tokenizer = TextTokenizer::new(config).unwrap(); |
|
|
|
|
|
let seq1 = vec![1, 2, 3]; |
|
|
let seq2 = vec![1, 2, 3, 4, 5]; |
|
|
|
|
|
let padded = tokenizer.pad_sequences(&[seq1, seq2], None); |
|
|
|
|
|
assert_eq!(padded[0].len(), 5); |
|
|
assert_eq!(padded[1].len(), 5); |
|
|
assert_eq!(padded[0][3], tokenizer.pad_id()); |
|
|
} |
|
|
|
|
|
#[test] |
|
|
fn test_attention_mask() { |
|
|
let config = TokenizerConfig::default(); |
|
|
let tokenizer = TextTokenizer::new(config).unwrap(); |
|
|
|
|
|
let tokens = vec![1, 2, tokenizer.pad_id(), tokenizer.pad_id()]; |
|
|
let mask = tokenizer.create_attention_mask(&tokens); |
|
|
|
|
|
assert_eq!(mask, vec![1, 1, 0, 0]); |
|
|
} |
|
|
} |
|
|
|