import gradio as gr import matplotlib.pyplot as plt from policy import load_policy_model, save_checkpoint from reward_fn import reward_fn from grpo_train import grpo_step policy_model, gen_model, tokenizer = load_policy_model() reward_history = [] global_step = 0 def plot_rewards(history): fig = plt.figure() plt.plot(history, marker="o") plt.title("Reward History") plt.xlabel("Step") plt.ylabel("Reward") return fig def run_step(prompt): global global_step global_step += 1 result = grpo_step( policy_model=policy_model, gen_model=gen_model, tokenizer=tokenizer, prompt=prompt, reward_fn=reward_fn, ) reward_history.append(float(result["reward"])) reward_plot = plot_rewards(reward_history) if global_step % 10 == 0: save_checkpoint(policy_model, global_step) return result["text"], result["reward"], result["kl"], result["loss"], reward_plot with gr.Blocks() as demo: gr.Markdown("# 🤝 GRPO with Phi-2 — Helpfulness Reward Demo") prompt = gr.Textbox( label="Prompt", placeholder="Ask something the model should answer helpfully...", ) run_btn = gr.Button("Run GRPO Step") output = gr.Textbox(label="Model Output") reward_box = gr.Number(label="Reward") kl_box = gr.Number(label="KL") loss_box = gr.Number(label="Loss") plot = gr.Plot(label="Reward Over Time") run_btn.click( fn=run_step, inputs=[prompt], outputs=[output, reward_box, kl_box, loss_box, plot], ) demo.launch()