Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,29 +4,34 @@ from reward_fn import reward_fn
|
|
| 4 |
from grpo_train import grpo_step
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
|
| 7 |
-
|
| 8 |
model, tokenizer = load_policy_model()
|
| 9 |
|
| 10 |
reward_history = []
|
| 11 |
|
|
|
|
| 12 |
def plot_rewards(history):
|
| 13 |
fig = plt.figure()
|
| 14 |
-
plt.plot(history, marker=
|
| 15 |
plt.title("Reward History")
|
| 16 |
plt.xlabel("Step")
|
| 17 |
plt.ylabel("Reward")
|
| 18 |
return fig
|
| 19 |
|
|
|
|
| 20 |
def run_step(prompt):
|
| 21 |
result = grpo_step(model, tokenizer, prompt, reward_fn)
|
| 22 |
reward_history.append(float(result["reward"]))
|
| 23 |
reward_plot = plot_rewards(reward_history)
|
| 24 |
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
|
| 25 |
-
|
|
|
|
| 26 |
with gr.Blocks() as demo:
|
| 27 |
gr.Markdown("# π€ GRPO with Phi-2 β Helpfulness Reward Demo")
|
| 28 |
|
| 29 |
-
prompt = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
| 30 |
run_btn = gr.Button("Run GRPO Step")
|
| 31 |
|
| 32 |
output = gr.Textbox(label="Model Output")
|
|
@@ -36,20 +41,10 @@ with gr.Blocks() as demo:
|
|
| 36 |
|
| 37 |
plot = gr.Plot(label="Reward Over Time")
|
| 38 |
|
| 39 |
-
def update_plot(xs, ys):
|
| 40 |
-
import matplotlib.pyplot as plt
|
| 41 |
-
fig, ax = plt.subplots()
|
| 42 |
-
ax.plot(xs, ys)
|
| 43 |
-
ax.set_title("Reward Trend")
|
| 44 |
-
ax.set_xlabel("Step")
|
| 45 |
-
ax.set_ylabel("Reward")
|
| 46 |
-
return fig
|
| 47 |
-
|
| 48 |
run_btn.click(
|
| 49 |
fn=run_step,
|
| 50 |
inputs=[prompt],
|
| 51 |
outputs=[output, reward_box, kl_box, loss_box, plot],
|
| 52 |
-
postprocess=update_plot
|
| 53 |
)
|
| 54 |
|
| 55 |
demo.launch()
|
|
|
|
| 4 |
from grpo_train import grpo_step
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
|
|
|
|
| 7 |
model, tokenizer = load_policy_model()
|
| 8 |
|
| 9 |
reward_history = []
|
| 10 |
|
| 11 |
+
|
| 12 |
def plot_rewards(history):
|
| 13 |
fig = plt.figure()
|
| 14 |
+
plt.plot(history, marker="o")
|
| 15 |
plt.title("Reward History")
|
| 16 |
plt.xlabel("Step")
|
| 17 |
plt.ylabel("Reward")
|
| 18 |
return fig
|
| 19 |
|
| 20 |
+
|
| 21 |
def run_step(prompt):
|
| 22 |
result = grpo_step(model, tokenizer, prompt, reward_fn)
|
| 23 |
reward_history.append(float(result["reward"]))
|
| 24 |
reward_plot = plot_rewards(reward_history)
|
| 25 |
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
|
| 26 |
+
|
| 27 |
+
|
| 28 |
with gr.Blocks() as demo:
|
| 29 |
gr.Markdown("# π€ GRPO with Phi-2 β Helpfulness Reward Demo")
|
| 30 |
|
| 31 |
+
prompt = gr.Textbox(
|
| 32 |
+
label="Prompt",
|
| 33 |
+
placeholder="Ask something the model should answer helpfully...",
|
| 34 |
+
)
|
| 35 |
run_btn = gr.Button("Run GRPO Step")
|
| 36 |
|
| 37 |
output = gr.Textbox(label="Model Output")
|
|
|
|
| 41 |
|
| 42 |
plot = gr.Plot(label="Reward Over Time")
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
run_btn.click(
|
| 45 |
fn=run_step,
|
| 46 |
inputs=[prompt],
|
| 47 |
outputs=[output, reward_box, kl_box, loss_box, plot],
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
demo.launch()
|