Sneha7's picture
Update grpo_train.py
cbb254e verified
raw
history blame
2.15 kB
import torch
import torch.nn.functional as F
def grpo_step(model, tokenizer, prompt, reward_fn, beta: float = 0.1):
device = model.device
# 1) Tokenize on CPU, truncate long prompts
inputs = tokenizer(prompt, return_tensors="pt")
if inputs["input_ids"].shape[-1] > 256:
inputs = {k: v[:, -256:] for k, v in inputs.items()}
# 2) Move to GPU
inputs = {k: v.to(device) for k, v in inputs.items()}
# 3) Sanity‑check model params before CUDA work
with torch.no_grad():
for p in model.parameters():
if torch.isnan(p).any() or torch.isinf(p).any():
raise RuntimeError("Model parameters contain NaN or Inf before GRPO step.")
# 4) Reference logprobs (snapshot)
with torch.no_grad():
ref_out = model(**inputs)
ref_logits = ref_out.logits[:, -1, :]
ref_logprobs = F.log_softmax(ref_logits, dim=-1)
# 5) Sample from current model with conservative decoding
gen_ids = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=1.0,
top_p=0.9,
top_k=50,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# 6) Reward
reward = reward_fn(output_text)
reward_tensor = torch.tensor(reward, dtype=torch.float32, device=device)
# 7) New logprobs
new_out = model(**inputs)
new_logits = new_out.logits[:, -1, :]
new_logprobs = F.log_softmax(new_logits, dim=-1)
# 8) KL divergence (mean over vocab)
kl = torch.mean(new_logprobs - ref_logprobs)
# 9) GRPO objective (scalar reward broadcast over vocab)
loss = -(new_logprobs * reward_tensor).mean() + beta * kl
# 10) Backward + step, with NaN/Inf guard
if torch.isnan(loss) or torch.isinf(loss):
raise RuntimeError("Loss became NaN or Inf in GRPO step.")
loss.backward()
model.optimizer.step()
model.optimizer.zero_grad()
return {
"text": output_text,
"reward": float(reward_tensor.item()),
"kl": float(kl.item()),
"loss": float(loss.item()),
}