Spaces:
Running
Running
| import gradio as gr | |
| from PyPDF2 import PdfReader | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from io import BytesIO | |
| from transformers import AutoTokenizer | |
| import os | |
| from openai import OpenAI | |
| # Cache for tokenizers to avoid reloading | |
| tokenizer_cache = {} | |
| # Function to fetch paper information from OpenReview | |
| def fetch_paper_info_neurips(paper_id): | |
| url = f"https://openreview.net/forum?id={paper_id}" | |
| response = requests.get(url) | |
| if response.status_code != 200: | |
| return None, None | |
| html_content = response.content | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| # Extract title | |
| title_tag = soup.find('h2', class_='citation_title') | |
| title = title_tag.get_text(strip=True) if title_tag else 'Title not found' | |
| # Extract authors | |
| authors = [] | |
| author_div = soup.find('div', class_='forum-authors') | |
| if author_div: | |
| author_tags = author_div.find_all('a') | |
| authors = [tag.get_text(strip=True) for tag in author_tags] | |
| author_list = ', '.join(authors) if authors else 'Authors not found' | |
| # Extract abstract | |
| abstract_div = soup.find('strong', text='Abstract:') | |
| if abstract_div: | |
| abstract_paragraph = abstract_div.find_next_sibling('div') | |
| abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' | |
| else: | |
| abstract = 'Abstract not found' | |
| # Construct preamble in Markdown | |
| # preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n**Abstract:**\n{abstract}" | |
| preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n" | |
| return preamble | |
| def fetch_paper_content(paper_id): | |
| try: | |
| # Construct the URL | |
| url = f"https://openreview.net/pdf?id={paper_id}" | |
| # Fetch the PDF | |
| response = requests.get(url) | |
| response.raise_for_status() # Raise an exception for HTTP errors | |
| # Read the PDF content | |
| pdf_content = BytesIO(response.content) | |
| reader = PdfReader(pdf_content) | |
| # Extract text from the PDF | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() | |
| return text # Return full text; truncation will be handled later | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| return None | |
| def paper_chat_tab(paper_id): | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| # Textbox to display the paper title and authors | |
| content = gr.Markdown(value="") | |
| # Preamble message to hint the user | |
| gr.Markdown("**Note:** Providing your own sambanova token can help you avoid rate limits.") | |
| # Input for Hugging Face token | |
| hf_token_input = gr.Textbox( | |
| label="Enter your sambanova token (optional)", | |
| type="password", | |
| placeholder="Enter your sambanova token to avoid rate limits" | |
| ) | |
| models = [ | |
| "Meta-Llama-3.1-8B-Instruct", | |
| "Meta-Llama-3.1-70B-Instruct", | |
| "Meta-Llama-3.1-405B-Instruct", | |
| ] | |
| default_model = models[-1] | |
| # Dropdown for selecting the model | |
| model_dropdown = gr.Dropdown( | |
| label="Select Model", | |
| choices=models, | |
| value=default_model | |
| ) | |
| # State to store the paper content | |
| paper_content = gr.State() | |
| # Create a column for each model, only visible if it's the default model | |
| columns = [] | |
| for model_name in models: | |
| column = gr.Column(visible=(model_name == default_model)) | |
| with column: | |
| chatbot = create_chat_interface(model_name, paper_content, hf_token_input) | |
| columns.append(column) | |
| gr.HTML( | |
| '<img src="https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg" width="100px" />') | |
| gr.Markdown("**Note:** This model is supported by SambaNova.") | |
| # Update visibility of columns based on the selected model | |
| def update_columns(selected_model): | |
| visibility = [] | |
| for model_name in models: | |
| is_visible = model_name == selected_model | |
| visibility.append(gr.update(visible=is_visible)) | |
| return visibility | |
| model_dropdown.change( | |
| fn=update_columns, | |
| inputs=model_dropdown, | |
| outputs=columns, | |
| api_name=False, | |
| queue=False, | |
| ) | |
| # Function to update the content Markdown and paper_content when paper ID or model changes | |
| def update_paper_info(paper_id, selected_model): | |
| preamble = fetch_paper_info_neurips(paper_id) | |
| text = fetch_paper_content(paper_id) | |
| if text is None: | |
| return preamble, None | |
| return preamble, text | |
| # Update paper content when paper ID or model changes | |
| paper_id.change( | |
| fn=update_paper_info, | |
| inputs=[paper_id, model_dropdown], | |
| outputs=[content, paper_content] | |
| ) | |
| model_dropdown.change( | |
| fn=update_paper_info, | |
| inputs=[paper_id, model_dropdown], | |
| outputs=[content, paper_content], | |
| queue=False, | |
| ) | |
| return demo | |
| def create_chat_interface(model_name, paper_content, hf_token_input): | |
| # Load tokenizer and cache it | |
| if model_name not in tokenizer_cache: | |
| # Load the tokenizer from Hugging Face | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") | |
| tokenizer_cache[model_name] = tokenizer | |
| else: | |
| tokenizer = tokenizer_cache[model_name] | |
| max_total_tokens = 50000 # Maximum tokens allowed | |
| # Define the function to handle the chat | |
| def get_fn(message, history, paper_content_value, hf_token_value): | |
| # Include the paper content as context | |
| if paper_content_value: | |
| context = f"The following is the content of the paper:\n{paper_content_value}\n\n" | |
| else: | |
| context = "" | |
| # Tokenize the context | |
| context_tokens = tokenizer.encode(context) | |
| context_token_length = len(context_tokens) | |
| # Prepare the messages without context | |
| messages = [] | |
| message_tokens_list = [] | |
| total_tokens = context_token_length # Start with context tokens | |
| for user_msg, assistant_msg in history: | |
| # Tokenize user message | |
| user_tokens = tokenizer.encode(user_msg) | |
| messages.append({"role": "user", "content": user_msg}) | |
| message_tokens_list.append(len(user_tokens)) | |
| total_tokens += len(user_tokens) | |
| # Tokenize assistant message | |
| if assistant_msg: | |
| assistant_tokens = tokenizer.encode(assistant_msg) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| message_tokens_list.append(len(assistant_tokens)) | |
| total_tokens += len(assistant_tokens) | |
| # Tokenize the new user message | |
| message_tokens = tokenizer.encode(message) | |
| messages.append({"role": "user", "content": message}) | |
| message_tokens_list.append(len(message_tokens)) | |
| total_tokens += len(message_tokens) | |
| # Check if total tokens exceed the maximum allowed tokens | |
| if total_tokens > max_total_tokens: | |
| # Attempt to truncate the context first | |
| available_tokens = max_total_tokens - (total_tokens - context_token_length) | |
| if available_tokens > 0: | |
| # Truncate the context to fit the available tokens | |
| truncated_context_tokens = context_tokens[:available_tokens] | |
| context = tokenizer.decode(truncated_context_tokens) | |
| context_token_length = available_tokens | |
| total_tokens = total_tokens - len(context_tokens) + context_token_length | |
| else: | |
| # Not enough space for context; remove it | |
| context = "" | |
| total_tokens -= context_token_length | |
| context_token_length = 0 | |
| # If total tokens still exceed the limit, truncate the message history | |
| while total_tokens > max_total_tokens and len(messages) > 1: | |
| # Remove the oldest message | |
| removed_message = messages.pop(0) | |
| removed_tokens = message_tokens_list.pop(0) | |
| total_tokens -= removed_tokens | |
| # Rebuild the final messages list including the (possibly truncated) context | |
| final_messages = [] | |
| if context: | |
| final_messages.append({"role": "system", "content": context}) | |
| final_messages.extend(messages) | |
| # Use the Hugging Face token if provided | |
| api_key = hf_token_value or os.environ.get("SAMBANOVA_API_KEY") | |
| if not api_key: | |
| raise ValueError("API token is not provided.") | |
| # Initialize the OpenAI client | |
| client = OpenAI( | |
| base_url="https://api.sambanova.ai/v1/", | |
| api_key=api_key, | |
| ) | |
| try: | |
| # Create the chat completion | |
| completion = client.chat.completions.create( | |
| model=model_name, | |
| messages=final_messages, | |
| stream=True, | |
| ) | |
| response_text = "" | |
| for chunk in completion: | |
| delta = chunk.choices[0].delta.content or "" | |
| response_text += delta | |
| yield response_text | |
| except Exception as e: | |
| error_message = f"Error: {str(e)}" | |
| yield error_message | |
| # Create the ChatInterface | |
| chat_interface = gr.ChatInterface( | |
| fn=get_fn, | |
| chatbot=gr.Chatbot( | |
| label="Chatbot", | |
| scale=1, | |
| height=400, | |
| autoscroll=True | |
| ), | |
| additional_inputs=[paper_content, hf_token_input], | |
| # examples=["What are the main findings of this paper?", "Explain the methodology used in this research."] | |
| ) | |
| return chat_interface | |