|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
|
|
|
class ShivikM1V3Config(PretrainedConfig): |
|
|
|
|
|
model_type = "shivik_m1" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=49156, |
|
|
d_model=2048, |
|
|
n_layers=24, |
|
|
num_heads=16, |
|
|
rotary_dim=128, |
|
|
context_length=4096, |
|
|
|
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.num_heads = num_heads |
|
|
self.rotary_dim = rotary_dim |
|
|
self.context_length = context_length |
|
|
|
|
|
|
|
|
|
|
|
self.num_hidden_layers = kwargs.get("num_hidden_layers", n_layers) |
|
|
self.num_layers = kwargs.get("num_layers", n_layers) |
|
|
self.n_layer = kwargs.get("n_layer", n_layers) |
|
|
self.layer_types = kwargs.get("layer_types", ["full_attention"] * n_layers) |
|
|
self.num_kv_shared_layers = kwargs.get("num_kv_shared_layers", 0) |
|
|
self.use_cache = kwargs.get("use_cache", True) |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d)) |
|
|
def forward(self, x): |
|
|
norm = x.pow(2).mean(-1, keepdim=True) |
|
|
return x * torch.rsqrt(norm + self.eps) * self.weight |
|
|
|
|
|
def apply_rope(x, cos, sin): |
|
|
|
|
|
|
|
|
D = x.shape[-1] |
|
|
x1 = x[..., 0::2] |
|
|
x2 = x[..., 1::2] |
|
|
|
|
|
xr = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
|
|
return xr.reshape(x.shape) |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.head_dim = cfg.d_model // cfg.num_heads |
|
|
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
|
|
self.out = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
|
|
def split_heads(self, x): |
|
|
B, T, C = x.shape |
|
|
return x.view(B, T, self.cfg.num_heads, self.head_dim).transpose(1, 2) |
|
|
def forward(self, x, cos, sin, mask, past=None): |
|
|
B, T, C = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v) |
|
|
rd = self.cfg.rotary_dim |
|
|
if rd > 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_rot = apply_rope(q[..., :rd], cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0)) |
|
|
k_rot = apply_rope(k[..., :rd], cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0)) |
|
|
q = torch.cat([q_rot, q[..., rd:]], dim=-1) |
|
|
k = torch.cat([k_rot, k[..., rd:]], dim=-1) |
|
|
if past is not None: |
|
|
pk, pv = past |
|
|
if pk is not None: |
|
|
k = torch.cat([pk, k], dim=2) |
|
|
if pv is not None: |
|
|
v = torch.cat([pv, v], dim=2) |
|
|
present = (k, v) |
|
|
dk = q.shape[-1] |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dk) |
|
|
|
|
|
scores = scores.masked_fill(~mask, float("-inf")) |
|
|
att = torch.softmax(scores, dim=-1) |
|
|
out = torch.matmul(att, v).transpose(1, 2).reshape(B, T, C) |
|
|
return self.out(out), present |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, d): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(d, 4 * d, bias=False) |
|
|
self.w2 = nn.Linear(d, 4 * d, bias=False) |
|
|
self.w3 = nn.Linear(4 * d, d, bias=False) |
|
|
def forward(self, x): |
|
|
return self.w3(F.silu(self.w1(x)) * self.w2(x)) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(cfg.d_model) |
|
|
self.att = Attention(cfg) |
|
|
self.norm2 = RMSNorm(cfg.d_model) |
|
|
self.mlp = SwiGLU(cfg.d_model) |
|
|
def forward(self, x, cos, sin, mask, past=None): |
|
|
h, present = self.att(self.norm1(x), cos, sin, mask, past) |
|
|
x = x + h |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x, present |
|
|
|
|
|
class ShivikM1V3Model(nn.Module): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) |
|
|
|
|
|
self.pos = nn.Parameter(torch.zeros(1, cfg.context_length, cfg.d_model)) |
|
|
mask = torch.tril(torch.ones(cfg.context_length, cfg.context_length)).bool() |
|
|
self.register_buffer("mask", mask.unsqueeze(0).unsqueeze(0)) |
|
|
t = torch.arange(cfg.context_length) |
|
|
|
|
|
freqs = 1.0 / (10000 ** (torch.arange(0, cfg.rotary_dim, 2) / cfg.rotary_dim)) |
|
|
angles = torch.einsum("i,j->ij", t.float(), freqs.float()) |
|
|
|
|
|
self.register_buffer("cos", angles.cos()) |
|
|
self.register_buffer("sin", angles.sin()) |
|
|
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
|
|
self.norm = RMSNorm(cfg.d_model) |
|
|
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
|
|
|
|
|
self.lm_head.weight = self.embed.weight |
|
|
|
|
|
def forward(self, input_ids, past_kvs=None, use_cache=False, **kwargs): |
|
|
""" |
|
|
Returns CausalLMOutputWithCrossAttentions to be compatible with .generate(). |
|
|
past_kvs (or past_key_values) should be iterable of (k, v) tuples per layer or None. |
|
|
""" |
|
|
B, T = input_ids.shape |
|
|
x = self.embed(input_ids) + self.pos[:, :T] |
|
|
mask = self.mask[:, :, :T, :T] |
|
|
cos = self.cos[:T] |
|
|
sin = self.sin[:T] |
|
|
|
|
|
|
|
|
if past_kvs is None: |
|
|
past_kvs = [None] * len(self.blocks) |
|
|
presents = [] |
|
|
for block, p in zip(self.blocks, past_kvs): |
|
|
x, kv = block(x, cos, sin, mask, p) |
|
|
presents.append(kv) |
|
|
|
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
if use_cache: |
|
|
|
|
|
past_key_values = tuple((p[0], p[1]) if p is not None else (None, None) for p in presents) |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
logits=logits, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
cross_attentions=None, |
|
|
) |
|
|
|
|
|
class ShivikM1V3ForCausalLM(PreTrainedModel, GenerationMixin): |
|
|
config_class = ShivikM1V3Config |
|
|
base_model_prefix = "shivik_m1_v3" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
n = getattr(config, "n_layers", None) or getattr(config, "n_layer", None) or getattr(config, "n_layers", None) or getattr(config, "num_hidden_layers", None) or getattr(config, "num_layers", None) or config.n_layers |
|
|
|
|
|
config.n_layers = int(n) |
|
|
config.num_hidden_layers = int(n) |
|
|
config.num_layers = int(n) |
|
|
config.n_layer = int(n) |
|
|
self.model = ShivikM1V3Model(config) |
|
|
|
|
|
def forward(self, input_ids=None, past_key_values=None, **kwargs): |
|
|
|
|
|
return self.model(input_ids, past_key_values, use_cache=kwargs.get("use_cache", False)) |
|
|
|