llm_workshop_hands_on_gpt-model / modeling_gptscratch.py
itarutomy's picture
Add custom modeling file
df157f8 verified
import torch
from transformers import PreTrainedModel, GPT2Config
from transformers.modeling_outputs import CausalLMOutput
from .gpt_model import GPTModel
class GPTScratchForCausalLM(PreTrainedModel):
config_class = GPT2Config
def __init__(self, config, base_model=None):
super().__init__(config)
self.inner = base_model or GPTModel({
"vocab_size": config.vocab_size,
"emb_dim": config.n_embd,
"n_heads": config.n_head,
"n_layers": config.n_layer,
"context_length": config.n_positions,
"drop_rate": 0.1,
})
self.lm_head = self.inner.out_head # expose for HF tools
def forward(self, input_ids, **kwargs):
logits = self.inner(input_ids)
return CausalLMOutput(logits=logits)
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=32, eos_token_id=None,
do_sample=False, temperature=1.0, top_k=None, top_p=None,
repetition_penalty=1.1, **_):
for _ in range(max_new_tokens):
logits = self.forward(input_ids).logits[:, -1, :]
if repetition_penalty and repetition_penalty != 1.0:
for b in range(input_ids.size(0)):
logits[b, input_ids[b]] /= repetition_penalty
if do_sample:
if temperature and temperature != 1.0:
logits = logits / temperature
probs = logits.softmax(dim=-1)
if top_k is not None:
v, _ = torch.topk(probs, k=top_k, dim=-1)
thresh = v[:, -1].unsqueeze(-1)
probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs))
probs = probs / probs.sum(dim=-1, keepdim=True)
if top_p is not None:
sp, si = probs.sort(descending=True, dim=-1)
cum = sp.cumsum(dim=-1)
mask = cum > top_p
mask[:, 0] = False
sp[mask] = 0
probs = torch.zeros_like(probs).scatter(-1, si, sp)
probs = probs / probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(probs, 1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return input_ids
# absorb old checkpoints whose keys start with 'inner.inner.'
@classmethod
def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs):
remap = {}
for k, v in list(state_dict.items()):
if k.startswith("inner.inner."):
remap[k.replace("inner.inner.", "inner.", 1)] = v
del state_dict[k]
state_dict.update(remap)
return super()._load_state_dict_into_model(model, state_dict, *args, **kwargs)