Spaces:
Running
Running
File size: 4,608 Bytes
efa06b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import pandas as pd
import plotly.express as px
from config.constants import (
CC_BENCHMARKS,
LC_BENCHMARKS,
NON_RTL_METRICS,
RTL_METRICS,
S2R_BENCHMARKS,
SCATTER_PLOT_X_TICKS,
TYPE_COLORS,
Y_AXIS_LIMITS,
)
from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases
def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state):
"""Filter leaderboard data based on user selections."""
subset = state.get_current_df().copy()
# Filter by task specific benchmarks when 'All' benchmarks is selected
if task == "Spec-to-RTL":
valid_benchmarks = S2R_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Code Completion":
valid_benchmarks = CC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
elif task == "Line Completion †":
valid_benchmarks = LC_BENCHMARKS
if benchmark == "All":
subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
if benchmark != "All":
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
if model_type != "All":
# without emojis
subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
if search_query:
subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
max_params = float(max_params)
subset = subset[subset["Params"] <= max_params]
if benchmark == "All":
if task == "Spec-to-RTL":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R")
elif task == "Code Completion":
return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC")
elif task == "Line Completion †":
return filter_RTLRepo(subset)
elif benchmark == "RTL-Repo":
return filter_RTLRepo(subset)
else:
agg_column = None
if benchmark == "VerilogEval S2R":
agg_column = "Agg VerilogEval S2R"
elif benchmark == "VerilogEval MC":
agg_column = "Agg VerilogEval MC"
elif benchmark == "RTLLM":
agg_column = "Agg RTLLM"
elif benchmark == "VeriGen":
agg_column = "Agg VeriGen"
return filter_bench(subset, state.get_current_agg(), agg_column)
def generate_scatter_plot(benchmark, metric, state):
"""Generate a scatter plot for the given benchmark and metric."""
benchmark, metric = handle_special_cases(benchmark, metric)
subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
if benchmark == "RTL-Repo":
subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
else:
detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()
details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
subset=["Params", metric]
)
scatter_data["x"] = scatter_data["Params"]
scatter_data["y"] = scatter_data[metric]
scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40
scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")
y_range = Y_AXIS_LIMITS.get(metric, [0, 80])
fig = px.scatter(
scatter_data,
x="x",
y="y",
log_x=True,
size="size",
color="Model Type",
text="Model",
hover_data={metric: ":.2f"},
title=f"Params vs. {metric} for {benchmark}",
labels={"x": "# Params (Log Scale)", "y": metric},
template="plotly_white",
height=600,
width=1200,
)
fig.update_traces(
textposition="top center",
textfont_size=10,
marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
)
fig.update_layout(
xaxis=dict(
showgrid=True,
type="log",
tickmode="array",
tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
),
showlegend=False,
yaxis=dict(range=y_range),
margin=dict(l=50, r=50, t=50, b=50),
plot_bgcolor="white",
)
return fig
|