Sneha7's picture
Update app.py
55bd1b0 verified
import gradio as gr
import matplotlib.pyplot as plt
from policy import load_policy_model, save_checkpoint
from reward_fn import reward_fn
from grpo_train import grpo_step
policy_model, gen_model, tokenizer = load_policy_model()
reward_history = []
global_step = 0
def plot_rewards(history):
fig = plt.figure()
plt.plot(history, marker="o")
plt.title("Reward History")
plt.xlabel("Step")
plt.ylabel("Reward")
return fig
def run_step(prompt):
global global_step
global_step += 1
result = grpo_step(
policy_model=policy_model,
gen_model=gen_model,
tokenizer=tokenizer,
prompt=prompt,
reward_fn=reward_fn,
)
reward_history.append(float(result["reward"]))
reward_plot = plot_rewards(reward_history)
if global_step % 10 == 0:
save_checkpoint(policy_model, global_step)
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
with gr.Blocks() as demo:
gr.Markdown("# 🀝 GRPO with Phi-2 β€” Helpfulness Reward Demo")
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask something the model should answer helpfully...",
)
run_btn = gr.Button("Run GRPO Step")
output = gr.Textbox(label="Model Output")
reward_box = gr.Number(label="Reward")
kl_box = gr.Number(label="KL")
loss_box = gr.Number(label="Loss")
plot = gr.Plot(label="Reward Over Time")
run_btn.click(
fn=run_step,
inputs=[prompt],
outputs=[output, reward_box, kl_box, loss_box, plot],
)
demo.launch()