# coding=utf-8 # Copyright (c) 2025, Qwerky AI, Inc. All rights reserved. # # Licensed under the Qwerky Distilled Model License Agreement (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # See the LICENSE file in this repository # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch QwerkyLlamaMambaHybrid model for inference only.""" from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.modeling_outputs import CausalLMOutput from mamba_ssm.ops.triton.layer_norm import RMSNorm from mamba_ssm.modules.mha import MHA from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin from transformers.activations import ACT2FN # Import Mamba dependencies import math import torch.nn.functional as F from einops import rearrange, repeat from mamba_ssm.ops.selective_scan_interface import selective_scan_fn try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: causal_conv1d_fn, causal_conv1d_update = None, None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update except ImportError: selective_state_update = None from .configuration_qwerky_llama_mamba_hybrid import QwerkyLlamaMambaHybridConfig logger = logging.get_logger(__name__) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) # Mamba class implementation (included directly for standalone HuggingFace repo) class Mamba(nn.Module): def __init__( self, d_model, d_inner, d_xb, d_state=16, d_conv=4, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, repeat_kv_before_conv=True, conv_bias=True, proj_x_bias=False, proj_z_bias=False, out_proj_bias=False, use_fast_path=True, layer_idx=None, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_xb = d_xb self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = ( d_inner if d_inner is not None else int(self.expand * self.d_model) ) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.use_fast_path = use_fast_path self.layer_idx = layer_idx self.repeat_kv_before_conv = repeat_kv_before_conv if self.repeat_kv_before_conv: self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, bias=conv_bias, kernel_size=d_conv, groups=self.d_inner, padding=d_conv - 1, **factory_kwargs, ) else: self.conv1d = nn.Conv1d( in_channels=self.d_xb, out_channels=self.d_xb, bias=conv_bias, kernel_size=d_conv, groups=self.d_xb, padding=d_conv - 1, **factory_kwargs, ) self.activation = "silu" self.act = nn.SiLU() self.num_xb_head = self.d_xb // self.d_state self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head self.in_proj = nn.Linear( self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=False, **factory_kwargs, ) self.dt_proj = nn.Linear( self.dt_rank, self.d_inner, bias=True, **factory_kwargs ) # Initialize special dt projection to preserve variance at initialization dt_init_std = self.dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit self.dt_proj.bias._no_reinit = True # S4D real initialization A = repeat( torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=self.d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 self.D._no_weight_decay = True self.out_proj = nn.Linear( self.d_inner, self.d_model, bias=out_proj_bias, **factory_kwargs ) def forward(self, hidden_states, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # Optimize: Ensure input is contiguous for better performance if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() zxbcdt = self.in_proj(hidden_states) z, x, B, C, dt = torch.split( zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1, ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") B = rearrange( B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state ) B = repeat_kv(B, self.repeat_group) # B, n_group, L, H B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange( C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state ).contiguous() dt = self.dt_proj(dt) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: # b d l x = rearrange( x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state ) x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") # Compute short convolution # Optimize: Only update state if we need it for next step (during generation) # During prompt processing, we can skip state update if not needed need_state_update = conv_state is not None if need_state_update: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( x=x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, ) if not self.repeat_kv_before_conv: x = rearrange( x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state ) x = repeat_kv(x, self.repeat_group) x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") assert self.activation in ["silu", "swish"] # Optimize: Only return last_state if we need to update ssm_state return_last_state = ssm_state is not None y = selective_scan_fn( x, dt, A, B, C, self.D.float(), z=z, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=return_last_state, ) if return_last_state: y, last_state = y # ssm_state.copy_(last_state.unsqueeze(-2)) ssm_state.copy_( rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head) ) y = rearrange(y, "b d l -> b l d") out = self.out_proj(y) return out def step(self, hidden_states, conv_state, ssm_state): dtype = hidden_states.dtype assert hidden_states.shape[1] == 1, ( "Only support decoding with 1 token at a time for now" ) hidden_states_input = hidden_states.squeeze(1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) zxbcdt = self.in_proj(hidden_states_input) z, x, B, C, dt = torch.split( zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1, ) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange( C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state ).contiguous() dt = self.dt_proj(dt) # B, d_inner if self.repeat_kv_before_conv: x = rearrange( x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state ) x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) x = rearrange(x, "b n_group dstate -> b (n_group dstate)") # Conv step if causal_conv1d_update is None: # Update state (B D W) conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) conv_state[:, :, -1] = x x = torch.sum( conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 ) # (B D) if self.conv1d.bias is not None: x = x + self.conv1d.bias x = self.act(x).to(dtype=dtype) else: x = causal_conv1d_update( x, conv_state, rearrange(self.conv1d.weight, "d 1 w -> d w"), self.conv1d.bias, self.activation, ) if not self.repeat_kv_before_conv: x = rearrange( x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state ) x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) x = rearrange(x, "b n_group dstate -> b (n_group dstate)") x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head) dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head) A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head) D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head) z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head) dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head) # SSM step assert selective_state_update is not None y = selective_state_update( ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True ) y = rearrange(y, "b h d -> b (h d)") out = self.out_proj(y) return out.unsqueeze(1), conv_state, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.out_proj.weight.device conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype if self.repeat_kv_before_conv: conv_state = torch.zeros( batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype ) else: conv_state = torch.zeros( batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype ) ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype ssm_state = torch.zeros( batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype, ) return conv_state, ssm_state def _get_states_from_cache( self, inference_params, batch_size, initialize_states=False ): assert self.layer_idx is not None if self.layer_idx not in inference_params.key_value_memory_dict: if self.repeat_kv_before_conv: conv_state = torch.zeros( batch_size, self.d_inner, self.d_conv, device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ) else: conv_state = torch.zeros( batch_size, self.d_xb, self.d_conv, device=self.conv1d.weight.device, dtype=self.conv1d.weight.dtype, ) ssm_state = torch.zeros( batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=self.dt_proj.weight.device, dtype=self.dt_proj.weight.dtype, ) inference_params.key_value_memory_dict[self.layer_idx] = ( conv_state, ssm_state, ) else: conv_state, ssm_state = inference_params.key_value_memory_dict[ self.layer_idx ] if initialize_states: conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state class MLP(nn.Module): def __init__(self, d_model, intermediate_size, hidden_act, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.hidden_size = d_model self.intermediate_size = intermediate_size self.gate_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs ) self.up_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=False, **factory_kwargs ) self.act_fn = ACT2FN[hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class MHADecoderLayer(nn.Module): def __init__( self, config: QwerkyLlamaMambaHybridConfig, layer_idx: int, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super(MHADecoderLayer, self).__init__() self.layer_idx = layer_idx self.mha = MHA( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, num_heads_kv=config.num_key_value_heads, layer_idx=layer_idx, mlp_dim=0, qkv_proj_bias=False, out_proj_bias=False, rotary_emb_dim=config.hidden_size // config.num_attention_heads, rotary_emb_base=config.rope_theta, causal=True, device=device, dtype=dtype, ) self.mlp = MLP( config.hidden_size, config.intermediate_size, config.hidden_act, **factory_kwargs, ) self.input_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs ) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs ) self.residual_in_fp32 = True def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mha.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) def forward( self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.mha(hidden_states, inference_params) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class MambaDecoderLayer(nn.Module): def __init__( self, config: QwerkyLlamaMambaHybridConfig, layer_idx: int, device=None, dtype=None ): factory_kwargs = {"device": device, "dtype": dtype} super(MambaDecoderLayer, self).__init__() self.layer_idx = layer_idx # Create Mamba layer with config parameters self.mamba = Mamba( d_model=config.d_model, d_inner=config.d_inner, d_xb=config.d_xb, layer_idx=layer_idx, **config.ssm_cfg, **factory_kwargs, ) self.mlp = MLP( config.d_model, config.intermediate_size, config.hidden_act, **factory_kwargs, ) self.input_layernorm = RMSNorm( config.d_model, eps=config.rms_norm_eps, **factory_kwargs ) self.post_attention_layernorm = RMSNorm( config.d_model, eps=config.rms_norm_eps, **factory_kwargs ) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mamba.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) def forward( self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.mamba(hidden_states, inference_params=inference_params) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states def merge_projections_for_layers(checkpoint, layer_indices): """Merge q_proj, k_proj, v_proj into in_proj for attention layers.""" for layer_idx in layer_indices: q_proj_key = f"model.layers.{layer_idx}.self_attn.q_proj.weight" k_proj_key = f"model.layers.{layer_idx}.self_attn.k_proj.weight" v_proj_key = f"model.layers.{layer_idx}.self_attn.v_proj.weight" o_proj_key = f"model.layers.{layer_idx}.self_attn.o_proj.weight" if ( q_proj_key in checkpoint and k_proj_key in checkpoint and v_proj_key in checkpoint ): q_proj_weight = checkpoint[q_proj_key] k_proj_weight = checkpoint[k_proj_key] v_proj_weight = checkpoint[v_proj_key] in_proj_weight = torch.cat( [q_proj_weight, k_proj_weight, v_proj_weight], dim=0 ) in_proj_key = f"model.layers.{layer_idx}.mha.in_proj.weight" checkpoint[in_proj_key] = in_proj_weight del checkpoint[q_proj_key] del checkpoint[k_proj_key] del checkpoint[v_proj_key] if o_proj_key in checkpoint: out_proj_key = f"model.layers.{layer_idx}.mha.out_proj.weight" checkpoint[out_proj_key] = checkpoint[o_proj_key] del checkpoint[o_proj_key] return checkpoint class QwerkyLlamaMambaHybridPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = QwerkyLlamaMambaHybridConfig base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["MambaDecoderLayer", "MHADecoderLayer"] _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) class QwerkyLlamaMambaHybridModel(QwerkyLlamaMambaHybridPreTrainedModel): """ The bare QwerkyLlamaMambaHybrid Model transformer outputting raw hidden-states without any specific head on top. """ def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs): super().__init__(config, **kwargs) self.config = config self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ MHADecoderLayer(config, layer_idx, device=None, dtype=None) if layer_idx in config.attn_layers else MambaDecoderLayer(config, layer_idx, device=None, dtype=None) for layer_idx in range(config.num_hidden_layers) ] ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Register hook to transform state dict keys before loading # This merges q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers self._register_load_state_dict_pre_hook(self.load_hook) self.post_init() def load_hook(self, state_dict, prefix, *args): """Transform state dict keys: merge q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers.""" if self.config.attn_layers: merge_projections_for_layers(state_dict, self.config.attn_layers) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, inference_params=None, num_last_tokens: int = 0, **kwargs, ): if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) if input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # Optimize: Ensure hidden_states is contiguous for better memory access patterns if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() for layer in self.layers: hidden_states = layer( hidden_states, inference_params=inference_params, **kwargs ) # Optimize: Keep hidden_states contiguous between layers if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() hidden_states = self.norm(hidden_states) if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] return hidden_states def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): """Allocate inference cache for all layers.""" return { i: layer.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) for i, layer in enumerate(self.layers) } class QwerkyLlamaMambaHybridForCausalLM( QwerkyLlamaMambaHybridPreTrainedModel, MambaGenerationMixin ): """ The QwerkyLlamaMambaHybrid Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). """ _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs): super().__init__(config, **kwargs) self.model = QwerkyLlamaMambaHybridModel(config, **kwargs) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tie weights if configured if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight # Cache device to avoid repeated next(self.parameters()).device calls self._cached_device = None self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, inference_params=None, num_last_tokens: int = 0, **kwargs, ) -> Union[Tuple, CausalLMOutput]: """ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ # Optimize TTFT: During prefill (prompt processing), only compute logits for the last token # This saves computation in lm_head since we only need the last token's logits to generate the first token # Conditions: not training (labels is None), in prefill phase (seqlen_offset == 0 or None), and num_last_tokens not explicitly set is_prefill = ( labels is None # Not in training mode and ( inference_params is None or getattr(inference_params, "seqlen_offset", 0) == 0 ) # Prefill phase and num_last_tokens == 0 # Not explicitly set by caller ) if is_prefill: num_last_tokens = 1 hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, position_ids=position_ids, inference_params=inference_params, num_last_tokens=num_last_tokens, **kwargs, ) logits = self.lm_head(hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) return CausalLMOutput( loss=loss, logits=logits, ) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): """Allocate inference cache for all layers.""" return self.model.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) def generate( self, input_ids, max_length=1024, top_k=50, top_p=1.0, min_p=0.0, temperature=1.0, repetition_penalty=1.0, return_dict_in_generate=False, output_scores=False, **kwargs, ): """ Generate sequences using the model. Supports all standard Transformers generation parameters including: - do_sample, temperature, top_k, top_p, repetition_penalty - attention_mask, pad_token_id, eos_token_id - max_new_tokens, use_cache, and more """ # Ensure input_ids is properly shaped (2D: batch_size, seq_len) if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) # Add batch dimension # Optimize: Cache device to avoid repeated next(self.parameters()).device calls if self._cached_device is None: self._cached_device = next(self.parameters()).device device = self._cached_device # Ensure input_ids is on the correct device and dtype for generation # MambaGenerationMixin expects input_ids to match the model's device if input_ids.device != device: input_ids = input_ids.to(device) # Ensure input_ids is long/int64 dtype (required for token IDs) if input_ids.dtype != torch.long: input_ids = input_ids.long() # Get batch_size early for cache pre-allocation batch_size = input_ids.shape[0] if kwargs is not None: max_new_tokens = kwargs.pop("max_new_tokens", None) if max_new_tokens is not None: max_length = max_new_tokens + input_ids.shape[1] do_sample = kwargs.pop("do_sample", True) if not do_sample: top_k, top_p, min_p = 1, 0.0, 0.0 cg = kwargs.pop("cg", True) eos_token_id = kwargs.pop("eos_token_id", self.config.eos_token_id) # Convert eos_token_id to tensor to ensure compatibility with mamba_ssm tensor comparisons if eos_token_id is not None: if isinstance(eos_token_id, (list, tuple)): eos_token_id = torch.tensor( eos_token_id, dtype=torch.long, device=device ) else: eos_token_id = torch.tensor( [eos_token_id], dtype=torch.long, device=device ) attention_mask = kwargs.pop("attention_mask", None) pad_token_id = kwargs.pop( "pad_token_id", getattr(self.config, "pad_token_id", None) ) # Optimize: Handle attention_mask more efficiently # Skip expensive filtering if attention_mask is None or all ones if attention_mask is not None: # Fast path: Check if all sequences are fully valid (all ones) if attention_mask.all(): # No filtering needed, just ensure contiguous input_ids = input_ids.contiguous() else: # Vectorized filtering: get sequence lengths and max length seq_lengths = attention_mask.sum(dim=1) # (batch_size,) max_seq_len = seq_lengths.max().item() min_seq_len = seq_lengths.min().item() original_seq_len = input_ids.shape[1] # Fast path: if all sequences are the same length, just slice if min_seq_len == max_seq_len and max_seq_len <= original_seq_len: input_ids = input_ids[:, :max_seq_len].contiguous() else: # Fully vectorized approach: create padded tensor and copy sequences batch_size = input_ids.shape[0] dtype = input_ids.dtype pad_value = pad_token_id if pad_token_id is not None else 0 # Create output tensor filled with pad_value (single vectorized operation) input_ids_filtered = torch.full( (batch_size, max_seq_len), pad_value, dtype=dtype, device=device, ) # Only copy up to the original sequence length to avoid out-of-bounds access copy_len = min(max_seq_len, original_seq_len) if copy_len > 0: # Create a mask for valid positions (vectorized) # Shape: (batch_size, copy_len) - True where we should copy from input_ids valid_mask = torch.arange( copy_len, device=device ).unsqueeze(0) < seq_lengths.unsqueeze(1) # Copy valid positions using PyTorch masking operations # Use .contiguous() to ensure proper memory layout input_ids_slice = input_ids[:, :copy_len].contiguous() input_ids_filtered_slice = input_ids_filtered[:, :copy_len] # Use torch.where for safe vectorized copying # valid_mask broadcasts automatically: (batch_size, copy_len) -> (batch_size, copy_len) input_ids_filtered[:, :copy_len] = torch.where( valid_mask, input_ids_slice, input_ids_filtered_slice ) input_ids = input_ids_filtered.contiguous() # Use repetition_penalty from parameter or kwargs (supported by decode function) repetition_penalty = kwargs.pop("repetition_penalty", repetition_penalty) # Extract other parameters that might be passed but not used by MambaGenerationMixin # These are popped from kwargs to avoid passing them to the parent generate() method use_cache = kwargs.pop( "use_cache", None ) # Not supported by MambaGenerationMixin no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None) length_penalty = kwargs.pop("length_penalty", None) num_return_sequences = kwargs.pop("num_return_sequences", None) num_beams = kwargs.pop("num_beams", None) low_memory = kwargs.pop("low_memory", None) stopping_criteria = kwargs.pop("stopping_criteria", None) # Optimize TTFT: Pre-allocate inference cache before generation starts # This avoids allocation overhead during the first forward pass # Calculate max_seqlen: use max_length (which includes prompt + generation length) max_seqlen = max_length # Pre-allocate cache - this allocates memory upfront, reducing latency during generation # The cache will be used by MambaGenerationMixin internally # Note: We pre-allocate even if it's not directly passed, as it warms up memory allocator try: # Get model dtype for cache allocation model_dtype = next(self.parameters()).dtype # Pre-allocate cache - this is a warm-up allocation that helps with memory timing _ = self.allocate_inference_cache( batch_size=batch_size, max_seqlen=max_seqlen, dtype=model_dtype, ) except Exception: # If allocation fails, continue without pre-allocation # This shouldn't happen, but we don't want to break generation pass return super().generate( input_ids=input_ids, max_length=max_length, cg=cg, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature, repetition_penalty=repetition_penalty, return_dict_in_generate=return_dict_in_generate, output_scores=output_scores, eos_token_id=eos_token_id, **kwargs, )