Sneha7 commited on
Commit
19afcd9
Β·
verified Β·
1 Parent(s): 5981a94

Update policy.py

Browse files
Files changed (1) hide show
  1. policy.py +25 -26
policy.py CHANGED
@@ -13,48 +13,47 @@ def load_policy_model():
13
  torch_dtype=torch.float16
14
  )
15
 
16
- # -----------------------------------------
17
- # 1. UNTIE LM HEAD FROM EMBEDDINGS
18
- # -----------------------------------------
19
- # Phi-2 ties lm_head.weight = embed_tokens.weight
20
- # We replace lm_head with a *separate* nn.Linear so gradients do NOT flow to embeddings.
21
- old_lm_head = model.lm_head
22
- vocab_size, hidden_size = old_lm_head.weight.shape
23
 
24
  print(">>> UNTIEING LM HEAD...")
 
 
 
 
 
25
  new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=True)
 
26
  new_lm_head.weight.data = old_lm_head.weight.data.clone()
27
  if old_lm_head.bias is not None:
28
  new_lm_head.bias.data = old_lm_head.bias.data.clone()
29
 
 
30
  model.lm_head = new_lm_head.to(model.device)
31
 
32
- # -----------------------------------------
33
- # 2. FREEZE EVERYTHING
34
- # -----------------------------------------
35
  for name, param in model.named_parameters():
36
  param.requires_grad = False
37
 
38
- # -----------------------------------------
39
- # 3. UNFREEZE ONLY THE UNTIED LM HEAD
40
- # -----------------------------------------
41
  for name, param in model.named_parameters():
42
- if "lm_head" in name:
43
  param.requires_grad = True
44
  print("TRAINABLE:", name)
45
 
46
- # -----------------------------------------
47
- # 4. VERIFY FINAL PARAM COUNT
48
- # -----------------------------------------
49
- trainable_params = [p for p in model.parameters() if p.requires_grad]
50
- total = sum(p.numel() for p in trainable_params)
51
  print(">>> FINAL TRAINABLE PARAM COUNT:", total)
52
 
53
- # -----------------------------------------
54
- # 5. OPTIMIZER
55
- # -----------------------------------------
56
- optimizer = torch.optim.Adam(trainable_params, lr=1e-4)
57
- model.optimizer = optimizer
58
-
59
- print(">>> POLICY MODEL READY.")
60
  return model, tokenizer
 
13
  torch_dtype=torch.float16
14
  )
15
 
16
+ # -----------------------------------------------------------
17
+ # 1. Identify the REAL lm_head and embedding weights
18
+ # -----------------------------------------------------------
19
+ embed = model.model.embed_tokens
20
+ old_lm_head = model.lm_head # This is actually tied to embed
 
 
21
 
22
  print(">>> UNTIEING LM HEAD...")
23
+
24
+ # -----------------------------------------------------------
25
+ # 2. Create a new untied lm_head
26
+ # -----------------------------------------------------------
27
+ vocab_size, hidden_size = old_lm_head.weight.shape
28
  new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=True)
29
+
30
  new_lm_head.weight.data = old_lm_head.weight.data.clone()
31
  if old_lm_head.bias is not None:
32
  new_lm_head.bias.data = old_lm_head.bias.data.clone()
33
 
34
+ # Replace tied head with untied one
35
  model.lm_head = new_lm_head.to(model.device)
36
 
37
+ # -----------------------------------------------------------
38
+ # 3. Freeze EVERYTHING
39
+ # -----------------------------------------------------------
40
  for name, param in model.named_parameters():
41
  param.requires_grad = False
42
 
43
+ # -----------------------------------------------------------
44
+ # 4. Unfreeze ONLY the new lm_head
45
+ # -----------------------------------------------------------
46
  for name, param in model.named_parameters():
47
+ if name.startswith("lm_head"):
48
  param.requires_grad = True
49
  print("TRAINABLE:", name)
50
 
51
+ # -----------------------------------------------------------
52
+ # 5. Count trainable params
53
+ # -----------------------------------------------------------
54
+ trainable = [p for p in model.parameters() if p.requires_grad]
55
+ total = sum(p.numel() for p in trainable)
56
  print(">>> FINAL TRAINABLE PARAM COUNT:", total)
57
 
58
+ model.optimizer = torch.optim.Adam(trainable, lr=1e-4)
 
 
 
 
 
 
59
  return model, tokenizer