itarutomy commited on
Commit
df157f8
·
verified ·
1 Parent(s): 1cf3bf6

Add custom modeling file

Browse files
Files changed (1) hide show
  1. modeling_gptscratch.py +8 -13
modeling_gptscratch.py CHANGED
@@ -1,15 +1,13 @@
1
- # modeling_gptscratch.py
2
  import torch
3
  from transformers import PreTrainedModel, GPT2Config
4
  from transformers.modeling_outputs import CausalLMOutput
5
- from .gpt_model import GPTModel # ← repo直下に置いた gpt_model.py を読む
6
 
7
  class GPTScratchForCausalLM(PreTrainedModel):
8
  config_class = GPT2Config
9
 
10
  def __init__(self, config, base_model=None):
11
  super().__init__(config)
12
- # 学習時のハイパラに合わせて内部モデルを構築
13
  self.inner = base_model or GPTModel({
14
  "vocab_size": config.vocab_size,
15
  "emb_dim": config.n_embd,
@@ -18,8 +16,7 @@ class GPTScratchForCausalLM(PreTrainedModel):
18
  "context_length": config.n_positions,
19
  "drop_rate": 0.1,
20
  })
21
- # HF互換:lm_headを露出(重みはinner側をそのまま共有)
22
- self.lm_head = self.inner.out_head
23
 
24
  def forward(self, input_ids, **kwargs):
25
  logits = self.inner(input_ids)
@@ -29,11 +26,9 @@ class GPTScratchForCausalLM(PreTrainedModel):
29
  def generate(self, input_ids, max_new_tokens=32, eos_token_id=None,
30
  do_sample=False, temperature=1.0, top_k=None, top_p=None,
31
  repetition_penalty=1.1, **_):
32
- # 最小実装(Greedy or 簡易サンプリング)
33
  for _ in range(max_new_tokens):
34
  logits = self.forward(input_ids).logits[:, -1, :]
35
 
36
- # 繰り返し抑制
37
  if repetition_penalty and repetition_penalty != 1.0:
38
  for b in range(input_ids.size(0)):
39
  logits[b, input_ids[b]] /= repetition_penalty
@@ -48,14 +43,14 @@ class GPTScratchForCausalLM(PreTrainedModel):
48
  probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs))
49
  probs = probs / probs.sum(dim=-1, keepdim=True)
50
  if top_p is not None:
51
- sorted_probs, sorted_idx = probs.sort(descending=True, dim=-1)
52
- cum = sorted_probs.cumsum(dim=-1)
53
  mask = cum > top_p
54
  mask[:, 0] = False
55
- sorted_probs[mask] = 0
56
- probs = torch.zeros_like(probs).scatter(-1, sorted_idx, sorted_probs)
57
  probs = probs / probs.sum(dim=-1, keepdim=True)
58
- next_token = torch.multinomial(probs, num_samples=1)
59
  else:
60
  next_token = torch.argmax(logits, dim=-1, keepdim=True)
61
 
@@ -64,7 +59,7 @@ class GPTScratchForCausalLM(PreTrainedModel):
64
  break
65
  return input_ids
66
 
67
- # 旧チェックポイントの 'inner.inner.' → 'inner.' を吸収
68
  @classmethod
69
  def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs):
70
  remap = {}
 
 
1
  import torch
2
  from transformers import PreTrainedModel, GPT2Config
3
  from transformers.modeling_outputs import CausalLMOutput
4
+ from .gpt_model import GPTModel
5
 
6
  class GPTScratchForCausalLM(PreTrainedModel):
7
  config_class = GPT2Config
8
 
9
  def __init__(self, config, base_model=None):
10
  super().__init__(config)
 
11
  self.inner = base_model or GPTModel({
12
  "vocab_size": config.vocab_size,
13
  "emb_dim": config.n_embd,
 
16
  "context_length": config.n_positions,
17
  "drop_rate": 0.1,
18
  })
19
+ self.lm_head = self.inner.out_head # expose for HF tools
 
20
 
21
  def forward(self, input_ids, **kwargs):
22
  logits = self.inner(input_ids)
 
26
  def generate(self, input_ids, max_new_tokens=32, eos_token_id=None,
27
  do_sample=False, temperature=1.0, top_k=None, top_p=None,
28
  repetition_penalty=1.1, **_):
 
29
  for _ in range(max_new_tokens):
30
  logits = self.forward(input_ids).logits[:, -1, :]
31
 
 
32
  if repetition_penalty and repetition_penalty != 1.0:
33
  for b in range(input_ids.size(0)):
34
  logits[b, input_ids[b]] /= repetition_penalty
 
43
  probs = torch.where(probs >= thresh, probs, torch.zeros_like(probs))
44
  probs = probs / probs.sum(dim=-1, keepdim=True)
45
  if top_p is not None:
46
+ sp, si = probs.sort(descending=True, dim=-1)
47
+ cum = sp.cumsum(dim=-1)
48
  mask = cum > top_p
49
  mask[:, 0] = False
50
+ sp[mask] = 0
51
+ probs = torch.zeros_like(probs).scatter(-1, si, sp)
52
  probs = probs / probs.sum(dim=-1, keepdim=True)
53
+ next_token = torch.multinomial(probs, 1)
54
  else:
55
  next_token = torch.argmax(logits, dim=-1, keepdim=True)
56
 
 
59
  break
60
  return input_ids
61
 
62
+ # absorb old checkpoints whose keys start with 'inner.inner.'
63
  @classmethod
64
  def _load_state_dict_into_model(cls, model, state_dict, *args, **kwargs):
65
  remap = {}