shivik-m1-v3-4 / modeling_shivik_m1.py
ziadrone's picture
push fresh rebuilt model + files
19c487d verified
# modeling_shivik_m1.py (PATCHED)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class ShivikM1V3Config(PretrainedConfig):
# keep model_type stable so HF knows what this is
model_type = "shivik_m1"
def __init__(
self,
vocab_size=49156,
d_model=2048,
n_layers=24,
num_heads=16,
rotary_dim=128,
context_length=4096,
# legacy / generation-friendly aliases (kept in config for compatibility)
**kwargs,
):
super().__init__(**kwargs)
# core params
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.num_heads = num_heads
self.rotary_dim = rotary_dim
self.context_length = context_length
# Generation compatibility fields (Transformers internals expect these)
# Keep several aliases so both old and new code find a supported name
self.num_hidden_layers = kwargs.get("num_hidden_layers", n_layers)
self.num_layers = kwargs.get("num_layers", n_layers)
self.n_layer = kwargs.get("n_layer", n_layers)
self.layer_types = kwargs.get("layer_types", ["full_attention"] * n_layers)
self.num_kv_shared_layers = kwargs.get("num_kv_shared_layers", 0)
self.use_cache = kwargs.get("use_cache", True)
class RMSNorm(nn.Module):
def __init__(self, d, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d))
def forward(self, x):
norm = x.pow(2).mean(-1, keepdim=True)
return x * torch.rsqrt(norm + self.eps) * self.weight
def apply_rope(x, cos, sin):
# x: (..., seq_len, head_dim)
# cos/sin: seq_len x (rotary_dim/2) (as created below)
D = x.shape[-1]
x1 = x[..., 0::2]
x2 = x[..., 1::2]
# x1/x2 shape: (..., seq_len, D/2)
xr = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return xr.reshape(x.shape)
class Attention(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.head_dim = cfg.d_model // cfg.num_heads
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
def split_heads(self, x):
B, T, C = x.shape
return x.view(B, T, self.cfg.num_heads, self.head_dim).transpose(1, 2)
def forward(self, x, cos, sin, mask, past=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
rd = self.cfg.rotary_dim
if rd > 0:
# cos/sin currently shape: (T, rd/2)
# Expand cos/sin to match q[..., :rd] shape if necessary via unsqueeze:
# q[..., :rd] has shape (B, heads, T, rd)
# our cos/sin are (T, rd/2) but apply_rope uses splitting into even/odd so current shapes work if broadcasted.
q_rot = apply_rope(q[..., :rd], cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0))
k_rot = apply_rope(k[..., :rd], cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0))
q = torch.cat([q_rot, q[..., rd:]], dim=-1)
k = torch.cat([k_rot, k[..., rd:]], dim=-1)
if past is not None:
pk, pv = past
if pk is not None:
k = torch.cat([pk, k], dim=2)
if pv is not None:
v = torch.cat([pv, v], dim=2)
present = (k, v)
dk = q.shape[-1]
# attention scores: (B, heads, T, T')
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dk)
# mask: shape (1,1,T,T) broadcastable to (B,heads,T,T)
scores = scores.masked_fill(~mask, float("-inf"))
att = torch.softmax(scores, dim=-1)
out = torch.matmul(att, v).transpose(1, 2).reshape(B, T, C)
return self.out(out), present
class SwiGLU(nn.Module):
def __init__(self, d):
super().__init__()
self.w1 = nn.Linear(d, 4 * d, bias=False)
self.w2 = nn.Linear(d, 4 * d, bias=False)
self.w3 = nn.Linear(4 * d, d, bias=False)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class Block(nn.Module):
def __init__(self, cfg):
super().__init__()
self.norm1 = RMSNorm(cfg.d_model)
self.att = Attention(cfg)
self.norm2 = RMSNorm(cfg.d_model)
self.mlp = SwiGLU(cfg.d_model)
def forward(self, x, cos, sin, mask, past=None):
h, present = self.att(self.norm1(x), cos, sin, mask, past)
x = x + h
x = x + self.mlp(self.norm2(x))
return x, present
class ShivikM1V3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
# position embedding (kept as parameter)
self.pos = nn.Parameter(torch.zeros(1, cfg.context_length, cfg.d_model))
mask = torch.tril(torch.ones(cfg.context_length, cfg.context_length)).bool()
self.register_buffer("mask", mask.unsqueeze(0).unsqueeze(0))
t = torch.arange(cfg.context_length)
# rotary frequencies: create half-dim angles (matching even/odd packing)
freqs = 1.0 / (10000 ** (torch.arange(0, cfg.rotary_dim, 2) / cfg.rotary_dim))
angles = torch.einsum("i,j->ij", t.float(), freqs.float()) # (T, rd/2)
# register cos/sin as (T, rd/2) and cast later by loading code if needed
self.register_buffer("cos", angles.cos())
self.register_buffer("sin", angles.sin())
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.norm = RMSNorm(cfg.d_model)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
# tie weights
self.lm_head.weight = self.embed.weight
def forward(self, input_ids, past_kvs=None, use_cache=False, **kwargs):
"""
Returns CausalLMOutputWithCrossAttentions to be compatible with .generate().
past_kvs (or past_key_values) should be iterable of (k, v) tuples per layer or None.
"""
B, T = input_ids.shape
x = self.embed(input_ids) + self.pos[:, :T]
mask = self.mask[:, :, :T, :T] # (1,1,T,T) -> broadcast to (B,heads,T,T)
cos = self.cos[:T] # shape (T, rd/2)
sin = self.sin[:T] # shape (T, rd/2)
# Normalize past format: accept tuple/list named past_key_values or past_kvs
if past_kvs is None:
past_kvs = [None] * len(self.blocks)
presents = []
for block, p in zip(self.blocks, past_kvs):
x, kv = block(x, cos, sin, mask, p)
presents.append(kv)
x = self.norm(x)
logits = self.lm_head(x)
# convert presents -> tuple-of-tuples for past_key_values expected shape
past_key_values = None
if use_cache:
# each present is (k, v); make them into tuples
past_key_values = tuple((p[0], p[1]) if p is not None else (None, None) for p in presents)
return CausalLMOutputWithCrossAttentions(
logits=logits,
past_key_values=past_key_values,
hidden_states=None,
attentions=None,
cross_attentions=None,
)
class ShivikM1V3ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ShivikM1V3Config
base_model_prefix = "shivik_m1_v3"
def __init__(self, config):
super().__init__(config)
# allow both config.n_layers and config.num_hidden_layers to drive model depth
# ensure config fields are in sync
n = getattr(config, "n_layers", None) or getattr(config, "n_layer", None) or getattr(config, "n_layers", None) or getattr(config, "num_hidden_layers", None) or getattr(config, "num_layers", None) or config.n_layers
# normalize config for downstream code
config.n_layers = int(n)
config.num_hidden_layers = int(n)
config.num_layers = int(n)
config.n_layer = int(n)
self.model = ShivikM1V3Model(config)
def forward(self, input_ids=None, past_key_values=None, **kwargs):
# pass through; ShivikM1V3Model returns a proper ModelOutput
return self.model(input_ids, past_key_values, use_cache=kwargs.get("use_cache", False))