Spaces:
Runtime error
Runtime error
File size: 1,589 Bytes
7487ca9 e354192 7487ca9 4fc5777 55bd1b0 7487ca9 55bd1b0 7487ca9 bb39a36 4fc5777 bb39a36 4fc5777 bb39a36 7487ca9 e354192 5981a94 4fc5777 e354192 55bd1b0 e354192 4fc5777 bb39a36 7487ca9 bb39a36 7487ca9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
import matplotlib.pyplot as plt
from policy import load_policy_model, save_checkpoint
from reward_fn import reward_fn
from grpo_train import grpo_step
policy_model, gen_model, tokenizer = load_policy_model()
reward_history = []
global_step = 0
def plot_rewards(history):
fig = plt.figure()
plt.plot(history, marker="o")
plt.title("Reward History")
plt.xlabel("Step")
plt.ylabel("Reward")
return fig
def run_step(prompt):
global global_step
global_step += 1
result = grpo_step(
policy_model=policy_model,
gen_model=gen_model,
tokenizer=tokenizer,
prompt=prompt,
reward_fn=reward_fn,
)
reward_history.append(float(result["reward"]))
reward_plot = plot_rewards(reward_history)
if global_step % 10 == 0:
save_checkpoint(policy_model, global_step)
return result["text"], result["reward"], result["kl"], result["loss"], reward_plot
with gr.Blocks() as demo:
gr.Markdown("# π€ GRPO with Phi-2 β Helpfulness Reward Demo")
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask something the model should answer helpfully...",
)
run_btn = gr.Button("Run GRPO Step")
output = gr.Textbox(label="Model Output")
reward_box = gr.Number(label="Reward")
kl_box = gr.Number(label="KL")
loss_box = gr.Number(label="Loss")
plot = gr.Plot(label="Reward Over Time")
run_btn.click(
fn=run_step,
inputs=[prompt],
outputs=[output, reward_box, kl_box, loss_box, plot],
)
demo.launch()
|