Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| def grpo_step(model, tokenizer, prompt, reward_fn, beta: float = 0.1): | |
| device = model.device | |
| # 1) Tokenize on CPU, truncate long prompts | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if inputs["input_ids"].shape[-1] > 256: | |
| inputs = {k: v[:, -256:] for k, v in inputs.items()} | |
| # 2) Move to GPU | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # 3) Sanityβcheck model params before CUDA work | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if torch.isnan(p).any() or torch.isinf(p).any(): | |
| raise RuntimeError("Model parameters contain NaN or Inf before GRPO step.") | |
| # 4) Reference logprobs (snapshot) | |
| with torch.no_grad(): | |
| ref_out = model(**inputs) | |
| ref_logits = ref_out.logits[:, -1, :] | |
| ref_logprobs = F.log_softmax(ref_logits, dim=-1) | |
| # 5) Sample from current model with conservative decoding | |
| gen_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=64, | |
| do_sample=True, | |
| temperature=1.0, | |
| top_p=0.9, | |
| top_k=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| output_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
| # 6) Reward | |
| reward = reward_fn(output_text) | |
| reward_tensor = torch.tensor(reward, dtype=torch.float32, device=device) | |
| # 7) New logprobs | |
| new_out = model(**inputs) | |
| new_logits = new_out.logits[:, -1, :] | |
| new_logprobs = F.log_softmax(new_logits, dim=-1) | |
| # 8) KL divergence (mean over vocab) | |
| kl = torch.mean(new_logprobs - ref_logprobs) | |
| # 9) GRPO objective (scalar reward broadcast over vocab) | |
| loss = -(new_logprobs * reward_tensor).mean() + beta * kl | |
| # 10) Backward + step, with NaN/Inf guard | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| raise RuntimeError("Loss became NaN or Inf in GRPO step.") | |
| loss.backward() | |
| model.optimizer.step() | |
| model.optimizer.zero_grad() | |
| return { | |
| "text": output_text, | |
| "reward": float(reward_tensor.item()), | |
| "kl": float(kl.item()), | |
| "loss": float(loss.item()), | |
| } | |