Sneha7's picture
Update policy.py
30a2ce8 verified
raw
history blame
1.64 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "microsoft/phi-2"
def load_policy_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map=None,
).to(device)
# Untie lm_head and freeze everything except lm_head
model.lm_head = torch.nn.Linear(
model.lm_head.in_features,
model.lm_head.out_features,
bias=True,
device=device,
dtype=torch.float16,
)
for name, param in model.named_parameters():
param.requires_grad = name.startswith("lm_head")
print(">>> UNTIEING LM HEAD...")
for name, param in model.named_parameters():
if param.requires_grad:
print("TRAINABLE:", name)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(">>> FINAL TRAINABLE PARAM COUNT:", trainable_params)
# Optimizer: only lm_head, small LR
optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=1e-5,
)
model.optimizer = optimizer
# Sanity check: no NaN / Inf in fresh weights
with torch.no_grad():
for p in model.parameters():
if torch.isnan(p).any() or torch.isinf(p).any():
raise RuntimeError("Loaded model checkpoint has NaN/Inf parameters.")
return model, tokenizer