Qwerky-Optimized-Llama3.1-Mamba-0.2-8B-Instruct / modeling_qwerky_llama_mamba_hybrid.py
ulmentflam's picture
Upload folder using huggingface_hub
1e2779e verified
# coding=utf-8
# Copyright (c) 2025, Qwerky AI, Inc. All rights reserved.
#
# Licensed under the Qwerky Distilled Model License Agreement (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# See the LICENSE file in this repository
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch QwerkyLlamaMambaHybrid model for inference only."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutput
from mamba_ssm.ops.triton.layer_norm import RMSNorm
from mamba_ssm.modules.mha import MHA
from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin
from transformers.activations import ACT2FN
# Import Mamba dependencies
import math
import torch.nn.functional as F
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
from .configuration_qwerky_llama_mamba_hybrid import QwerkyLlamaMambaHybridConfig
logger = logging.get_logger(__name__)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Mamba class implementation (included directly for standalone HuggingFace repo)
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_inner,
d_xb,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
repeat_kv_before_conv=True,
conv_bias=True,
proj_x_bias=False,
proj_z_bias=False,
out_proj_bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_xb = d_xb
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = (
d_inner if d_inner is not None else int(self.expand * self.d_model)
)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.repeat_kv_before_conv = repeat_kv_before_conv
if self.repeat_kv_before_conv:
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
else:
self.conv1d = nn.Conv1d(
in_channels=self.d_xb,
out_channels=self.d_xb,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_xb,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.num_xb_head = self.d_xb // self.d_state
self.num_C_head = self.d_inner // self.d_state
self.repeat_group = self.num_C_head // self.num_xb_head
self.in_proj = nn.Linear(
self.d_model,
2 * self.d_xb + 2 * self.d_inner + self.dt_rank,
bias=False,
**factory_kwargs,
)
self.dt_proj = nn.Linear(
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(
self.d_inner, self.d_model, bias=out_proj_bias, **factory_kwargs
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# Optimize: Ensure input is contiguous for better performance
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
zxbcdt = self.in_proj(hidden_states)
z, x, B, C, dt = torch.split(
zxbcdt,
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank],
dim=-1,
)
x = rearrange(x, "b l d -> b d l")
z = rearrange(z, "b l d -> b d l")
B = rearrange(
B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state
)
B = repeat_kv(B, self.repeat_group) # B, n_group, L, H
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
C = rearrange(
C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state
).contiguous()
dt = self.dt_proj(dt) # B, L, d_inner
dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L
if self.repeat_kv_before_conv:
# b d l
x = rearrange(
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state
)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
# Compute short convolution
# Optimize: Only update state if we need it for next step (during generation)
# During prompt processing, we can skip state update if not needed
need_state_update = conv_state is not None
if need_state_update:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
# Update state (B D W)
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state
)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
assert self.activation in ["silu", "swish"]
# Optimize: Only return last_state if we need to update ssm_state
return_last_state = ssm_state is not None
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=return_last_state,
)
if return_last_state:
y, last_state = y
# ssm_state.copy_(last_state.unsqueeze(-2))
ssm_state.copy_(
rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)
)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, (
"Only support decoding with 1 token at a time for now"
)
hidden_states_input = hidden_states.squeeze(1)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
zxbcdt = self.in_proj(hidden_states_input)
z, x, B, C, dt = torch.split(
zxbcdt,
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank],
dim=-1,
)
B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group)
C = rearrange(
C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
).contiguous()
dt = self.dt_proj(dt) # B, d_inner
if self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
# Conv step
if causal_conv1d_update is None:
# Update state (B D W)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x
x = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head)
dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head)
A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head)
D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head)
z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head)
dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head)
# SSM step
assert selective_state_update is not None
y = selective_state_update(
ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True
)
y = rearrange(y, "b h d -> b (h d)")
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
if self.repeat_kv_before_conv:
conv_state = torch.zeros(
batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype
)
else:
conv_state = torch.zeros(
batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
ssm_state = torch.zeros(
batch_size,
self.num_C_head,
self.d_inner // self.num_C_head,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(
self, inference_params, batch_size, initialize_states=False
):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
if self.repeat_kv_before_conv:
conv_state = torch.zeros(
batch_size,
self.d_inner,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
else:
conv_state = torch.zeros(
batch_size,
self.d_xb,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.num_C_head,
self.d_inner // self.num_C_head,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (
conv_state,
ssm_state,
)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[
self.layer_idx
]
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
class MLP(nn.Module):
def __init__(self, d_model, intermediate_size, hidden_act, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = d_model
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=False, **factory_kwargs
)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MHADecoderLayer(nn.Module):
def __init__(
self,
config: QwerkyLlamaMambaHybridConfig,
layer_idx: int,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super(MHADecoderLayer, self).__init__()
self.layer_idx = layer_idx
self.mha = MHA(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
num_heads_kv=config.num_key_value_heads,
layer_idx=layer_idx,
mlp_dim=0,
qkv_proj_bias=False,
out_proj_bias=False,
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
rotary_emb_base=config.rope_theta,
causal=True,
device=device,
dtype=dtype,
)
self.mlp = MLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
**factory_kwargs,
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs
)
self.residual_in_fp32 = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mha.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mha(hidden_states, inference_params)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MambaDecoderLayer(nn.Module):
def __init__(
self, config: QwerkyLlamaMambaHybridConfig, layer_idx: int, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super(MambaDecoderLayer, self).__init__()
self.layer_idx = layer_idx
# Create Mamba layer with config parameters
self.mamba = Mamba(
d_model=config.d_model,
d_inner=config.d_inner,
d_xb=config.d_xb,
layer_idx=layer_idx,
**config.ssm_cfg,
**factory_kwargs,
)
self.mlp = MLP(
config.d_model,
config.intermediate_size,
config.hidden_act,
**factory_kwargs,
)
self.input_layernorm = RMSNorm(
config.d_model, eps=config.rms_norm_eps, **factory_kwargs
)
self.post_attention_layernorm = RMSNorm(
config.d_model, eps=config.rms_norm_eps, **factory_kwargs
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mamba.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mamba(hidden_states, inference_params=inference_params)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def merge_projections_for_layers(checkpoint, layer_indices):
"""Merge q_proj, k_proj, v_proj into in_proj for attention layers."""
for layer_idx in layer_indices:
q_proj_key = f"model.layers.{layer_idx}.self_attn.q_proj.weight"
k_proj_key = f"model.layers.{layer_idx}.self_attn.k_proj.weight"
v_proj_key = f"model.layers.{layer_idx}.self_attn.v_proj.weight"
o_proj_key = f"model.layers.{layer_idx}.self_attn.o_proj.weight"
if (
q_proj_key in checkpoint
and k_proj_key in checkpoint
and v_proj_key in checkpoint
):
q_proj_weight = checkpoint[q_proj_key]
k_proj_weight = checkpoint[k_proj_key]
v_proj_weight = checkpoint[v_proj_key]
in_proj_weight = torch.cat(
[q_proj_weight, k_proj_weight, v_proj_weight], dim=0
)
in_proj_key = f"model.layers.{layer_idx}.mha.in_proj.weight"
checkpoint[in_proj_key] = in_proj_weight
del checkpoint[q_proj_key]
del checkpoint[k_proj_key]
del checkpoint[v_proj_key]
if o_proj_key in checkpoint:
out_proj_key = f"model.layers.{layer_idx}.mha.out_proj.weight"
checkpoint[out_proj_key] = checkpoint[o_proj_key]
del checkpoint[o_proj_key]
return checkpoint
class QwerkyLlamaMambaHybridPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = QwerkyLlamaMambaHybridConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["MambaDecoderLayer", "MHADecoderLayer"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
class QwerkyLlamaMambaHybridModel(QwerkyLlamaMambaHybridPreTrainedModel):
"""
The bare QwerkyLlamaMambaHybrid Model transformer outputting raw hidden-states without any specific head on top.
"""
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
MHADecoderLayer(config, layer_idx, device=None, dtype=None)
if layer_idx in config.attn_layers
else MambaDecoderLayer(config, layer_idx, device=None, dtype=None)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Register hook to transform state dict keys before loading
# This merges q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
"""Transform state dict keys: merge q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers."""
if self.config.attn_layers:
merge_projections_for_layers(state_dict, self.config.attn_layers)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inference_params=None,
num_last_tokens: int = 0,
**kwargs,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Optimize: Ensure hidden_states is contiguous for better memory access patterns
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
for layer in self.layers:
hidden_states = layer(
hidden_states, inference_params=inference_params, **kwargs
)
# Optimize: Keep hidden_states contiguous between layers
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
hidden_states = self.norm(hidden_states)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
return hidden_states
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for all layers."""
return {
i: layer.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
for i, layer in enumerate(self.layers)
}
class QwerkyLlamaMambaHybridForCausalLM(
QwerkyLlamaMambaHybridPreTrainedModel, MambaGenerationMixin
):
"""
The QwerkyLlamaMambaHybrid Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = QwerkyLlamaMambaHybridModel(config, **kwargs)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights if configured
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# Cache device to avoid repeated next(self.parameters()).device calls
self._cached_device = None
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
inference_params=None,
num_last_tokens: int = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
# Optimize TTFT: During prefill (prompt processing), only compute logits for the last token
# This saves computation in lm_head since we only need the last token's logits to generate the first token
# Conditions: not training (labels is None), in prefill phase (seqlen_offset == 0 or None), and num_last_tokens not explicitly set
is_prefill = (
labels is None # Not in training mode
and (
inference_params is None
or getattr(inference_params, "seqlen_offset", 0) == 0
) # Prefill phase
and num_last_tokens == 0 # Not explicitly set by caller
)
if is_prefill:
num_last_tokens = 1
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=num_last_tokens,
**kwargs,
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutput(
loss=loss,
logits=logits,
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for all layers."""
return self.model.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def generate(
self,
input_ids,
max_length=1024,
top_k=50,
top_p=1.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
"""
Generate sequences using the model.
Supports all standard Transformers generation parameters including:
- do_sample, temperature, top_k, top_p, repetition_penalty
- attention_mask, pad_token_id, eos_token_id
- max_new_tokens, use_cache, and more
"""
# Ensure input_ids is properly shaped (2D: batch_size, seq_len)
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0) # Add batch dimension
# Optimize: Cache device to avoid repeated next(self.parameters()).device calls
if self._cached_device is None:
self._cached_device = next(self.parameters()).device
device = self._cached_device
# Ensure input_ids is on the correct device and dtype for generation
# MambaGenerationMixin expects input_ids to match the model's device
if input_ids.device != device:
input_ids = input_ids.to(device)
# Ensure input_ids is long/int64 dtype (required for token IDs)
if input_ids.dtype != torch.long:
input_ids = input_ids.long()
# Get batch_size early for cache pre-allocation
batch_size = input_ids.shape[0]
if kwargs is not None:
max_new_tokens = kwargs.pop("max_new_tokens", None)
if max_new_tokens is not None:
max_length = max_new_tokens + input_ids.shape[1]
do_sample = kwargs.pop("do_sample", True)
if not do_sample:
top_k, top_p, min_p = 1, 0.0, 0.0
cg = kwargs.pop("cg", True)
eos_token_id = kwargs.pop("eos_token_id", self.config.eos_token_id)
# Convert eos_token_id to tensor to ensure compatibility with mamba_ssm tensor comparisons
if eos_token_id is not None:
if isinstance(eos_token_id, (list, tuple)):
eos_token_id = torch.tensor(
eos_token_id, dtype=torch.long, device=device
)
else:
eos_token_id = torch.tensor(
[eos_token_id], dtype=torch.long, device=device
)
attention_mask = kwargs.pop("attention_mask", None)
pad_token_id = kwargs.pop(
"pad_token_id", getattr(self.config, "pad_token_id", None)
)
# Optimize: Handle attention_mask more efficiently
# Skip expensive filtering if attention_mask is None or all ones
if attention_mask is not None:
# Fast path: Check if all sequences are fully valid (all ones)
if attention_mask.all():
# No filtering needed, just ensure contiguous
input_ids = input_ids.contiguous()
else:
# Vectorized filtering: get sequence lengths and max length
seq_lengths = attention_mask.sum(dim=1) # (batch_size,)
max_seq_len = seq_lengths.max().item()
min_seq_len = seq_lengths.min().item()
original_seq_len = input_ids.shape[1]
# Fast path: if all sequences are the same length, just slice
if min_seq_len == max_seq_len and max_seq_len <= original_seq_len:
input_ids = input_ids[:, :max_seq_len].contiguous()
else:
# Fully vectorized approach: create padded tensor and copy sequences
batch_size = input_ids.shape[0]
dtype = input_ids.dtype
pad_value = pad_token_id if pad_token_id is not None else 0
# Create output tensor filled with pad_value (single vectorized operation)
input_ids_filtered = torch.full(
(batch_size, max_seq_len),
pad_value,
dtype=dtype,
device=device,
)
# Only copy up to the original sequence length to avoid out-of-bounds access
copy_len = min(max_seq_len, original_seq_len)
if copy_len > 0:
# Create a mask for valid positions (vectorized)
# Shape: (batch_size, copy_len) - True where we should copy from input_ids
valid_mask = torch.arange(
copy_len, device=device
).unsqueeze(0) < seq_lengths.unsqueeze(1)
# Copy valid positions using PyTorch masking operations
# Use .contiguous() to ensure proper memory layout
input_ids_slice = input_ids[:, :copy_len].contiguous()
input_ids_filtered_slice = input_ids_filtered[:, :copy_len]
# Use torch.where for safe vectorized copying
# valid_mask broadcasts automatically: (batch_size, copy_len) -> (batch_size, copy_len)
input_ids_filtered[:, :copy_len] = torch.where(
valid_mask, input_ids_slice, input_ids_filtered_slice
)
input_ids = input_ids_filtered.contiguous()
# Use repetition_penalty from parameter or kwargs (supported by decode function)
repetition_penalty = kwargs.pop("repetition_penalty", repetition_penalty)
# Extract other parameters that might be passed but not used by MambaGenerationMixin
# These are popped from kwargs to avoid passing them to the parent generate() method
use_cache = kwargs.pop(
"use_cache", None
) # Not supported by MambaGenerationMixin
no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None)
length_penalty = kwargs.pop("length_penalty", None)
num_return_sequences = kwargs.pop("num_return_sequences", None)
num_beams = kwargs.pop("num_beams", None)
low_memory = kwargs.pop("low_memory", None)
stopping_criteria = kwargs.pop("stopping_criteria", None)
# Optimize TTFT: Pre-allocate inference cache before generation starts
# This avoids allocation overhead during the first forward pass
# Calculate max_seqlen: use max_length (which includes prompt + generation length)
max_seqlen = max_length
# Pre-allocate cache - this allocates memory upfront, reducing latency during generation
# The cache will be used by MambaGenerationMixin internally
# Note: We pre-allocate even if it's not directly passed, as it warms up memory allocator
try:
# Get model dtype for cache allocation
model_dtype = next(self.parameters()).dtype
# Pre-allocate cache - this is a warm-up allocation that helps with memory timing
_ = self.allocate_inference_cache(
batch_size=batch_size,
max_seqlen=max_seqlen,
dtype=model_dtype,
)
except Exception:
# If allocation fails, continue without pre-allocation
# This shouldn't happen, but we don't want to break generation
pass
return super().generate(
input_ids=input_ids,
max_length=max_length,
cg=cg,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
eos_token_id=eos_token_id,
**kwargs,
)