import torch from transformers import PreTrainedModel, GPT2Config from transformers.modeling_outputs import CausalLMOutput from .gpt_model import GPTModel class GPTScratchForCausalLM(PreTrainedModel): config_class = GPT2Config def __init__(self, config, base_model=None): super().__init__(config) self.inner = base_model or GPTModel({ "vocab_size": config.vocab_size, "emb_dim": config.n_embd, "n_heads": config.n_head, "n_layers": config.n_layer, "context_length": config.n_positions, "drop_rate": 0.1, }) self.lm_head = self.inner.out_head # expose for HF tools def forward(self, input_ids, **kwargs): logits = self.inner(input_ids) return CausalLMOutput(logits=logits) @torch.no_grad() def generate(self, input_ids, max_new_tokens=32, eos_token_id=None, do_sample=False, temperature=1.0, top_k=None, top_p=None, repetition_penalty=1.1, **_): for _ in range(max_new_tokens): logits = self.forward(input_ids).logits[:, -1, :] if repetition_penalty and repetition_penalty != 1.0: for b in range(input_ids.size(0)): logits[b, input_ids[b]] /= repetition_penalty if do_sample: if temperature and temperature != 1.0: logits = logits / temperature probs = logits.softmax(dim=-1) if top_k is not None: v, _ = torch.topk(probs, k=top_k, dim=-1) thresh = v[:, -1].unsqueeze(-1) probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs)) probs = probs / probs.sum(dim=-1, keepdim=True) if top_p is not None: sp, si = probs.sort(descending=True, dim=-1) cum = sp.cumsum(dim=-1) mask = cum > top_p mask[:, 0] = False sp[mask] = 0 probs = torch.zeros_like(probs).scatter(-1, si, sp) probs = probs / probs.sum(dim=-1, keepdim=True) next_token = torch.multinomial(probs, 1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids # absorb old checkpoints whose keys start with 'inner.inner.' @classmethod def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs): remap = {} for k, v in list(state_dict.items()): if k.startswith("inner.inner."): remap[k.replace("inner.inner.", "inner.", 1)] = v del state_dict[k] state_dict.update(remap) return super()._load_state_dict_into_model(model, state_dict, *args, **kwargs)