Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import copy | |
| import os | |
| MODEL_NAME = "microsoft/phi-2" | |
| CHECKPOINT_DIR = "checkpoints" | |
| def load_policy_model(lr: float = 1e-6): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Trainable policy model on GPU | |
| policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| policy_model.to("cuda") | |
| policy_model.train() | |
| # Train only lm_head | |
| for name, param in policy_model.named_parameters(): | |
| param.requires_grad = ("lm_head" in name) | |
| optimizer = torch.optim.AdamW( | |
| filter(lambda p: p.requires_grad, policy_model.parameters()), | |
| lr=lr, | |
| ) | |
| policy_model.optimizer = optimizer | |
| # Frozen generation model on CPU (no .to("cuda")) | |
| gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| gen_model.eval() | |
| for p in gen_model.parameters(): | |
| p.requires_grad_(False) | |
| return policy_model, gen_model, tokenizer | |
| def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR): | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| path = os.path.join(ckpt_dir, f"step_{step}.pt") | |
| torch.save( | |
| { | |
| "step": step, | |
| "model_state_dict": policy_model.state_dict(), | |
| "optimizer_state_dict": policy_model.optimizer.state_dict() | |
| if hasattr(policy_model, "optimizer") | |
| else None, | |
| }, | |
| path, | |
| ) | |
| print(f"[CKPT] Saved checkpoint at {path}") | |