|
|
|
|
|
|
|
|
|
|
|
|
|
|
mod gpt; |
|
|
mod embedding; |
|
|
mod session; |
|
|
|
|
|
pub use gpt::{GptModel, GptConfig}; |
|
|
pub use embedding::{SpeakerEncoder, EmotionEncoder, SemanticEncoder}; |
|
|
pub use session::{OnnxSession, ModelCache}; |
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)] |
|
|
pub enum SamplingStrategy { |
|
|
|
|
|
Greedy, |
|
|
|
|
|
TopK { k: usize }, |
|
|
|
|
|
TopP { p: f32 }, |
|
|
|
|
|
TopKP { k: usize, p: f32 }, |
|
|
|
|
|
Temperature { temp: f32 }, |
|
|
} |
|
|
|
|
|
impl Default for SamplingStrategy { |
|
|
fn default() -> Self { |
|
|
SamplingStrategy::TopKP { k: 50, p: 0.95 } |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn sample_from_logits(logits: &[f32], strategy: &SamplingStrategy) -> usize { |
|
|
match strategy { |
|
|
SamplingStrategy::Greedy => { |
|
|
logits |
|
|
.iter() |
|
|
.enumerate() |
|
|
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) |
|
|
.map(|(i, _)| i) |
|
|
.unwrap_or(0) |
|
|
} |
|
|
SamplingStrategy::TopK { k } => { |
|
|
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect(); |
|
|
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); |
|
|
indexed.truncate(*k); |
|
|
|
|
|
|
|
|
let max_logit = indexed[0].1; |
|
|
let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum(); |
|
|
let probs: Vec<f32> = indexed |
|
|
.iter() |
|
|
.map(|(_, l)| (l - max_logit).exp() / exp_sum) |
|
|
.collect(); |
|
|
|
|
|
sample_categorical(&indexed.iter().map(|(i, _)| *i).collect::<Vec<_>>(), &probs) |
|
|
} |
|
|
SamplingStrategy::TopP { p } => { |
|
|
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect(); |
|
|
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); |
|
|
|
|
|
|
|
|
let max_logit = indexed[0].1; |
|
|
let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum(); |
|
|
let probs: Vec<f32> = indexed |
|
|
.iter() |
|
|
.map(|(_, l)| (l - max_logit).exp() / exp_sum) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let mut cumsum = 0.0; |
|
|
let mut nucleus_size = probs.len(); |
|
|
for (i, prob) in probs.iter().enumerate() { |
|
|
cumsum += prob; |
|
|
if cumsum >= *p { |
|
|
nucleus_size = i + 1; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let nucleus_sum: f32 = probs[..nucleus_size].iter().sum(); |
|
|
let nucleus_probs: Vec<f32> = probs[..nucleus_size] |
|
|
.iter() |
|
|
.map(|p| p / nucleus_sum) |
|
|
.collect(); |
|
|
|
|
|
sample_categorical( |
|
|
&indexed[..nucleus_size] |
|
|
.iter() |
|
|
.map(|(i, _)| *i) |
|
|
.collect::<Vec<_>>(), |
|
|
&nucleus_probs, |
|
|
) |
|
|
} |
|
|
SamplingStrategy::TopKP { k, p } => { |
|
|
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect(); |
|
|
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); |
|
|
indexed.truncate(*k); |
|
|
|
|
|
|
|
|
let max_logit = indexed[0].1; |
|
|
let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum(); |
|
|
let probs: Vec<f32> = indexed |
|
|
.iter() |
|
|
.map(|(_, l)| (l - max_logit).exp() / exp_sum) |
|
|
.collect(); |
|
|
|
|
|
|
|
|
let mut cumsum = 0.0; |
|
|
let mut nucleus_size = probs.len(); |
|
|
for (i, prob) in probs.iter().enumerate() { |
|
|
cumsum += prob; |
|
|
if cumsum >= *p { |
|
|
nucleus_size = i + 1; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
let nucleus_sum: f32 = probs[..nucleus_size].iter().sum(); |
|
|
let nucleus_probs: Vec<f32> = probs[..nucleus_size] |
|
|
.iter() |
|
|
.map(|p| p / nucleus_sum) |
|
|
.collect(); |
|
|
|
|
|
sample_categorical( |
|
|
&indexed[..nucleus_size] |
|
|
.iter() |
|
|
.map(|(i, _)| *i) |
|
|
.collect::<Vec<_>>(), |
|
|
&nucleus_probs, |
|
|
) |
|
|
} |
|
|
SamplingStrategy::Temperature { temp } => { |
|
|
let scaled: Vec<f32> = logits.iter().map(|l| l / temp).collect(); |
|
|
let max_logit = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
|
|
let exp_sum: f32 = scaled.iter().map(|l| (l - max_logit).exp()).sum(); |
|
|
let probs: Vec<f32> = scaled |
|
|
.iter() |
|
|
.map(|l| (l - max_logit).exp() / exp_sum) |
|
|
.collect(); |
|
|
|
|
|
sample_categorical(&(0..probs.len()).collect::<Vec<_>>(), &probs) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
fn sample_categorical(indices: &[usize], probs: &[f32]) -> usize { |
|
|
use rand::Rng; |
|
|
let mut rng = rand::thread_rng(); |
|
|
let r: f32 = rng.gen(); |
|
|
|
|
|
let mut cumsum = 0.0; |
|
|
for (i, &p) in probs.iter().enumerate() { |
|
|
cumsum += p; |
|
|
if r <= cumsum { |
|
|
return indices[i]; |
|
|
} |
|
|
} |
|
|
|
|
|
indices[indices.len() - 1] |
|
|
} |
|
|
|
|
|
|
|
|
pub fn apply_repetition_penalty(logits: &mut [f32], previous_tokens: &[usize], penalty: f32) { |
|
|
for &token in previous_tokens { |
|
|
if token < logits.len() { |
|
|
if logits[token] > 0.0 { |
|
|
logits[token] /= penalty; |
|
|
} else { |
|
|
logits[token] *= penalty; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
pub fn softmax(logits: &[f32]) -> Vec<f32> { |
|
|
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
|
|
let exp_sum: f32 = logits.iter().map(|l| (l - max_logit).exp()).sum(); |
|
|
logits |
|
|
.iter() |
|
|
.map(|l| (l - max_logit).exp() / exp_sum) |
|
|
.collect() |
|
|
} |
|
|
|
|
|
|
|
|
pub fn log_softmax(logits: &[f32]) -> Vec<f32> { |
|
|
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
|
|
let exp_sum: f32 = logits.iter().map(|l| (l - max_logit).exp()).sum(); |
|
|
let log_sum = exp_sum.ln(); |
|
|
logits.iter().map(|l| l - max_logit - log_sum).collect() |
|
|
} |
|
|
|