Sneha7's picture
Update app.py
bb39a36 verified
raw
history blame
1.29 kB
import gradio as gr
from policy import load_policy_model
from reward_fn import reward_fn
from grpo_train import grpo_step
import matplotlib.pyplot as plt
model, tokenizer = load_policy_model()
reward_history = []
def plot_rewards(history):
fig = plt.figure()
plt.plot(history, marker="o")
plt.title("Reward History")
plt.xlabel("Step")
plt.ylabel("Reward")
return fig
def run_step(prompt):
result = grpo_step(model, tokenizer, prompt, reward_fn)
reward_history.append(float(result["reward"]))
reward_plot = plot_rewards(reward_history)
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
with gr.Blocks() as demo:
gr.Markdown("# 🀝 GRPO with Phi-2 β€” Helpfulness Reward Demo")
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask something the model should answer helpfully...",
)
run_btn = gr.Button("Run GRPO Step")
output = gr.Textbox(label="Model Output")
reward_box = gr.Number(label="Reward")
kl_box = gr.Number(label="KL")
loss_box = gr.Number(label="Loss")
plot = gr.Plot(label="Reward Over Time")
run_btn.click(
fn=run_step,
inputs=[prompt],
outputs=[output, reward_box, kl_box, loss_box, plot],
)
demo.launch()