|
|
import torch |
|
|
from transformers import PreTrainedModel, GPT2Config |
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
from .gpt_model import GPTModel |
|
|
|
|
|
class GPTScratchForCausalLM(PreTrainedModel): |
|
|
config_class = GPT2Config |
|
|
|
|
|
def __init__(self, config, base_model=None): |
|
|
super().__init__(config) |
|
|
self.inner = base_model or GPTModel({ |
|
|
"vocab_size": config.vocab_size, |
|
|
"emb_dim": config.n_embd, |
|
|
"n_heads": config.n_head, |
|
|
"n_layers": config.n_layer, |
|
|
"context_length": config.n_positions, |
|
|
"drop_rate": 0.1, |
|
|
}) |
|
|
self.lm_head = self.inner.out_head |
|
|
|
|
|
def forward(self, input_ids, **kwargs): |
|
|
logits = self.inner(input_ids) |
|
|
return CausalLMOutput(logits=logits) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, input_ids, max_new_tokens=32, eos_token_id=None, |
|
|
do_sample=False, temperature=1.0, top_k=None, top_p=None, |
|
|
repetition_penalty=1.1, **_): |
|
|
for _ in range(max_new_tokens): |
|
|
logits = self.forward(input_ids).logits[:, -1, :] |
|
|
|
|
|
if repetition_penalty and repetition_penalty != 1.0: |
|
|
for b in range(input_ids.size(0)): |
|
|
logits[b, input_ids[b]] /= repetition_penalty |
|
|
|
|
|
if do_sample: |
|
|
if temperature and temperature != 1.0: |
|
|
logits = logits / temperature |
|
|
probs = logits.softmax(dim=-1) |
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(probs, k=top_k, dim=-1) |
|
|
thresh = v[:, -1].unsqueeze(-1) |
|
|
probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs)) |
|
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
if top_p is not None: |
|
|
sp, si = probs.sort(descending=True, dim=-1) |
|
|
cum = sp.cumsum(dim=-1) |
|
|
mask = cum > top_p |
|
|
mask[:, 0] = False |
|
|
sp[mask] = 0 |
|
|
probs = torch.zeros_like(probs).scatter(-1, si, sp) |
|
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
next_token = torch.multinomial(probs, 1) |
|
|
else: |
|
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
|
break |
|
|
return input_ids |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs): |
|
|
remap = {} |
|
|
for k, v in list(state_dict.items()): |
|
|
if k.startswith("inner.inner."): |
|
|
remap[k.replace("inner.inner.", "inner.", 1)] = v |
|
|
del state_dict[k] |
|
|
state_dict.update(remap) |
|
|
return super()._load_state_dict_into_model(model, state_dict, *args, **kwargs) |
|
|
|