File size: 1,482 Bytes
e4c07fc
 
fb391c1
 
e4c07fc
30a2ce8
fb391c1
19afcd9
 
fb391c1
30a2ce8
fb391c1
cdc84bc
fb391c1
 
 
 
cdc84bc
fb391c1
 
 
 
 
 
 
 
 
cdc84bc
fb391c1
 
 
 
 
cdc84bc
fb391c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import copy
import os

MODEL_NAME = "microsoft/phi-2"
CHECKPOINT_DIR = "checkpoints"


def load_policy_model(lr: float = 1e-6):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    # Trainable policy model on GPU
    policy_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    policy_model.to("cuda")
    policy_model.train()

    # Train only lm_head
    for name, param in policy_model.named_parameters():
        param.requires_grad = ("lm_head" in name)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, policy_model.parameters()),
        lr=lr,
    )
    policy_model.optimizer = optimizer

    # Frozen generation model on CPU (no .to("cuda"))
    gen_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    gen_model.eval()
    for p in gen_model.parameters():
        p.requires_grad_(False)

    return policy_model, gen_model, tokenizer


def save_checkpoint(policy_model, step: int, ckpt_dir: str = CHECKPOINT_DIR):
    os.makedirs(ckpt_dir, exist_ok=True)
    path = os.path.join(ckpt_dir, f"step_{step}.pt")
    torch.save(
        {
            "step": step,
            "model_state_dict": policy_model.state_dict(),
            "optimizer_state_dict": policy_model.optimizer.state_dict()
            if hasattr(policy_model, "optimizer")
            else None,
        },
        path,
    )
    print(f"[CKPT] Saved checkpoint at {path}")