Sneha7 commited on
Commit
c8e38f6
Β·
verified Β·
1 Parent(s): 55bd1b0

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +17 -20
grpo_train.py CHANGED
@@ -12,25 +12,19 @@ def grpo_step(
12
  eps_clip: float = 0.2,
13
  group_size: int = 4,
14
  ):
15
- """
16
- GRPO step with:
17
- - Sampling from gen_model (CPU)
18
- - Policy/Ref both from policy_model on GPU (ref = frozen logits this step)
19
- """
20
  device = policy_model.device
21
 
22
- # 1) Tokenize on GPU for policy, but copy to CPU for gen_model
23
  inputs = tokenizer(prompt, return_tensors="pt")
24
  inputs_gpu = {k: v.to(device) for k, v in inputs.items()}
25
- input_ids_gpu = inputs_gpu["input_ids"] # [1, L]
26
  attn_gpu = inputs_gpu.get("attention_mask", None)
27
 
28
- # Group repeat for GPU tensors
29
  input_ids_gpu = input_ids_gpu.repeat_interleave(group_size, dim=0)
30
  if attn_gpu is not None:
31
  attn_gpu = attn_gpu.repeat_interleave(group_size, dim=0)
32
 
33
- # For CPU gen_model, keep a CPU copy
34
  input_ids_cpu = input_ids_gpu.cpu()
35
  attn_cpu = attn_gpu.cpu() if attn_gpu is not None else None
36
 
@@ -38,7 +32,7 @@ def grpo_step(
38
  if attn_cpu is not None:
39
  gen_inputs["attention_mask"] = attn_cpu
40
 
41
- # 2) Generate on CPU (slower but fits memory)
42
  with torch.no_grad():
43
  gen_output = gen_model.generate(
44
  **gen_inputs,
@@ -52,25 +46,25 @@ def grpo_step(
52
  output_scores=False,
53
  )
54
 
55
- sequences_cpu = gen_output.sequences # [G, L+T] on CPU
56
- sequences = sequences_cpu.to(device) # send batch to GPU once
57
 
58
  texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in sequences_cpu]
59
  rewards = torch.tensor(
60
  [reward_fn(text) for text in texts],
61
  device=device,
62
  dtype=torch.float32,
63
- ).clamp_(-2.0, 2.0)
64
 
65
  # 3) Group-normalized advantages
66
  group_mean = rewards.mean()
67
  group_std = rewards.std(unbiased=False) + 1e-8
68
  advantages = (rewards - group_mean) / group_std
69
- advantages = torch.clamp(advantages, -5.0, 5.0)
70
 
71
  orig_len = inputs["input_ids"].shape[1]
72
 
73
- # 4) Compute "ref" logprobs as frozen snapshot of current policy
74
  with torch.no_grad():
75
  ref_out = policy_model(sequences)
76
  ref_logits = ref_out.logits[:, :-1, :]
@@ -78,7 +72,7 @@ def grpo_step(
78
  ref_lp_all = ref_logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1)
79
  ref_lp_gen = ref_lp_all[:, orig_len - 1 :]
80
 
81
- # 5) Current policy logprobs (trainable)
82
  out = policy_model(sequences)
83
  logits = out.logits[:, :-1, :]
84
  logprobs = F.log_softmax(logits, dim=-1)
@@ -97,11 +91,14 @@ def grpo_step(
97
  "loss": 0.0,
98
  }
99
 
100
- # 6) Ratios, KL, loss
101
- log_ratio = (lp_gen - ref_lp_gen).mean(dim=1).clamp_(-10.0, 10.0)
102
- ratio = torch.exp(log_ratio).clamp_(0.0, 10.0)
 
 
103
 
104
- kl_per_sample = (lp_gen - ref_lp_gen).mean(dim=1).clamp_(-10.0, 10.0)
 
105
  kl_scalar = kl_per_sample.abs().mean()
106
 
107
  surr1 = ratio * advantages
 
12
  eps_clip: float = 0.2,
13
  group_size: int = 4,
14
  ):
 
 
 
 
 
15
  device = policy_model.device
16
 
17
+ # 1) Tokenize
18
  inputs = tokenizer(prompt, return_tensors="pt")
19
  inputs_gpu = {k: v.to(device) for k, v in inputs.items()}
20
+ input_ids_gpu = inputs_gpu["input_ids"]
21
  attn_gpu = inputs_gpu.get("attention_mask", None)
22
 
 
23
  input_ids_gpu = input_ids_gpu.repeat_interleave(group_size, dim=0)
24
  if attn_gpu is not None:
25
  attn_gpu = attn_gpu.repeat_interleave(group_size, dim=0)
26
 
27
+ # CPU copy for gen_model
28
  input_ids_cpu = input_ids_gpu.cpu()
29
  attn_cpu = attn_gpu.cpu() if attn_gpu is not None else None
30
 
 
32
  if attn_cpu is not None:
33
  gen_inputs["attention_mask"] = attn_cpu
34
 
35
+ # 2) Generate on CPU
36
  with torch.no_grad():
37
  gen_output = gen_model.generate(
38
  **gen_inputs,
 
46
  output_scores=False,
47
  )
48
 
49
+ sequences_cpu = gen_output.sequences
50
+ sequences = sequences_cpu.to(device)
51
 
52
  texts = [tokenizer.decode(seq, skip_special_tokens=True) for seq in sequences_cpu]
53
  rewards = torch.tensor(
54
  [reward_fn(text) for text in texts],
55
  device=device,
56
  dtype=torch.float32,
57
+ ).clamp(-2.0, 2.0)
58
 
59
  # 3) Group-normalized advantages
60
  group_mean = rewards.mean()
61
  group_std = rewards.std(unbiased=False) + 1e-8
62
  advantages = (rewards - group_mean) / group_std
63
+ advantages = advantages.clamp(-5.0, 5.0)
64
 
65
  orig_len = inputs["input_ids"].shape[1]
66
 
67
+ # 4) Ref logprobs (no grad)
68
  with torch.no_grad():
69
  ref_out = policy_model(sequences)
70
  ref_logits = ref_out.logits[:, :-1, :]
 
72
  ref_lp_all = ref_logprobs.gather(-1, sequences[:, 1:].unsqueeze(-1)).squeeze(-1)
73
  ref_lp_gen = ref_lp_all[:, orig_len - 1 :]
74
 
75
+ # 5) Current policy logprobs (with grad)
76
  out = policy_model(sequences)
77
  logits = out.logits[:, :-1, :]
78
  logprobs = F.log_softmax(logits, dim=-1)
 
91
  "loss": 0.0,
92
  }
93
 
94
+ # 6) Ratios, KL, loss (no in-place ops)
95
+ log_ratio = (lp_gen - ref_lp_gen).mean(dim=1)
96
+ log_ratio = log_ratio.clamp(-10.0, 10.0)
97
+ ratio = torch.exp(log_ratio)
98
+ ratio = ratio.clamp(0.0, 10.0)
99
 
100
+ kl_per_sample = (lp_gen - ref_lp_gen).mean(dim=1)
101
+ kl_per_sample = kl_per_sample.clamp(-10.0, 10.0)
102
  kl_scalar = kl_per_sample.abs().mean()
103
 
104
  surr1 = ratio * advantages