File size: 38,222 Bytes
1e2779e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 |
# coding=utf-8
# Copyright (c) 2025, Qwerky AI, Inc. All rights reserved.
#
# Licensed under the Qwerky Distilled Model License Agreement (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# See the LICENSE file in this repository
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch QwerkyLlamaMambaHybrid model for inference only."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.modeling_outputs import CausalLMOutput
from mamba_ssm.ops.triton.layer_norm import RMSNorm
from mamba_ssm.modules.mha import MHA
from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin
from transformers.activations import ACT2FN
# Import Mamba dependencies
import math
import torch.nn.functional as F
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
from .configuration_qwerky_llama_mamba_hybrid import QwerkyLlamaMambaHybridConfig
logger = logging.get_logger(__name__)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Mamba class implementation (included directly for standalone HuggingFace repo)
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_inner,
d_xb,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
repeat_kv_before_conv=True,
conv_bias=True,
proj_x_bias=False,
proj_z_bias=False,
out_proj_bias=False,
use_fast_path=True,
layer_idx=None,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_xb = d_xb
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = (
d_inner if d_inner is not None else int(self.expand * self.d_model)
)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.repeat_kv_before_conv = repeat_kv_before_conv
if self.repeat_kv_before_conv:
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
else:
self.conv1d = nn.Conv1d(
in_channels=self.d_xb,
out_channels=self.d_xb,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_xb,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.num_xb_head = self.d_xb // self.d_state
self.num_C_head = self.d_inner // self.d_state
self.repeat_group = self.num_C_head // self.num_xb_head
self.in_proj = nn.Linear(
self.d_model,
2 * self.d_xb + 2 * self.d_inner + self.dt_rank,
bias=False,
**factory_kwargs,
)
self.dt_proj = nn.Linear(
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(
self.d_inner, self.d_model, bias=out_proj_bias, **factory_kwargs
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# Optimize: Ensure input is contiguous for better performance
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
zxbcdt = self.in_proj(hidden_states)
z, x, B, C, dt = torch.split(
zxbcdt,
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank],
dim=-1,
)
x = rearrange(x, "b l d -> b d l")
z = rearrange(z, "b l d -> b d l")
B = rearrange(
B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state
)
B = repeat_kv(B, self.repeat_group) # B, n_group, L, H
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
C = rearrange(
C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state
).contiguous()
dt = self.dt_proj(dt) # B, L, d_inner
dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L
if self.repeat_kv_before_conv:
# b d l
x = rearrange(
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state
)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
# Compute short convolution
# Optimize: Only update state if we need it for next step (during generation)
# During prompt processing, we can skip state update if not needed
need_state_update = conv_state is not None
if need_state_update:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
# Update state (B D W)
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state
)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
assert self.activation in ["silu", "swish"]
# Optimize: Only return last_state if we need to update ssm_state
return_last_state = ssm_state is not None
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=return_last_state,
)
if return_last_state:
y, last_state = y
# ssm_state.copy_(last_state.unsqueeze(-2))
ssm_state.copy_(
rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)
)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, (
"Only support decoding with 1 token at a time for now"
)
hidden_states_input = hidden_states.squeeze(1)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
zxbcdt = self.in_proj(hidden_states_input)
z, x, B, C, dt = torch.split(
zxbcdt,
[self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank],
dim=-1,
)
B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group)
C = rearrange(
C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
).contiguous()
dt = self.dt_proj(dt) # B, d_inner
if self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
# Conv step
if causal_conv1d_update is None:
# Update state (B D W)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x
x = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(
x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state
)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head)
dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head)
A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head)
D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head)
z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head)
dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head)
# SSM step
assert selective_state_update is not None
y = selective_state_update(
ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True
)
y = rearrange(y, "b h d -> b (h d)")
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
if self.repeat_kv_before_conv:
conv_state = torch.zeros(
batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype
)
else:
conv_state = torch.zeros(
batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
ssm_state = torch.zeros(
batch_size,
self.num_C_head,
self.d_inner // self.num_C_head,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(
self, inference_params, batch_size, initialize_states=False
):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
if self.repeat_kv_before_conv:
conv_state = torch.zeros(
batch_size,
self.d_inner,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
else:
conv_state = torch.zeros(
batch_size,
self.d_xb,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.num_C_head,
self.d_inner // self.num_C_head,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (
conv_state,
ssm_state,
)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[
self.layer_idx
]
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
class MLP(nn.Module):
def __init__(self, d_model, intermediate_size, hidden_act, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = d_model
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False, **factory_kwargs
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=False, **factory_kwargs
)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MHADecoderLayer(nn.Module):
def __init__(
self,
config: QwerkyLlamaMambaHybridConfig,
layer_idx: int,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super(MHADecoderLayer, self).__init__()
self.layer_idx = layer_idx
self.mha = MHA(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
num_heads_kv=config.num_key_value_heads,
layer_idx=layer_idx,
mlp_dim=0,
qkv_proj_bias=False,
out_proj_bias=False,
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
rotary_emb_base=config.rope_theta,
causal=True,
device=device,
dtype=dtype,
)
self.mlp = MLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
**factory_kwargs,
)
self.input_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs
)
self.residual_in_fp32 = True
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mha.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mha(hidden_states, inference_params)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MambaDecoderLayer(nn.Module):
def __init__(
self, config: QwerkyLlamaMambaHybridConfig, layer_idx: int, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super(MambaDecoderLayer, self).__init__()
self.layer_idx = layer_idx
# Create Mamba layer with config parameters
self.mamba = Mamba(
d_model=config.d_model,
d_inner=config.d_inner,
d_xb=config.d_xb,
layer_idx=layer_idx,
**config.ssm_cfg,
**factory_kwargs,
)
self.mlp = MLP(
config.d_model,
config.intermediate_size,
config.hidden_act,
**factory_kwargs,
)
self.input_layernorm = RMSNorm(
config.d_model, eps=config.rms_norm_eps, **factory_kwargs
)
self.post_attention_layernorm = RMSNorm(
config.d_model, eps=config.rms_norm_eps, **factory_kwargs
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mamba.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(
self, hidden_states: torch.Tensor, inference_params=None, *args, **kwargs
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mamba(hidden_states, inference_params=inference_params)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def merge_projections_for_layers(checkpoint, layer_indices):
"""Merge q_proj, k_proj, v_proj into in_proj for attention layers."""
for layer_idx in layer_indices:
q_proj_key = f"model.layers.{layer_idx}.self_attn.q_proj.weight"
k_proj_key = f"model.layers.{layer_idx}.self_attn.k_proj.weight"
v_proj_key = f"model.layers.{layer_idx}.self_attn.v_proj.weight"
o_proj_key = f"model.layers.{layer_idx}.self_attn.o_proj.weight"
if (
q_proj_key in checkpoint
and k_proj_key in checkpoint
and v_proj_key in checkpoint
):
q_proj_weight = checkpoint[q_proj_key]
k_proj_weight = checkpoint[k_proj_key]
v_proj_weight = checkpoint[v_proj_key]
in_proj_weight = torch.cat(
[q_proj_weight, k_proj_weight, v_proj_weight], dim=0
)
in_proj_key = f"model.layers.{layer_idx}.mha.in_proj.weight"
checkpoint[in_proj_key] = in_proj_weight
del checkpoint[q_proj_key]
del checkpoint[k_proj_key]
del checkpoint[v_proj_key]
if o_proj_key in checkpoint:
out_proj_key = f"model.layers.{layer_idx}.mha.out_proj.weight"
checkpoint[out_proj_key] = checkpoint[o_proj_key]
del checkpoint[o_proj_key]
return checkpoint
class QwerkyLlamaMambaHybridPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = QwerkyLlamaMambaHybridConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["MambaDecoderLayer", "MHADecoderLayer"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
class QwerkyLlamaMambaHybridModel(QwerkyLlamaMambaHybridPreTrainedModel):
"""
The bare QwerkyLlamaMambaHybrid Model transformer outputting raw hidden-states without any specific head on top.
"""
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
MHADecoderLayer(config, layer_idx, device=None, dtype=None)
if layer_idx in config.attn_layers
else MambaDecoderLayer(config, layer_idx, device=None, dtype=None)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Register hook to transform state dict keys before loading
# This merges q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
"""Transform state dict keys: merge q_proj/k_proj/v_proj into mha.in_proj.weight for attention layers."""
if self.config.attn_layers:
merge_projections_for_layers(state_dict, self.config.attn_layers)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inference_params=None,
num_last_tokens: int = 0,
**kwargs,
):
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# Optimize: Ensure hidden_states is contiguous for better memory access patterns
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
for layer in self.layers:
hidden_states = layer(
hidden_states, inference_params=inference_params, **kwargs
)
# Optimize: Keep hidden_states contiguous between layers
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
hidden_states = self.norm(hidden_states)
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
return hidden_states
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for all layers."""
return {
i: layer.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
for i, layer in enumerate(self.layers)
}
class QwerkyLlamaMambaHybridForCausalLM(
QwerkyLlamaMambaHybridPreTrainedModel, MambaGenerationMixin
):
"""
The QwerkyLlamaMambaHybrid Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: QwerkyLlamaMambaHybridConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = QwerkyLlamaMambaHybridModel(config, **kwargs)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights if configured
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# Cache device to avoid repeated next(self.parameters()).device calls
self._cached_device = None
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
inference_params=None,
num_last_tokens: int = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
# Optimize TTFT: During prefill (prompt processing), only compute logits for the last token
# This saves computation in lm_head since we only need the last token's logits to generate the first token
# Conditions: not training (labels is None), in prefill phase (seqlen_offset == 0 or None), and num_last_tokens not explicitly set
is_prefill = (
labels is None # Not in training mode
and (
inference_params is None
or getattr(inference_params, "seqlen_offset", 0) == 0
) # Prefill phase
and num_last_tokens == 0 # Not explicitly set by caller
)
if is_prefill:
num_last_tokens = 1
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
inference_params=inference_params,
num_last_tokens=num_last_tokens,
**kwargs,
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return CausalLMOutput(
loss=loss,
logits=logits,
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
"""Allocate inference cache for all layers."""
return self.model.allocate_inference_cache(
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def generate(
self,
input_ids,
max_length=1024,
top_k=50,
top_p=1.0,
min_p=0.0,
temperature=1.0,
repetition_penalty=1.0,
return_dict_in_generate=False,
output_scores=False,
**kwargs,
):
"""
Generate sequences using the model.
Supports all standard Transformers generation parameters including:
- do_sample, temperature, top_k, top_p, repetition_penalty
- attention_mask, pad_token_id, eos_token_id
- max_new_tokens, use_cache, and more
"""
# Ensure input_ids is properly shaped (2D: batch_size, seq_len)
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0) # Add batch dimension
# Optimize: Cache device to avoid repeated next(self.parameters()).device calls
if self._cached_device is None:
self._cached_device = next(self.parameters()).device
device = self._cached_device
# Ensure input_ids is on the correct device and dtype for generation
# MambaGenerationMixin expects input_ids to match the model's device
if input_ids.device != device:
input_ids = input_ids.to(device)
# Ensure input_ids is long/int64 dtype (required for token IDs)
if input_ids.dtype != torch.long:
input_ids = input_ids.long()
# Get batch_size early for cache pre-allocation
batch_size = input_ids.shape[0]
if kwargs is not None:
max_new_tokens = kwargs.pop("max_new_tokens", None)
if max_new_tokens is not None:
max_length = max_new_tokens + input_ids.shape[1]
do_sample = kwargs.pop("do_sample", True)
if not do_sample:
top_k, top_p, min_p = 1, 0.0, 0.0
cg = kwargs.pop("cg", True)
eos_token_id = kwargs.pop("eos_token_id", self.config.eos_token_id)
# Convert eos_token_id to tensor to ensure compatibility with mamba_ssm tensor comparisons
if eos_token_id is not None:
if isinstance(eos_token_id, (list, tuple)):
eos_token_id = torch.tensor(
eos_token_id, dtype=torch.long, device=device
)
else:
eos_token_id = torch.tensor(
[eos_token_id], dtype=torch.long, device=device
)
attention_mask = kwargs.pop("attention_mask", None)
pad_token_id = kwargs.pop(
"pad_token_id", getattr(self.config, "pad_token_id", None)
)
# Optimize: Handle attention_mask more efficiently
# Skip expensive filtering if attention_mask is None or all ones
if attention_mask is not None:
# Fast path: Check if all sequences are fully valid (all ones)
if attention_mask.all():
# No filtering needed, just ensure contiguous
input_ids = input_ids.contiguous()
else:
# Vectorized filtering: get sequence lengths and max length
seq_lengths = attention_mask.sum(dim=1) # (batch_size,)
max_seq_len = seq_lengths.max().item()
min_seq_len = seq_lengths.min().item()
original_seq_len = input_ids.shape[1]
# Fast path: if all sequences are the same length, just slice
if min_seq_len == max_seq_len and max_seq_len <= original_seq_len:
input_ids = input_ids[:, :max_seq_len].contiguous()
else:
# Fully vectorized approach: create padded tensor and copy sequences
batch_size = input_ids.shape[0]
dtype = input_ids.dtype
pad_value = pad_token_id if pad_token_id is not None else 0
# Create output tensor filled with pad_value (single vectorized operation)
input_ids_filtered = torch.full(
(batch_size, max_seq_len),
pad_value,
dtype=dtype,
device=device,
)
# Only copy up to the original sequence length to avoid out-of-bounds access
copy_len = min(max_seq_len, original_seq_len)
if copy_len > 0:
# Create a mask for valid positions (vectorized)
# Shape: (batch_size, copy_len) - True where we should copy from input_ids
valid_mask = torch.arange(
copy_len, device=device
).unsqueeze(0) < seq_lengths.unsqueeze(1)
# Copy valid positions using PyTorch masking operations
# Use .contiguous() to ensure proper memory layout
input_ids_slice = input_ids[:, :copy_len].contiguous()
input_ids_filtered_slice = input_ids_filtered[:, :copy_len]
# Use torch.where for safe vectorized copying
# valid_mask broadcasts automatically: (batch_size, copy_len) -> (batch_size, copy_len)
input_ids_filtered[:, :copy_len] = torch.where(
valid_mask, input_ids_slice, input_ids_filtered_slice
)
input_ids = input_ids_filtered.contiguous()
# Use repetition_penalty from parameter or kwargs (supported by decode function)
repetition_penalty = kwargs.pop("repetition_penalty", repetition_penalty)
# Extract other parameters that might be passed but not used by MambaGenerationMixin
# These are popped from kwargs to avoid passing them to the parent generate() method
use_cache = kwargs.pop(
"use_cache", None
) # Not supported by MambaGenerationMixin
no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None)
length_penalty = kwargs.pop("length_penalty", None)
num_return_sequences = kwargs.pop("num_return_sequences", None)
num_beams = kwargs.pop("num_beams", None)
low_memory = kwargs.pop("low_memory", None)
stopping_criteria = kwargs.pop("stopping_criteria", None)
# Optimize TTFT: Pre-allocate inference cache before generation starts
# This avoids allocation overhead during the first forward pass
# Calculate max_seqlen: use max_length (which includes prompt + generation length)
max_seqlen = max_length
# Pre-allocate cache - this allocates memory upfront, reducing latency during generation
# The cache will be used by MambaGenerationMixin internally
# Note: We pre-allocate even if it's not directly passed, as it warms up memory allocator
try:
# Get model dtype for cache allocation
model_dtype = next(self.parameters()).dtype
# Pre-allocate cache - this is a warm-up allocation that helps with memory timing
_ = self.allocate_inference_cache(
batch_size=batch_size,
max_seqlen=max_seqlen,
dtype=model_dtype,
)
except Exception:
# If allocation fails, continue without pre-allocation
# This shouldn't happen, but we don't want to break generation
pass
return super().generate(
input_ids=input_ids,
max_length=max_length,
cg=cg,
top_k=top_k,
top_p=top_p,
min_p=min_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
eos_token_id=eos_token_id,
**kwargs,
)
|