Sneha7 commited on
Commit
30a2ce8
Β·
verified Β·
1 Parent(s): cbb254e

Update policy.py

Browse files
Files changed (1) hide show
  1. policy.py +38 -44
policy.py CHANGED
@@ -1,59 +1,53 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- def load_policy_model():
5
- model_name = "microsoft/phi-2"
6
-
7
- print(">>> LOADING PHI-2...")
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_name,
12
- device_map="auto",
13
- torch_dtype=torch.float16
14
- )
15
-
16
- # -----------------------------------------------------------
17
- # 1. Identify the REAL lm_head and embedding weights
18
- # -----------------------------------------------------------
19
- embed = model.model.embed_tokens
20
- old_lm_head = model.lm_head # This is actually tied to embed
21
-
22
- print(">>> UNTIEING LM HEAD...")
23
 
24
- # -----------------------------------------------------------
25
- # 2. Create a new untied lm_head
26
- # -----------------------------------------------------------
27
- vocab_size, hidden_size = old_lm_head.weight.shape
28
- new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=True)
29
 
30
- new_lm_head.weight.data = old_lm_head.weight.data.clone()
31
- if old_lm_head.bias is not None:
32
- new_lm_head.bias.data = old_lm_head.bias.data.clone()
33
 
34
- # Replace tied head with untied one
35
- model.lm_head = new_lm_head.to(model.device)
 
36
 
37
- # -----------------------------------------------------------
38
- # 3. Freeze EVERYTHING
39
- # -----------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
40
  for name, param in model.named_parameters():
41
- param.requires_grad = False
42
 
43
- # -----------------------------------------------------------
44
- # 4. Unfreeze ONLY the new lm_head
45
- # -----------------------------------------------------------
46
  for name, param in model.named_parameters():
47
- if name.startswith("lm_head"):
48
- param.requires_grad = True
49
  print("TRAINABLE:", name)
50
 
51
- # -----------------------------------------------------------
52
- # 5. Count trainable params
53
- # -----------------------------------------------------------
54
- trainable = [p for p in model.parameters() if p.requires_grad]
55
- total = sum(p.numel() for p in trainable)
56
- print(">>> FINAL TRAINABLE PARAM COUNT:", total)
 
 
 
 
 
 
 
 
 
57
 
58
- model.optimizer = torch.optim.Adam(trainable, lr=1e-4)
59
  return model, tokenizer
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
 
 
 
 
 
4
 
5
+ MODEL_NAME = "microsoft/phi-2"
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
7
 
8
+ def load_policy_model():
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
10
 
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ if tokenizer.pad_token_id is None:
13
+ tokenizer.pad_token = tokenizer.eos_token
14
 
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ torch_dtype=torch.float16,
18
+ device_map=None,
19
+ ).to(device)
20
+
21
+ # Untie lm_head and freeze everything except lm_head
22
+ model.lm_head = torch.nn.Linear(
23
+ model.lm_head.in_features,
24
+ model.lm_head.out_features,
25
+ bias=True,
26
+ device=device,
27
+ dtype=torch.float16,
28
+ )
29
  for name, param in model.named_parameters():
30
+ param.requires_grad = name.startswith("lm_head")
31
 
32
+ print(">>> UNTIEING LM HEAD...")
 
 
33
  for name, param in model.named_parameters():
34
+ if param.requires_grad:
 
35
  print("TRAINABLE:", name)
36
 
37
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ print(">>> FINAL TRAINABLE PARAM COUNT:", trainable_params)
39
+
40
+ # Optimizer: only lm_head, small LR
41
+ optimizer = torch.optim.AdamW(
42
+ (p for p in model.parameters() if p.requires_grad),
43
+ lr=1e-5,
44
+ )
45
+ model.optimizer = optimizer
46
+
47
+ # Sanity check: no NaN / Inf in fresh weights
48
+ with torch.no_grad():
49
+ for p in model.parameters():
50
+ if torch.isnan(p).any() or torch.isinf(p).any():
51
+ raise RuntimeError("Loaded model checkpoint has NaN/Inf parameters.")
52
 
 
53
  return model, tokenizer