Sneha7 commited on
Commit
fb391c1
Β·
verified Β·
1 Parent(s): 7ffc118

Update policy.py

Browse files
Files changed (1) hide show
  1. policy.py +58 -27
policy.py CHANGED
@@ -1,34 +1,65 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
  MODEL_NAME = "microsoft/phi-2"
 
5
 
6
 
7
- def load_policy_model():
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- if tokenizer.pad_token_id is None:
12
- tokenizer.pad_token = tokenizer.eos_token
13
-
14
- model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_NAME,
16
- torch_dtype=torch.float16,
17
- device_map=None,
18
- ).to(device)
19
-
20
- # Freeze everything
21
- for p in model.parameters():
22
- p.requires_grad = False
23
-
24
- # Enable training only for lm_head
25
- trainable = []
26
- for name,p in model.named_parameters():
27
- if "lm_head" in name:
28
- p.requires_grad = True
29
- trainable.append(p)
30
- print("TRAINABLE: ", name)
31
-
32
- model.optimizer = torch.optim.Adam(trainable, lr= 1e-5)
33
- print(">>> POLICY MODEL LOADED")
34
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import copy
4
+ import os
5
 
6
  MODEL_NAME = "microsoft/phi-2"
7
+ CHECKPOINT_DIR = "checkpoints"
8
 
9
 
10
+ def load_policy_model(lr: float = 1e-6):
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+
13
+ # Trainable policy model
14
+ policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
15
+ policy_model.to("cuda")
16
+ policy_model.train()
17
+
18
+ # Only train lm_head
19
+ for name, param in policy_model.named_parameters():
20
+ param.requires_grad = ("lm_head" in name)
21
+
22
+ optimizer = torch.optim.AdamW(
23
+ filter(lambda p: p.requires_grad, policy_model.parameters()),
24
+ lr=lr,
25
+ )
26
+ policy_model.optimizer = optimizer
27
+
28
+ # Frozen generation model
29
+ gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
30
+ gen_model.to("cuda")
31
+ gen_model.eval()
32
+ for p in gen_model.parameters():
33
+ p.requires_grad_(False)
34
+
35
+ # Frozen reference model (can just deepcopy gen_model)
36
+ ref_model = copy.deepcopy(gen_model)
37
+ ref_model.eval()
38
+ for p in ref_model.parameters():
39
+ p.requires_grad_(False)
40
+
41
+ return policy_model, gen_model, ref_model, tokenizer
42
+
43
+
44
+ def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
45
+ os.makedirs(ckpt_dir, exist_ok=True)
46
+ path = os.path.join(ckpt_dir, f"step_{step}.pt")
47
+ torch.save(
48
+ {
49
+ "step": step,
50
+ "model_state_dict": policy_model.state_dict(),
51
+ "optimizer_state_dict": policy_model.optimizer.state_dict()
52
+ if hasattr(policy_model, "optimizer")
53
+ else None,
54
+ },
55
+ path,
56
+ )
57
+ print(f"[CKPT] Saved checkpoint at {path}")
58
+
59
+
60
+ def load_checkpoint(policy_model, optimizer, ckpt_path: str):
61
+ ckpt = torch.load(ckpt_path, map_location="cuda")
62
+ policy_model.load_state_dict(ckpt["model_state_dict"])
63
+ if optimizer is not None and ckpt.get("optimizer_state_dict") is not None:
64
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
65
+ print(f"[CKPT] Loaded checkpoint from {ckpt_path} at step={ckpt.get('step')}")