Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from policy import load_policy_model, save_checkpoint | |
| from reward_fn import reward_fn | |
| from grpo_train import grpo_step | |
| policy_model, gen_model, tokenizer = load_policy_model() | |
| reward_history = [] | |
| global_step = 0 | |
| def plot_rewards(history): | |
| fig = plt.figure() | |
| plt.plot(history, marker="o") | |
| plt.title("Reward History") | |
| plt.xlabel("Step") | |
| plt.ylabel("Reward") | |
| return fig | |
| def run_step(prompt): | |
| global global_step | |
| global_step += 1 | |
| result = grpo_step( | |
| policy_model=policy_model, | |
| gen_model=gen_model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| reward_fn=reward_fn, | |
| ) | |
| reward_history.append(float(result["reward"])) | |
| reward_plot = plot_rewards(reward_history) | |
| if global_step % 10 == 0: | |
| save_checkpoint(policy_model, global_step) | |
| return result["text"], result["reward"], result["kl"], result["loss"], reward_plot | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π€ GRPO with Phi-2 β Helpfulness Reward Demo") | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Ask something the model should answer helpfully...", | |
| ) | |
| run_btn = gr.Button("Run GRPO Step") | |
| output = gr.Textbox(label="Model Output") | |
| reward_box = gr.Number(label="Reward") | |
| kl_box = gr.Number(label="KL") | |
| loss_box = gr.Number(label="Loss") | |
| plot = gr.Plot(label="Reward Over Time") | |
| run_btn.click( | |
| fn=run_step, | |
| inputs=[prompt], | |
| outputs=[output, reward_box, kl_box, loss_box, plot], | |
| ) | |
| demo.launch() | |