|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch QwerkyLlamaMambaHybrid model for inference only.""" |
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.utils import logging |
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
|
|
|
from mamba_ssm.ops.triton.layer_norm import RMSNorm |
|
|
from mamba_ssm.modules.mha import MHA |
|
|
from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin |
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
|
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn |
|
|
|
|
|
try: |
|
|
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update |
|
|
except ImportError: |
|
|
causal_conv1d_fn, causal_conv1d_update = None, None |
|
|
|
|
|
try: |
|
|
from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
|
|
except ImportError: |
|
|
selective_state_update = None |
|
|
|
|
|
from .configuration_qwerky_llama_mamba_hybrid import QwerkyLlamaMambaHybridConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
|
) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
|
|
|
class Mamba(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
d_inner, |
|
|
d_xb, |
|
|
d_state=16, |
|
|
d_conv=4, |
|
|
expand=2, |
|
|
dt_rank="auto", |
|
|
dt_min=0.001, |
|
|
dt_max=0.1, |
|
|
dt_init="random", |
|
|
dt_scale=1.0, |
|
|
dt_init_floor=1e-4, |
|
|
repeat_kv_before_conv=True, |
|
|
conv_bias=True, |
|
|
proj_x_bias=False, |
|
|
proj_z_bias=False, |
|
|
out_proj_bias=False, |
|
|
use_fast_path=True, |
|
|
layer_idx=None, |
|
|
device=None, |
|
|
dtype=None, |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.d_xb = d_xb |
|
|
self.d_state = d_state |
|
|
self.d_conv = d_conv |
|
|
self.expand = expand |
|
|
self.d_inner = ( |
|
|
d_inner if d_inner is not None else int(self.expand * self.d_model) |
|
|
) |
|
|
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank |
|
|
self.use_fast_path = use_fast_path |
|
|
self.layer_idx = layer_idx |
|
|
self.repeat_kv_before_conv = repeat_kv_before_conv |
|
|
|
|
|
if self.repeat_kv_before_conv: |
|
|
self.conv1d = nn.Conv1d( |
|
|
in_channels=self.d_inner, |
|
|
out_channels=self.d_inner, |
|
|
bias=conv_bias, |
|
|
kernel_size=d_conv, |
|
|
groups=self.d_inner, |
|
|
padding=d_conv - 1, |
|
|
**factory_kwargs, |
|
|
) |
|
|
else: |
|
|
self.conv1d = nn.Conv1d( |
|
|
in_channels=self.d_xb, |
|
|
out_channels=self.d_xb, |
|
|
bias=conv_bias, |
|
|
kernel_size=d_conv, |
|
|
groups=self.d_xb, |
|
|
padding=d_conv - 1, |
|
|
**factory_kwargs, |
|
|
) |
|
|
|
|
|
self.activation = "silu" |
|
|
self.act = nn.SiLU() |
|
|
|
|
|
self.num_xb_head = self.d_xb // self.d_state |
|
|
self.num_C_head = self.d_inner // self.d_state |
|
|
self.repeat_group = self.num_C_head // self.num_xb_head |
|
|
|
|
|
self.in_proj = nn.Linear( |
|
|
self.d_model, |
|
|
2 * self.d_xb + 2 * self.d_inner + self.dt_rank, |
|
|
bias=False, |
|
|
**factory_kwargs, |
|
|
) |
|
|
self.dt_proj = nn.Linear( |
|
|
self.dt_rank, self.d_inner, bias=True, **factory_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
dt_init_std = self.dt_rank**-0.5 * dt_scale |
|
|
if dt_init == "constant": |
|
|
nn.init.constant_(self.dt_proj.weight, dt_init_std) |
|
|
elif dt_init == "random": |
|
|
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
dt = torch.exp( |
|
|
torch.rand(self.d_inner, **factory_kwargs) |
|
|
* (math.log(dt_max) - math.log(dt_min)) |
|
|
+ math.log(dt_min) |
|
|
).clamp(min=dt_init_floor) |
|
|
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
|
with torch.no_grad(): |
|
|
self.dt_proj.bias.copy_(inv_dt) |
|
|
|
|
|
self.dt_proj.bias._no_reinit = True |
|
|
|
|
|
|
|
|
A = repeat( |
|
|
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), |
|
|
"n -> d n", |
|
|
d=self.d_inner, |
|
|
).contiguous() |
|
|
A_log = torch.log(A) |
|
|
self.A_log = nn.Parameter(A_log) |
|
|
self.A_log._no_weight_decay = True |
|
|
|
|
|
|
|
|
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) |
|
|
self.D._no_weight_decay = True |
|
|
|
|
|
self.out_proj = nn.Linear( |
|
|
self.d_inner, self.d_model, bias=out_proj_bias, **factory_kwargs |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states, inference_params=None): |
|
|
""" |
|
|
hidden_states: (B, L, D) |
|
|
Returns: same shape as hidden_states |
|
|
""" |
|
|
batch, seqlen, dim = hidden_states.shape |
|
|
|
|
|
conv_state, ssm_state = None, None |
|
|
if inference_params is not None: |
|
|
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) |
|
|
if inference_params.seqlen_offset > 0: |
|
|
|
|
|
out, _, _ = self.step(hidden_states, conv_state, ssm_state) |
|
|
return out |
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
|
|
|
|
|
|
|
if not hidden_states.is_contiguous(): |
|
|
hidden_states = hidden_states.contiguous() |
|
|
|
|
|
zxbcdt = self.in_proj(hidden_states) |
|
|
z, x, B, C, dt = torch.split( |
|
|
zxbcdt, |
|
|
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
x = rearrange(x, "b l d -> b d l") |
|
|
z = rearrange(z, "b l d -> b d l") |
|
|
|
|
|
B = rearrange( |
|
|
B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state |
|
|
) |
|
|
B = repeat_kv(B, self.repeat_group) |
|
|
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() |
|
|
C = rearrange( |
|
|
C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state |
|
|
).contiguous() |
|
|
|
|
|
dt = self.dt_proj(dt) |
|
|
dt = rearrange(dt, "b l d -> b d l") |
|
|
|
|
|
if self.repeat_kv_before_conv: |
|
|
|
|
|
x = rearrange( |
|
|
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state |
|
|
) |
|
|
x = repeat_kv(x, self.repeat_group) |
|
|
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_state_update = conv_state is not None |
|
|
if need_state_update: |
|
|
|
|
|
|
|
|
|
|
|
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) |
|
|
if causal_conv1d_fn is None: |
|
|
x = self.act(self.conv1d(x)[..., :seqlen]) |
|
|
else: |
|
|
assert self.activation in ["silu", "swish"] |
|
|
x = causal_conv1d_fn( |
|
|
x=x, |
|
|
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), |
|
|
bias=self.conv1d.bias, |
|
|
activation=self.activation, |
|
|
) |
|
|
|
|
|
if not self.repeat_kv_before_conv: |
|
|
x = rearrange( |
|
|
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state |
|
|
) |
|
|
x = repeat_kv(x, self.repeat_group) |
|
|
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") |
|
|
|
|
|
assert self.activation in ["silu", "swish"] |
|
|
|
|
|
return_last_state = ssm_state is not None |
|
|
y = selective_scan_fn( |
|
|
x, |
|
|
dt, |
|
|
A, |
|
|
B, |
|
|
C, |
|
|
self.D.float(), |
|
|
z=z, |
|
|
delta_bias=self.dt_proj.bias.float(), |
|
|
delta_softplus=True, |
|
|
return_last_state=return_last_state, |
|
|
) |
|
|
if return_last_state: |
|
|
y, last_state = y |
|
|
|
|
|
ssm_state.copy_( |
|
|
rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head) |
|
|
) |
|
|
y = rearrange(y, "b d l -> b l d") |
|
|
out = self.out_proj(y) |
|
|
|
|
|
return out |
|
|
|
|
|
def step(self, hidden_states, conv_state, ssm_state): |
|
|
dtype = hidden_states.dtype |
|
|
assert hidden_states.shape[1] == 1, ( |
|
|
"Only support decoding with 1 token at a time for now" |
|
|
) |
|
|
|
|
|
hidden_states_input = hidden_states.squeeze(1) |
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
|
|
|
|
zxbcdt = self.in_proj(hidden_states_input) |
|
|
z, x, B, C, dt = torch.split( |
|
|
zxbcdt, |
|
|
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) |
|
|
B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) |
|
|
|
|
|
C = rearrange( |
|
|
C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state |
|
|
).contiguous() |
|
|
|
|
|
dt = self.dt_proj(dt) |
|
|
|
|
|
if self.repeat_kv_before_conv: |
|
|
x = rearrange( |
|
|
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state |
|
|
) |
|
|
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) |
|
|
x = rearrange(x, "b n_group dstate -> b (n_group dstate)") |
|
|
|
|
|
|
|
|
if causal_conv1d_update is None: |
|
|
|
|
|
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) |
|
|
conv_state[:, :, -1] = x |
|
|
x = torch.sum( |
|
|
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 |
|
|
) |
|
|
if self.conv1d.bias is not None: |
|
|
x = x + self.conv1d.bias |
|
|
x = self.act(x).to(dtype=dtype) |
|
|
else: |
|
|
x = causal_conv1d_update( |
|
|
x, |
|
|
conv_state, |
|
|
rearrange(self.conv1d.weight, "d 1 w -> d w"), |
|
|
self.conv1d.bias, |
|
|
self.activation, |
|
|
) |
|
|
|
|
|
if not self.repeat_kv_before_conv: |
|
|
x = rearrange( |
|
|
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state |
|
|
) |
|
|
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) |
|
|
x = rearrange(x, "b n_group dstate -> b (n_group dstate)") |
|
|
|
|
|
x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head) |
|
|
dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head) |
|
|
A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head) |
|
|
D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head) |
|
|
z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head) |
|
|
dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head) |
|
|
|
|
|
|
|
|
assert selective_state_update is not None |
|
|
y = selective_state_update( |
|
|
ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True |
|
|
) |
|
|
|
|
|
y = rearrange(y, "b h d -> b (h d)") |
|
|
out = self.out_proj(y) |
|
|
|
|
|
return out.unsqueeze(1), conv_state, ssm_state |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
device = self.out_proj.weight.device |
|
|
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype |
|
|
if self.repeat_kv_before_conv: |
|
|
conv_state = torch.zeros( |
|
|
batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype |
|
|
) |
|
|
else: |
|
|
conv_state = torch.zeros( |
|
|
batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype |
|
|
) |
|
|
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype |
|
|
ssm_state = torch.zeros( |
|
|
batch_size, |
|
|
self.num_C_head, |
|
|
self.d_inner // self.num_C_head, |
|
|
self.d_state, |
|
|
device=device, |
|
|
dtype=ssm_dtype, |
|
|
) |
|
|
return conv_state, ssm_state |
|
|
|
|
|
def _get_states_from_cache( |
|
|
self, inference_params, batch_size, initialize_states=False |
|
|
): |
|
|
assert self.layer_idx is not None |
|
|
if self.layer_idx not in inference_params.key_value_memory_dict: |
|
|
if self.repeat_kv_before_conv: |
|
|
conv_state = torch.zeros( |
|
|
batch_size, |
|
|
self.d_inner, |
|
|
self.d_conv, |
|
|
device=self.conv1d.weight.device, |
|
|
dtype=self.conv1d.weight.dtype, |
|
|
) |
|
|
else: |
|
|
conv_state = torch.zeros( |
|
|
batch_size, |
|
|
self.d_xb, |
|
|
self.d_conv, |
|
|
device=self.conv1d.weight.device, |
|
|
dtype=self.conv1d.weight.dtype, |
|
|
) |
|
|
ssm_state = torch.zeros( |
|
|
batch_size, |
|
|
self.num_C_head, |
|
|
self.d_inner // self.num_C_head, |
|
|
self.d_state, |
|
|
device=self.dt_proj.weight.device, |
|
|
dtype=self.dt_proj.weight.dtype, |
|
|
) |
|
|
inference_params.key_value_memory_dict[self.layer_idx] = ( |
|
|
conv_state, |
|
|
ssm_state, |
|
|
) |
|
|
else: |
|
|
conv_state, ssm_state = inference_params.key_value_memory_dict[ |
|
|
self.layer_idx |
|
|
] |
|
|
if initialize_states: |
|
|
conv_state.zero_() |
|
|
ssm_state.zero_() |
|
|
return conv_state, ssm_state |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, d_model, intermediate_size, hidden_act, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.hidden_size = d_model |
|
|
self.intermediate_size = intermediate_size |
|
|
self.gate_proj = nn.Linear( |
|
|
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs |
|
|
) |
|
|
self.up_proj = nn.Linear( |
|
|
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs |
|
|
) |
|
|
self.down_proj = nn.Linear( |
|
|
self.intermediate_size, self.hidden_size, bias=False, **factory_kwargs |
|
|
) |
|
|
self.act_fn = ACT2FN[hidden_act] |
|
|
|
|
|
def forward(self, x): |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class MHADecoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config: QwerkyLlamaMambaHybridConfig, |
|
|
layer_idx: int, |
|
|
device=None, |
|
|
dtype=None, |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super(MHADecoderLayer, self).__init__() |
|
|
self.layer_idx = layer_idx |
|
|
self.mha = MHA( |
|
|
embed_dim=config.hidden_size, |
|
|
num_heads=config.num_attention_heads, |
|
|
num_heads_kv=config.num_key_value_heads, |
|
|
layer_idx=layer_idx, |
|
|
mlp_dim=0, |
|
|
qkv_proj_bias=False, |
|
|
out_proj_bias=False, |
|
|
rotary_emb_dim=config.hidden_size // config.num_attention_heads, |
|
|
rotary_emb_base=config.rope_theta, |
|
|
causal=True, |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
self.mlp = MLP( |
|
|
config.hidden_size, |
|
|
config.intermediate_size, |
|
|
config.hidden_act, |
|
|
**factory_kwargs, |
|
|
) |
|
|
self.input_layernorm = RMSNorm( |
|
|
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs |
|
|
) |
|
|
self.post_attention_layernorm = RMSNorm( |
|
|
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs |
|
|
) |
|
|
self.residual_in_fp32 = True |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
return self.mha.allocate_inference_cache( |
|
|
batch_size, max_seqlen, dtype=dtype, **kwargs |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs |
|
|
): |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states = self.mha(hidden_states, inference_params) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class MambaDecoderLayer(nn.Module): |
|
|
def __init__( |
|
|
self, config: QwerkyLlamaMambaHybridConfig, layer_idx: int, device=None, dtype=None |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super(MambaDecoderLayer, self).__init__() |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
self.mamba = Mamba( |
|
|
d_model=config.d_model, |
|
|
d_inner=config.d_inner, |
|
|
d_xb=config.d_xb, |
|
|
layer_idx=layer_idx, |
|
|
**config.ssm_cfg, |
|
|
**factory_kwargs, |
|
|
) |
|
|
self.mlp = MLP( |
|
|
config.d_model, |
|
|
config.intermediate_size, |
|
|
config.hidden_act, |
|
|
**factory_kwargs, |
|
|
) |
|
|
self.input_layernorm = RMSNorm( |
|
|
config.d_model, eps=config.rms_norm_eps, **factory_kwargs |
|
|
) |
|
|
self.post_attention_layernorm = RMSNorm( |
|
|
config.d_model, eps=config.rms_norm_eps, **factory_kwargs |
|
|
) |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
return self.mamba.allocate_inference_cache( |
|
|
batch_size, max_seqlen, dtype=dtype, **kwargs |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs |
|
|
): |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states = self.mamba(hidden_states, inference_params=inference_params) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def merge_projections_for_layers(checkpoint, layer_indices): |
|
|
"""Merge q_proj, k_proj, v_proj into in_proj for attention layers.""" |
|
|
for layer_idx in layer_indices: |
|
|
q_proj_key = f"model.layers.{layer_idx}.self_attn.q_proj.weight" |
|
|
k_proj_key = f"model.layers.{layer_idx}.self_attn.k_proj.weight" |
|
|
v_proj_key = f"model.layers.{layer_idx}.self_attn.v_proj.weight" |
|
|
o_proj_key = f"model.layers.{layer_idx}.self_attn.o_proj.weight" |
|
|
|
|
|
if ( |
|
|
q_proj_key in checkpoint |
|
|
and k_proj_key in checkpoint |
|
|
and v_proj_key in checkpoint |
|
|
): |
|
|
q_proj_weight = checkpoint[q_proj_key] |
|
|
k_proj_weight = checkpoint[k_proj_key] |
|
|
v_proj_weight = checkpoint[v_proj_key] |
|
|
|
|
|
in_proj_weight = torch.cat( |
|
|
[q_proj_weight, k_proj_weight, v_proj_weight], dim=0 |
|
|
) |
|
|
in_proj_key = f"model.layers.{layer_idx}.mha.in_proj.weight" |
|
|
checkpoint[in_proj_key] = in_proj_weight |
|
|
|
|
|
del checkpoint[q_proj_key] |
|
|
del checkpoint[k_proj_key] |
|
|
del checkpoint[v_proj_key] |
|
|
|
|
|
if o_proj_key in checkpoint: |
|
|
out_proj_key = f"model.layers.{layer_idx}.mha.out_proj.weight" |
|
|
checkpoint[out_proj_key] = checkpoint[o_proj_key] |
|
|
del checkpoint[o_proj_key] |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
|
|
|
class QwerkyLlamaMambaHybridPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
|
models. |
|
|
""" |
|
|
|
|
|
config_class = QwerkyLlamaMambaHybridConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = False |
|
|
_no_split_modules = ["MambaDecoderLayer", "MHADecoderLayer"] |
|
|
_supports_flash_attn_2 = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
|
|
|
|
|
|
class QwerkyLlamaMambaHybridModel(QwerkyLlamaMambaHybridPreTrainedModel): |
|
|
""" |
|
|
The bare QwerkyLlamaMambaHybrid Model transformer outputting raw hidden-states without any specific head on top. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
self.config = config |
|
|
self.vocab_size = config.vocab_size |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
|
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
MHADecoderLayer(config, layer_idx, device=None, dtype=None) |
|
|
if layer_idx in config.attn_layers |
|
|
else MambaDecoderLayer(config, layer_idx, device=None, dtype=None) |
|
|
for layer_idx in range(config.num_hidden_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def load_hook(self, state_dict, prefix, *args): |
|
|
"""Transform state dict keys: merge q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers.""" |
|
|
if self.config.attn_layers: |
|
|
merge_projections_for_layers(state_dict, self.config.attn_layers) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inference_params=None, |
|
|
num_last_tokens: int = 0, |
|
|
**kwargs, |
|
|
): |
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError( |
|
|
"You cannot specify both input_ids and inputs_embeds at the same time" |
|
|
) |
|
|
if input_ids is None and inputs_embeds is None: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
|
|
|
if not hidden_states.is_contiguous(): |
|
|
hidden_states = hidden_states.contiguous() |
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states = layer( |
|
|
hidden_states, inference_params=inference_params, **kwargs |
|
|
) |
|
|
|
|
|
if not hidden_states.is_contiguous(): |
|
|
hidden_states = hidden_states.contiguous() |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if num_last_tokens > 0: |
|
|
hidden_states = hidden_states[:, -num_last_tokens:] |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
"""Allocate inference cache for all layers.""" |
|
|
return { |
|
|
i: layer.allocate_inference_cache( |
|
|
batch_size, max_seqlen, dtype=dtype, **kwargs |
|
|
) |
|
|
for i, layer in enumerate(self.layers) |
|
|
} |
|
|
|
|
|
|
|
|
class QwerkyLlamaMambaHybridForCausalLM( |
|
|
QwerkyLlamaMambaHybridPreTrainedModel, MambaGenerationMixin |
|
|
): |
|
|
""" |
|
|
The QwerkyLlamaMambaHybrid Model transformer with a language modeling head on top (linear layer with weights tied to the input |
|
|
embeddings). |
|
|
""" |
|
|
|
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
self.model = QwerkyLlamaMambaHybridModel(config, **kwargs) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
if config.tie_word_embeddings: |
|
|
self.lm_head.weight = self.model.embed_tokens.weight |
|
|
|
|
|
|
|
|
self._cached_device = None |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
inference_params=None, |
|
|
num_last_tokens: int = 0, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutput]: |
|
|
""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
is_prefill = ( |
|
|
labels is None |
|
|
and ( |
|
|
inference_params is None |
|
|
or getattr(inference_params, "seqlen_offset", 0) == 0 |
|
|
) |
|
|
and num_last_tokens == 0 |
|
|
) |
|
|
|
|
|
if is_prefill: |
|
|
num_last_tokens = 1 |
|
|
|
|
|
hidden_states = self.model( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
position_ids=position_ids, |
|
|
inference_params=inference_params, |
|
|
num_last_tokens=num_last_tokens, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
return CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
"""Allocate inference cache for all layers.""" |
|
|
return self.model.allocate_inference_cache( |
|
|
batch_size, max_seqlen, dtype=dtype, **kwargs |
|
|
) |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
max_length=1024, |
|
|
top_k=50, |
|
|
top_p=1.0, |
|
|
min_p=0.0, |
|
|
temperature=1.0, |
|
|
repetition_penalty=1.0, |
|
|
return_dict_in_generate=False, |
|
|
output_scores=False, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Generate sequences using the model. |
|
|
|
|
|
Supports all standard Transformers generation parameters including: |
|
|
- do_sample, temperature, top_k, top_p, repetition_penalty |
|
|
- attention_mask, pad_token_id, eos_token_id |
|
|
- max_new_tokens, use_cache, and more |
|
|
""" |
|
|
|
|
|
if input_ids.dim() == 1: |
|
|
input_ids = input_ids.unsqueeze(0) |
|
|
|
|
|
|
|
|
if self._cached_device is None: |
|
|
self._cached_device = next(self.parameters()).device |
|
|
device = self._cached_device |
|
|
|
|
|
|
|
|
|
|
|
if input_ids.device != device: |
|
|
input_ids = input_ids.to(device) |
|
|
|
|
|
if input_ids.dtype != torch.long: |
|
|
input_ids = input_ids.long() |
|
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
if kwargs is not None: |
|
|
max_new_tokens = kwargs.pop("max_new_tokens", None) |
|
|
if max_new_tokens is not None: |
|
|
max_length = max_new_tokens + input_ids.shape[1] |
|
|
|
|
|
do_sample = kwargs.pop("do_sample", True) |
|
|
if not do_sample: |
|
|
top_k, top_p, min_p = 1, 0.0, 0.0 |
|
|
|
|
|
cg = kwargs.pop("cg", True) |
|
|
|
|
|
eos_token_id = kwargs.pop("eos_token_id", self.config.eos_token_id) |
|
|
|
|
|
if eos_token_id is not None: |
|
|
if isinstance(eos_token_id, (list, tuple)): |
|
|
eos_token_id = torch.tensor( |
|
|
eos_token_id, dtype=torch.long, device=device |
|
|
) |
|
|
else: |
|
|
eos_token_id = torch.tensor( |
|
|
[eos_token_id], dtype=torch.long, device=device |
|
|
) |
|
|
|
|
|
attention_mask = kwargs.pop("attention_mask", None) |
|
|
pad_token_id = kwargs.pop( |
|
|
"pad_token_id", getattr(self.config, "pad_token_id", None) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
if attention_mask.all(): |
|
|
|
|
|
input_ids = input_ids.contiguous() |
|
|
else: |
|
|
|
|
|
seq_lengths = attention_mask.sum(dim=1) |
|
|
max_seq_len = seq_lengths.max().item() |
|
|
min_seq_len = seq_lengths.min().item() |
|
|
original_seq_len = input_ids.shape[1] |
|
|
|
|
|
|
|
|
if min_seq_len == max_seq_len and max_seq_len <= original_seq_len: |
|
|
input_ids = input_ids[:, :max_seq_len].contiguous() |
|
|
else: |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
dtype = input_ids.dtype |
|
|
pad_value = pad_token_id if pad_token_id is not None else 0 |
|
|
|
|
|
|
|
|
input_ids_filtered = torch.full( |
|
|
(batch_size, max_seq_len), |
|
|
pad_value, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
copy_len = min(max_seq_len, original_seq_len) |
|
|
if copy_len > 0: |
|
|
|
|
|
|
|
|
valid_mask = torch.arange( |
|
|
copy_len, device=device |
|
|
).unsqueeze(0) < seq_lengths.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
input_ids_slice = input_ids[:, :copy_len].contiguous() |
|
|
input_ids_filtered_slice = input_ids_filtered[:, :copy_len] |
|
|
|
|
|
|
|
|
|
|
|
input_ids_filtered[:, :copy_len] = torch.where( |
|
|
valid_mask, input_ids_slice, input_ids_filtered_slice |
|
|
) |
|
|
|
|
|
input_ids = input_ids_filtered.contiguous() |
|
|
|
|
|
|
|
|
repetition_penalty = kwargs.pop("repetition_penalty", repetition_penalty) |
|
|
|
|
|
|
|
|
|
|
|
use_cache = kwargs.pop( |
|
|
"use_cache", None |
|
|
) |
|
|
no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None) |
|
|
length_penalty = kwargs.pop("length_penalty", None) |
|
|
num_return_sequences = kwargs.pop("num_return_sequences", None) |
|
|
num_beams = kwargs.pop("num_beams", None) |
|
|
low_memory = kwargs.pop("low_memory", None) |
|
|
stopping_criteria = kwargs.pop("stopping_criteria", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_seqlen = max_length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model_dtype = next(self.parameters()).dtype |
|
|
|
|
|
_ = self.allocate_inference_cache( |
|
|
batch_size=batch_size, |
|
|
max_seqlen=max_seqlen, |
|
|
dtype=model_dtype, |
|
|
) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
return super().generate( |
|
|
input_ids=input_ids, |
|
|
max_length=max_length, |
|
|
cg=cg, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
min_p=min_p, |
|
|
temperature=temperature, |
|
|
repetition_penalty=repetition_penalty, |
|
|
return_dict_in_generate=return_dict_in_generate, |
|
|
output_scores=output_scores, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs, |
|
|
) |
|
|
|