Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import markdown | |
| def create_chat_html(messages, dataset_id, offset, compare_mode=False, column=""): | |
| chat_html = "" | |
| turn_number = 1 | |
| for i in range(0, len(messages), 2): | |
| user_message = messages[i] | |
| system_message = messages[i + 1] if i + 1 < len(messages) else None | |
| user_role = user_message["role"] | |
| user_content = user_message["content"] | |
| user_content_html = markdown.markdown(user_content) | |
| user_content_length = len(user_content) | |
| user_html = f'<div class="user-message" style="justify-content: right;">' | |
| user_html += f'<div class="message-content">' | |
| user_html += ( | |
| f"<strong>Turn {turn_number} - {user_role.capitalize()}:</strong><br>" | |
| ) | |
| user_html += f"<em>Length: {user_content_length} characters</em><br><br>" | |
| user_html += f"{user_content_html}" | |
| user_html += "</div></div>" | |
| chat_html += user_html | |
| if system_message: | |
| system_role = system_message["role"] | |
| system_content = system_message["content"] | |
| system_content_html = markdown.markdown(system_content) | |
| system_content_length = len(system_content) | |
| system_html = f'<div class="system-message" style="justify-content: left;">' | |
| system_html += f'<div class="message-content">' | |
| system_html += f"<strong>{system_role.capitalize()}:</strong><br>" | |
| system_html += ( | |
| f"<em>Length: {system_content_length} characters</em><br><br>" | |
| ) | |
| system_html += f"{system_content_html}" | |
| system_html += "</div></div>" | |
| chat_html += system_html | |
| turn_number += 1 | |
| if compare_mode: | |
| chat_html = f'<div class="column {column}">{chat_html}</div>' | |
| style = """ | |
| <style> | |
| .user-message, .system-message { | |
| display: flex; | |
| margin: 10px; | |
| } | |
| .user-message .message-content { | |
| background-color: #c2e3f7; | |
| color: #000000; | |
| } | |
| .system-message .message-content { | |
| background-color: #f5f5f5; | |
| color: #000000; | |
| } | |
| .message-content { | |
| padding: 10px; | |
| border-radius: 10px; | |
| max-width: 70%; | |
| word-wrap: break-word; | |
| } | |
| .container { | |
| display: flex; | |
| justify-content: space-between; | |
| } | |
| .column { | |
| width: 48%; | |
| } | |
| </style> | |
| """ | |
| dataset_url = f"https://huggingface.co/datasets/{dataset_id}/viewer/default/train?row={offset}" | |
| dataset_link = f"[View dataset row]({dataset_url})" | |
| return dataset_link, style + chat_html | |
| def fetch_data( | |
| dataset_id, chosen_column, rejected_column, current_offset, direction, compare_mode | |
| ): | |
| change = 1 if direction == "Next" else -1 | |
| new_offset = max(0, current_offset + change) | |
| base_url = f"https://datasets-server.huggingface.co/rows?dataset={dataset_id}&config=default&split=train&offset={new_offset}&length=1" | |
| response = requests.get(base_url) | |
| if response.status_code != 200: | |
| return "", "Failed to fetch data", new_offset | |
| data = response.json() | |
| if compare_mode: | |
| if chosen_column and rejected_column: | |
| chosen_messages = data["rows"][0]["row"].get(chosen_column, []) | |
| rejected_messages = data["rows"][0]["row"].get(rejected_column, []) | |
| chosen_link, chosen_html = create_chat_html( | |
| chosen_messages, | |
| dataset_id, | |
| new_offset, | |
| compare_mode=True, | |
| column="chosen", | |
| ) | |
| rejected_link, rejected_html = create_chat_html( | |
| rejected_messages, | |
| dataset_id, | |
| new_offset, | |
| compare_mode=True, | |
| column="rejected", | |
| ) | |
| chat_html = f'<div class="container">{chosen_html}{rejected_html}</div>' | |
| else: | |
| return ( | |
| "", | |
| "Please provide both chosen and rejected columns for comparison", | |
| new_offset, | |
| ) | |
| else: | |
| if chosen_column: | |
| messages = data["rows"][0]["row"].get(chosen_column, []) | |
| else: | |
| for key, value in data["rows"][0]["row"].items(): | |
| if ( | |
| isinstance(value, list) | |
| and len(value) > 0 | |
| and isinstance(value[0], dict) | |
| and "role" in value[0] | |
| ): | |
| messages = value | |
| break | |
| else: | |
| return "", "No suitable chat column found", new_offset | |
| _, chat_html = create_chat_html(messages, dataset_id, new_offset) | |
| dataset_url = f"https://huggingface.co/datasets/{dataset_id}/viewer/default/train?row={new_offset}" | |
| dataset_link = f"[View dataset row]({dataset_url})" | |
| return dataset_link, chat_html, new_offset | |
| def update_column_names(compare_mode): | |
| if compare_mode: | |
| return "chosen", "rejected" | |
| else: | |
| return "", "" | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| dataset_id = gr.Textbox( | |
| label="Dataset ID", placeholder="e.g., davanstrien/cosmochat" | |
| ) | |
| chosen_column = gr.Textbox( | |
| label="Chosen Column", | |
| placeholder="Column containing chosen chat data", | |
| ) | |
| rejected_column = gr.Textbox( | |
| label="Rejected Column", | |
| placeholder="Column containing rejected chat data", | |
| ) | |
| compare_mode = gr.Checkbox(label="Compare chosen and rejected chats") | |
| current_offset = gr.State(value=0) | |
| with gr.Row(): | |
| back_button = gr.Button("Back") | |
| next_button = gr.Button("Next") | |
| dataset_link = gr.Markdown() | |
| output_html = gr.HTML() | |
| compare_mode.change( | |
| fn=update_column_names, | |
| inputs=compare_mode, | |
| outputs=[chosen_column, rejected_column], | |
| ) | |
| back_button.click( | |
| lambda data, chosen, rejected, offset, compare: fetch_data( | |
| data, chosen, rejected, offset, "Back", compare | |
| ), | |
| inputs=[ | |
| dataset_id, | |
| chosen_column, | |
| rejected_column, | |
| current_offset, | |
| compare_mode, | |
| ], | |
| outputs=[dataset_link, output_html, current_offset], | |
| ) | |
| next_button.click( | |
| lambda data, chosen, rejected, offset, compare: fetch_data( | |
| data, chosen, rejected, offset, "Next", compare | |
| ), | |
| inputs=[ | |
| dataset_id, | |
| chosen_column, | |
| rejected_column, | |
| current_offset, | |
| compare_mode, | |
| ], | |
| outputs=[dataset_link, output_html, current_offset], | |
| ) | |
| demo.launch(debug=True, share=True) | |