Spaces:
Runtime error
Runtime error
File size: 1,482 Bytes
e4c07fc fb391c1 e4c07fc 30a2ce8 fb391c1 19afcd9 fb391c1 30a2ce8 fb391c1 cdc84bc fb391c1 cdc84bc fb391c1 cdc84bc fb391c1 cdc84bc fb391c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import copy
import os
MODEL_NAME = "microsoft/phi-2"
CHECKPOINT_DIR = "checkpoints"
def load_policy_model(lr: float = 1e-6):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Trainable policy model on GPU
policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
policy_model.to("cuda")
policy_model.train()
# Train only lm_head
for name, param in policy_model.named_parameters():
param.requires_grad = ("lm_head" in name)
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, policy_model.parameters()),
lr=lr,
)
policy_model.optimizer = optimizer
# Frozen generation model on CPU (no .to("cuda"))
gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
gen_model.eval()
for p in gen_model.parameters():
p.requires_grad_(False)
return policy_model, gen_model, tokenizer
def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
os.makedirs(ckpt_dir, exist_ok=True)
path = os.path.join(ckpt_dir, f"step_{step}.pt")
torch.save(
{
"step": step,
"model_state_dict": policy_model.state_dict(),
"optimizer_state_dict": policy_model.optimizer.state_dict()
if hasattr(policy_model, "optimizer")
else None,
},
path,
)
print(f"[CKPT] Saved checkpoint at {path}")
|