File size: 4,608 Bytes
efa06b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import pandas as pd
import plotly.express as px

from config.constants import (
    CC_BENCHMARKS,
    LC_BENCHMARKS,
    NON_RTL_METRICS,
    RTL_METRICS,
    S2R_BENCHMARKS,
    SCATTER_PLOT_X_TICKS,
    TYPE_COLORS,
    Y_AXIS_LIMITS,
)
from utils import filter_bench, filter_bench_all, filter_RTLRepo, handle_special_cases


def filter_leaderboard(task, benchmark, model_type, search_query, max_params, state):
    """Filter leaderboard data based on user selections."""
    subset = state.get_current_df().copy()

    # Filter by task specific benchmarks when 'All' benchmarks is selected
    if task == "Spec-to-RTL":
        valid_benchmarks = S2R_BENCHMARKS
        if benchmark == "All":
            subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
    elif task == "Code Completion":
        valid_benchmarks = CC_BENCHMARKS
        if benchmark == "All":
            subset = subset[subset["Benchmark"].isin(valid_benchmarks)]
    elif task == "Line Completion †":
        valid_benchmarks = LC_BENCHMARKS
        if benchmark == "All":
            subset = subset[subset["Benchmark"].isin(valid_benchmarks)]

    if benchmark != "All":
        subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]

    if model_type != "All":
        # without emojis
        subset = subset[subset["Model Type"] == model_type.split(" ")[0]]
    if search_query:
        subset = subset[subset["Model"].str.contains(search_query, case=False, na=False)]
    max_params = float(max_params)
    subset = subset[subset["Params"] <= max_params]

    if benchmark == "All":
        if task == "Spec-to-RTL":
            return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg S2R")
        elif task == "Code Completion":
            return filter_bench_all(subset, state.get_current_agg(), agg_column="Agg MC")
        elif task == "Line Completion †":
            return filter_RTLRepo(subset)
    elif benchmark == "RTL-Repo":
        return filter_RTLRepo(subset)
    else:
        agg_column = None
        if benchmark == "VerilogEval S2R":
            agg_column = "Agg VerilogEval S2R"
        elif benchmark == "VerilogEval MC":
            agg_column = "Agg VerilogEval MC"
        elif benchmark == "RTLLM":
            agg_column = "Agg RTLLM"
        elif benchmark == "VeriGen":
            agg_column = "Agg VeriGen"

        return filter_bench(subset, state.get_current_agg(), agg_column)


def generate_scatter_plot(benchmark, metric, state):
    """Generate a scatter plot for the given benchmark and metric."""
    benchmark, metric = handle_special_cases(benchmark, metric)

    subset = state.get_current_df()[state.get_current_df()["Benchmark"] == benchmark]
    if benchmark == "RTL-Repo":
        subset = subset[subset["Metric"].str.contains("EM", case=False, na=False)]
        detailed_scores = subset.groupby("Model", as_index=False)["Score"].mean()
        detailed_scores.rename(columns={"Score": "Exact Matching (EM)"}, inplace=True)
    else:
        detailed_scores = subset.pivot_table(index="Model", columns="Metric", values="Score").reset_index()

    details = state.get_current_df()[["Model", "Params", "Model Type"]].drop_duplicates("Model")
    scatter_data = pd.merge(detailed_scores, details, on="Model", how="left").dropna(
        subset=["Params", metric]
    )

    scatter_data["x"] = scatter_data["Params"]
    scatter_data["y"] = scatter_data[metric]
    scatter_data["size"] = (scatter_data["x"] ** 0.3) * 40

    scatter_data["color"] = scatter_data["Model Type"].map(TYPE_COLORS).fillna("gray")

    y_range = Y_AXIS_LIMITS.get(metric, [0, 80])

    fig = px.scatter(
        scatter_data,
        x="x",
        y="y",
        log_x=True,
        size="size",
        color="Model Type",
        text="Model",
        hover_data={metric: ":.2f"},
        title=f"Params vs. {metric} for {benchmark}",
        labels={"x": "# Params (Log Scale)", "y": metric},
        template="plotly_white",
        height=600,
        width=1200,
    )

    fig.update_traces(
        textposition="top center",
        textfont_size=10,
        marker=dict(opacity=0.8, line=dict(width=0.5, color="black")),
    )
    fig.update_layout(
        xaxis=dict(
            showgrid=True,
            type="log",
            tickmode="array",
            tickvals=SCATTER_PLOT_X_TICKS["tickvals"],
            ticktext=SCATTER_PLOT_X_TICKS["ticktext"],
        ),
        showlegend=False,
        yaxis=dict(range=y_range),
        margin=dict(l=50, r=50, t=50, b=50),
        plot_bgcolor="white",
    )

    return fig