Spaces:
Configuration error
Configuration error
| import evaluate | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import ast | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| plt.rcParams["figure.dpi"] = 300 | |
| plt.switch_backend( | |
| "agg" | |
| ) # ; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop | |
| def default_plot(): | |
| fig = plt.figure() | |
| ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) | |
| ax2 = plt.subplot2grid((3, 1), (2, 0)) | |
| ranged = np.linspace(0, 1, 10) | |
| ax1.plot( | |
| ranged, | |
| ranged, | |
| color="darkgreen", | |
| ls="dotted", | |
| label="Perfect", | |
| ) | |
| # Bin differences | |
| ax1.set_ylabel("Conditional Expectation") | |
| ax1.set_ylim([0, 1.05]) | |
| ax1.set_title("Reliability Diagram") | |
| ax1.set_xlim([-0.05, 1.05]) # respective to bin range | |
| # Bin frequencies | |
| ax2.set_xlabel("Confidence") | |
| ax2.set_ylabel("Count") | |
| ax2.set_xlim([-0.05, 1.05]) # respective to bin range | |
| return fig, ax1, ax2 | |
| def reliability_plot(results): | |
| # DEV: might still need to write tests in case of equal mass binning | |
| # DEV: nicer would be to plot like a polygon | |
| # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py | |
| fig, ax1, ax2 = default_plot() | |
| # Bin differences | |
| bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0) | |
| bins_with_right_edge = np.insert(results["y_bar"], -1, 1.0, axis=0) | |
| bins_with_leftright_edge = np.insert(bins_with_left_edge, -1, 1.0, axis=0) | |
| weights = np.nan_to_num(results["p_bar"], copy=True, nan=0) | |
| # NOTE: the histogram API is strange | |
| _, _, patches = ax1.hist( | |
| bins_with_left_edge, | |
| weights=weights, | |
| bins=bins_with_leftright_edge, | |
| ) | |
| for b in range(len(patches)): | |
| perfect = bins_with_right_edge[b] # if b != n_bins else | |
| empirical = weights[b] # patches[b]._height | |
| bin_color = ( | |
| "limegreen" | |
| if perfect == empirical | |
| else "dodgerblue" | |
| if empirical < perfect | |
| else "orangered" | |
| ) | |
| patches[b].set_facecolor(bin_color) # color based on over/underconfidence | |
| ax1handles = [ | |
| mpatches.Patch(color="orangered", label="Overconfident"), | |
| mpatches.Patch(color="limegreen", label="Perfect", linestyle="dotted"), | |
| mpatches.Patch(color="dodgerblue", label="Underconfident"), | |
| ] | |
| # Bin frequencies | |
| anindices = np.where(~np.isnan(results["p_bar"]))[0] | |
| bin_freqs = np.zeros(len(results["p_bar"])) | |
| bin_freqs[anindices] = results["bin_freq"] | |
| ax2.hist( | |
| bins_with_left_edge, weights=bin_freqs, color="midnightblue", bins=bins_with_leftright_edge | |
| ) | |
| acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy") | |
| conf_plt = ax2.axvline( | |
| x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" | |
| ) | |
| ax1.legend(loc="lower right", handles=ax1handles) | |
| ax2.legend(handles=[acc_plt, conf_plt]) | |
| ax1.set_xticks(bins_with_left_edge) | |
| ax2.set_xticks(bins_with_left_edge) | |
| plt.tight_layout() | |
| return fig | |
| def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): | |
| # DEV: check on invalid datatypes with better warnings | |
| if isinstance(data, pd.DataFrame): | |
| data.dropna(inplace=True) | |
| predictions = [ | |
| ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction | |
| for prediction in data["predictions"] | |
| ] | |
| references = [reference for reference in data["references"]] | |
| results = metric._compute( | |
| predictions, | |
| references, | |
| n_bins=n_bins, | |
| scheme=scheme, | |
| proxy=proxy, | |
| p=p, | |
| detail=True, | |
| ) | |
| plot = reliability_plot(results) | |
| return results["ECE"], plot | |
| sliders = [ | |
| gr.Slider(0, 100, value=10, label="n_bins"), | |
| gr.Slider( | |
| 0, 100, value=None, label="bin_range", visible=False | |
| ), # DEV: need to have a double slider | |
| gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), | |
| gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), | |
| gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), | |
| ] | |
| slider_defaults = [slider.value for slider in sliders] | |
| # example data | |
| component = gr.inputs.Dataframe( | |
| headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" | |
| ) | |
| component.value = [ | |
| [[0.6, 0.2, 0.2], 0], | |
| [[0.7, 0.1, 0.2], 2], | |
| [[0, 0.95, 0.05], 1], | |
| ] | |
| sample_data = [[component] + slider_defaults] | |
| local_path = Path(sys.path[0]) | |
| metric = evaluate.load("jordyvl/ece") | |
| outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] | |
| # outputs[1].value = default_plot().__dict__ #DEV: Does not work in gradio; needs to be JSON encoded | |
| iface = gr.Interface( | |
| fn=compute_and_plot, | |
| inputs=[component] + sliders, | |
| outputs=outputs, | |
| description=metric.info.description, | |
| article=evaluate.utils.parse_readme(local_path / "README.md"), | |
| title=f"Metric: {metric.name}", | |
| # examples=sample_data; #DEV: ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. | |
| ).launch() | |