Sneha7's picture
Update policy.py
cdc84bc verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import copy
import os
MODEL_NAME = "microsoft/phi-2"
CHECKPOINT_DIR = "checkpoints"
def load_policy_model(lr: float = 1e-6):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Trainable policy model on GPU
policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
policy_model.to("cuda")
policy_model.train()
# Train only lm_head
for name, param in policy_model.named_parameters():
param.requires_grad = ("lm_head" in name)
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, policy_model.parameters()),
lr=lr,
)
policy_model.optimizer = optimizer
# Frozen generation model on CPU (no .to("cuda"))
gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
gen_model.eval()
for p in gen_model.parameters():
p.requires_grad_(False)
return policy_model, gen_model, tokenizer
def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
os.makedirs(ckpt_dir, exist_ok=True)
path = os.path.join(ckpt_dir, f"step_{step}.pt")
torch.save(
{
"step": step,
"model_state_dict": policy_model.state_dict(),
"optimizer_state_dict": policy_model.optimizer.state_dict()
if hasattr(policy_model, "optimizer")
else None,
},
path,
)
print(f"[CKPT] Saved checkpoint at {path}")