Sneha7 commited on
Commit
2f7e6f8
Β·
verified Β·
1 Parent(s): c046c04

Create grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +45 -0
grpo_train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def grpo_step(model, tokenizer, prompt, reward_fn, beta=0.1):
5
+ device = model.device
6
+
7
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
8
+
9
+ # 1) Reference logprobs (snapshot)
10
+ with torch.no_grad():
11
+ ref_out = model(**inputs)
12
+ ref_logprobs = F.log_softmax(ref_out.logits[:, -1, :], dim=-1)
13
+
14
+ # 2) Sample from current model
15
+ gen_ids = model.generate(
16
+ **inputs,
17
+ max_new_tokens=80,
18
+ do_sample=True,
19
+ temperature=0.7
20
+ )
21
+ output_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
22
+
23
+ # 3) Reward
24
+ reward = reward_fn(output_text)
25
+
26
+ # 4) New logprobs
27
+ new_out = model(**inputs)
28
+ new_logprobs = F.log_softmax(new_out.logits[:, -1, :], dim=-1)
29
+
30
+ # 5) KL divergence
31
+ kl = torch.mean(new_logprobs - ref_logprobs)
32
+
33
+ # 6) GRPO objective
34
+ loss = -(new_logprobs * reward).mean() + beta * kl
35
+
36
+ loss.backward()
37
+ model.optimizer.step()
38
+ model.optimizer.zero_grad()
39
+
40
+ return {
41
+ "text": output_text,
42
+ "reward": float(reward),
43
+ "kl": float(kl.item()),
44
+ "loss": float(loss.item())
45
+ }