Spaces:
Sleeping
Sleeping
| import datetime | |
| import difflib | |
| import json | |
| import re | |
| import tempfile | |
| import gradio as gr | |
| import polars as pl | |
| from gradio_modal import Modal | |
| from huggingface_hub import CommitOperationAdd, HfApi | |
| from table import PATCH_REPO_ID, df_orig | |
| # TODO: remove this once https://github.com/gradio-app/gradio/issues/11022 is fixed # noqa: FIX002, TD002 | |
| NOTE = """\ | |
| #### ⚠️ Note | |
| You may encounter an issue when selecting table data after using the search bar. | |
| This is due to a known bug in Gradio. | |
| The issue typically occurs when multiple rows remain after filtering. | |
| If only one row remains, the selection should work as expected. | |
| """ | |
| api = HfApi() | |
| PR_VIEW_COLUMNS = [ | |
| "title", | |
| "authors_str", | |
| "openreview_md", | |
| "arxiv_id", | |
| "github_md", | |
| "Spaces", | |
| "Models", | |
| "Datasets", | |
| "paper_id", | |
| ] | |
| PR_RAW_COLUMNS = [ | |
| "paper_id", | |
| "title", | |
| "authors", | |
| "arxiv_id", | |
| "project_page", | |
| "github", | |
| "space_ids", | |
| "model_ids", | |
| "dataset_ids", | |
| ] | |
| df_pr_view = df_orig.with_columns(pl.lit("📝").alias("Fix")).select(["Fix", *PR_VIEW_COLUMNS]) | |
| df_pr_view = df_pr_view.with_columns(pl.col("arxiv_id").fill_null("")) | |
| df_pr_raw = df_orig.select(PR_RAW_COLUMNS) | |
| def df_pr_row_selected( | |
| evt: gr.SelectData, | |
| ) -> tuple[ | |
| Modal, | |
| gr.Textbox, # title | |
| gr.Textbox, # authors | |
| gr.Textbox, # arxiv_id | |
| gr.Textbox, # project_page | |
| gr.Textbox, # github | |
| gr.Textbox, # space_ids | |
| gr.Textbox, # model_ids | |
| gr.Textbox, # dataset_ids | |
| dict | None, # original_data | |
| ]: | |
| if evt.value != "📝": | |
| return ( | |
| Modal(), | |
| gr.Textbox(), # title | |
| gr.Textbox(), # authors | |
| gr.Textbox(), # arxiv_id | |
| gr.Textbox(), # project_page | |
| gr.Textbox(), # github | |
| gr.Textbox(), # space_ids | |
| gr.Textbox(), # model_ids | |
| gr.Textbox(), # dataset_ids | |
| None, # original_data | |
| ) | |
| paper_id = evt.row_value[-1] | |
| row = df_pr_raw.filter(pl.col("paper_id") == paper_id) | |
| original_data = row.to_dicts()[0] | |
| authors = original_data["authors"] | |
| space_ids = original_data["space_ids"] | |
| model_ids = original_data["model_ids"] | |
| dataset_ids = original_data["dataset_ids"] | |
| return ( | |
| Modal(visible=True), | |
| gr.Textbox(value=row["title"].item()), # title | |
| gr.Textbox(value="\n".join(authors)), # authors | |
| gr.Textbox(value=row["arxiv_id"].item()), # arxiv_id | |
| gr.Textbox(value=row["project_page"].item()), # project_page | |
| gr.Textbox(value=row["github"].item()), # github | |
| gr.Textbox(value="\n".join(space_ids)), # space_ids | |
| gr.Textbox(value="\n".join(model_ids)), # model_ids | |
| gr.Textbox(value="\n".join(dataset_ids)), # dataset_ids | |
| original_data, # original_data | |
| ) | |
| URL_PATTERN = re.compile(r"^(https?://)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}(:\d+)?(/.*)?$") | |
| GITHUB_PATTERN = re.compile(r"^https://github\.com/[^/\s]+/[^/\s]+(/tree/[^/\s]+/[^/\s].*)?$") | |
| REPO_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$") | |
| ARXIV_ID_PATTERN = re.compile(r"^\d{4}\.\d{4,5}$") | |
| def is_valid_url(url: str) -> bool: | |
| return URL_PATTERN.match(url) is not None | |
| def is_valid_github_url(url: str) -> bool: | |
| return GITHUB_PATTERN.match(url) is not None | |
| def is_valid_repo_id(repo_id: str) -> bool: | |
| return REPO_ID_PATTERN.match(repo_id) is not None | |
| def is_valid_arxiv_id(arxiv_id: str) -> bool: | |
| return ARXIV_ID_PATTERN.match(arxiv_id) is not None | |
| def validate_pr_data( | |
| title_pr: str, | |
| authors_pr: str, | |
| arxiv_id_pr: str, | |
| project_page_pr: str, | |
| github_pr: str, | |
| space_ids: list[str], | |
| model_ids: list[str], | |
| dataset_ids: list[str], | |
| ) -> None: | |
| if not title_pr: | |
| raise gr.Error("Title cannot be empty", print_exception=False) | |
| if not authors_pr: | |
| raise gr.Error("Authors cannot be empty", print_exception=False) | |
| if arxiv_id_pr and not is_valid_arxiv_id(arxiv_id_pr): | |
| raise gr.Error( | |
| "Invalid arXiv ID format. Expected format: 'YYYY.NNNNN' (e.g., '2023.01234')", print_exception=False | |
| ) | |
| if project_page_pr and not is_valid_url(project_page_pr): | |
| raise gr.Error("Project page must be a valid URL", print_exception=False) | |
| if github_pr and not is_valid_github_url(github_pr): | |
| raise gr.Error("GitHub must be a valid GitHub URL", print_exception=False) | |
| for repo_id in space_ids + model_ids + dataset_ids: | |
| if not is_valid_repo_id(repo_id): | |
| error_msg = f"Space/Model/Dataset ID must be in the format 'org_name/repo_name'. Got: {repo_id}" | |
| raise gr.Error(error_msg, print_exception=False) | |
| def format_submitted_data( | |
| title_pr: str, | |
| authors_pr: str, | |
| arxiv_id_pr: str, | |
| project_page_pr: str, | |
| github_pr: str, | |
| space_ids_pr: str, | |
| model_ids_pr: str, | |
| dataset_ids_pr: str, | |
| ) -> dict: | |
| space_ids = [repo_id for repo_id in space_ids_pr.split("\n") if repo_id.strip()] | |
| model_ids = [repo_id for repo_id in model_ids_pr.split("\n") if repo_id.strip()] | |
| dataset_ids = [repo_id for repo_id in dataset_ids_pr.split("\n") if repo_id.strip()] | |
| validate_pr_data(title_pr, authors_pr, arxiv_id_pr, project_page_pr, github_pr, space_ids, model_ids, dataset_ids) | |
| return { | |
| "title": title_pr, | |
| "authors": [a for a in authors_pr.split("\n") if a.strip()], | |
| "arxiv_id": arxiv_id_pr if arxiv_id_pr else None, | |
| "project_page": project_page_pr if project_page_pr else None, | |
| "github": github_pr if github_pr else None, | |
| "space_ids": space_ids, | |
| "model_ids": model_ids, | |
| "dataset_ids": dataset_ids, | |
| } | |
| def preview_diff( | |
| title_pr: str, | |
| authors_pr: str, | |
| arxiv_id_pr: str, | |
| project_page_pr: str, | |
| github_pr: str, | |
| space_ids_pr: str, | |
| model_ids_pr: str, | |
| dataset_ids_pr: str, | |
| original_data: dict, | |
| ) -> tuple[gr.Markdown, gr.Button]: | |
| submitted_data = format_submitted_data( | |
| title_pr, | |
| authors_pr, | |
| arxiv_id_pr, | |
| project_page_pr, | |
| github_pr, | |
| space_ids_pr, | |
| model_ids_pr, | |
| dataset_ids_pr, | |
| ) | |
| submitted_data = {"paper_id": original_data["paper_id"], **submitted_data} | |
| original_json = json.dumps(original_data, indent=2) | |
| submitted_json = json.dumps(submitted_data, indent=2) | |
| diff = difflib.unified_diff( | |
| original_json.splitlines(), | |
| submitted_json.splitlines(), | |
| fromfile="before", | |
| tofile="after", | |
| lineterm="", | |
| ) | |
| diff_str = "\n".join(diff) | |
| return gr.Markdown(value=f"```diff\n{diff_str}\n```"), gr.Button(visible=True) | |
| def open_pr( | |
| title_pr: str, | |
| authors_pr: str, | |
| arxiv_id_pr: str, | |
| project_page_pr: str, | |
| github_pr: str, | |
| space_ids_pr: str, | |
| model_ids_pr: str, | |
| dataset_ids_pr: str, | |
| original_data: dict, | |
| oauth_token: gr.OAuthToken | None, | |
| ) -> gr.Markdown: | |
| submitted_data = format_submitted_data( | |
| title_pr, | |
| authors_pr, | |
| arxiv_id_pr, | |
| project_page_pr, | |
| github_pr, | |
| space_ids_pr, | |
| model_ids_pr, | |
| dataset_ids_pr, | |
| ) | |
| diff_dict = {key: submitted_data[key] for key in submitted_data if submitted_data[key] != original_data[key]} | |
| if not diff_dict: | |
| gr.Info("No data to submit") | |
| return "" | |
| paper_id = original_data["paper_id"] | |
| diff_dict["paper_id"] = paper_id | |
| original_json = json.dumps(original_data, indent=2) | |
| submitted_json = json.dumps(submitted_data, indent=2) | |
| diff = "\n".join( | |
| difflib.unified_diff( | |
| original_json.splitlines(), | |
| submitted_json.splitlines(), | |
| fromfile="before", | |
| tofile="after", | |
| lineterm="", | |
| ) | |
| ) | |
| diff_dict["diff"] = diff | |
| timestamp = datetime.datetime.now(datetime.timezone.utc) | |
| diff_dict["timestamp"] = timestamp.isoformat() | |
| with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: | |
| json.dump(diff_dict, f, indent=2) | |
| f.flush() | |
| commit = CommitOperationAdd(f"data/{paper_id}--{timestamp.strftime('%Y-%m-%d-%H-%M-%S')}.json", f.name) | |
| res = api.create_commit( | |
| repo_id=PATCH_REPO_ID, | |
| operations=[commit], | |
| commit_message=f"Update {paper_id}", | |
| repo_type="dataset", | |
| create_pr=True, | |
| token=oauth_token.token if oauth_token else None, | |
| ) | |
| return gr.Markdown(value=res.pr_url, visible=True) | |
| def render_open_pr_page(profile: gr.OAuthProfile | None) -> dict: | |
| return gr.Column(visible=profile is not None) | |
| with gr.Blocks() as demo: | |
| gr.LoginButton() | |
| with gr.Column(visible=False) as open_pr_col: | |
| gr.Markdown(NOTE) | |
| df_pr = gr.Dataframe( | |
| value=df_pr_view, | |
| datatype=[ | |
| "str", # Fix | |
| "str", # Title | |
| "str", # Authors | |
| "markdown", # openreview | |
| "str", # arxiv_id | |
| "markdown", # github | |
| "markdown", # spaces | |
| "markdown", # models | |
| "markdown", # datasets | |
| "str", # paper id | |
| ], | |
| column_widths=[ | |
| "50px", # Fix | |
| "40%", # Title | |
| "20%", # Authors | |
| None, # openreview | |
| "100px", # arxiv_id | |
| None, # github | |
| None, # spaces | |
| None, # models | |
| None, # datasets | |
| None, # paper id | |
| ], | |
| type="polars", | |
| row_count=(0, "dynamic"), | |
| interactive=False, | |
| max_height=1000, | |
| show_search="search", | |
| ) | |
| with Modal(visible=False) as pr_modal: | |
| with gr.Group(): | |
| title_pr = gr.Textbox(label="Title") | |
| authors_pr = gr.Textbox(label="Authors") | |
| arxiv_id_pr = gr.Textbox(label="arXiv ID") | |
| project_page_pr = gr.Textbox(label="Project page") | |
| github_pr = gr.Textbox(label="GitHub") | |
| spaces_pr = gr.Textbox( | |
| label="Spaces", | |
| info="Enter one space ID (e.g., 'org_name/space_name') per line.", | |
| ) | |
| models_pr = gr.Textbox( | |
| label="Models", | |
| info="Enter one model ID (e.g., 'org_name/model_name') per line.", | |
| ) | |
| datasets_pr = gr.Textbox( | |
| label="Datasets", | |
| info="Enter one dataset ID (e.g., 'org_name/dataset_name') per line.", | |
| ) | |
| original_data = gr.State() | |
| preview_diff_button = gr.Button("Preview diff") | |
| diff_view = gr.Markdown() | |
| open_pr_button = gr.Button("Open PR", visible=False) | |
| pr_url = gr.Markdown(visible=False) | |
| pr_modal.blur( | |
| fn=lambda: (None, gr.Button(visible=False), gr.Markdown(visible=False)), | |
| outputs=[diff_view, open_pr_button, pr_url], | |
| ) | |
| df_pr.select( | |
| fn=df_pr_row_selected, | |
| outputs=[ | |
| pr_modal, | |
| title_pr, | |
| authors_pr, | |
| arxiv_id_pr, | |
| project_page_pr, | |
| github_pr, | |
| spaces_pr, | |
| models_pr, | |
| datasets_pr, | |
| original_data, | |
| ], | |
| ) | |
| preview_diff_button.click( | |
| fn=preview_diff, | |
| inputs=[ | |
| title_pr, | |
| authors_pr, | |
| arxiv_id_pr, | |
| project_page_pr, | |
| github_pr, | |
| spaces_pr, | |
| models_pr, | |
| datasets_pr, | |
| original_data, | |
| ], | |
| outputs=[diff_view, open_pr_button], | |
| ) | |
| open_pr_button.click( | |
| fn=open_pr, | |
| inputs=[ | |
| title_pr, | |
| authors_pr, | |
| arxiv_id_pr, | |
| project_page_pr, | |
| github_pr, | |
| spaces_pr, | |
| models_pr, | |
| datasets_pr, | |
| original_data, | |
| ], | |
| outputs=pr_url, | |
| ) | |
| demo.load(fn=render_open_pr_page, outputs=open_pr_col) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False).launch(show_api=False) | |