Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import requests | |
| import pandas as pd | |
| import altair as alt | |
| from collections import OrderedDict | |
| from nltk.tokenize import sent_tokenize | |
| import trafilatura | |
| import validators | |
| # Load the punkt tokenizer from nltk | |
| import nltk | |
| nltk.download('punkt') | |
| # Load model and tokenizer | |
| model_name = 'dejanseo/sentiment' #Load model adapted from | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Sentiment labels as textual descriptions | |
| sentiment_labels = { | |
| 0: "very positive", | |
| 1: "positive", | |
| 2: "somewhat positive", | |
| 3: "neutral", | |
| 4: "somewhat negative", | |
| 5: "negative", | |
| 6: "very negative" | |
| } | |
| # Background colors for sentiments | |
| background_colors = { | |
| "very positive": "rgba(0, 255, 0, 0.5)", | |
| "positive": "rgba(0, 255, 0, 0.3)", | |
| "somewhat positive": "rgba(0, 255, 0, 0.1)", | |
| "neutral": "rgba(128, 128, 128, 0.1)", | |
| "somewhat negative": "rgba(255, 0, 0, 0.1)", | |
| "negative": "rgba(255, 0, 0, 0.3)", | |
| "very negative": "rgba(255, 0, 0, 0.5)" | |
| } | |
| # Function to classify text and return sentiment scores | |
| def classify_text(text, max_length): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1).squeeze().tolist() | |
| return probabilities | |
| # Function to get text content from a URL, restricted to Medium stories/articles | |
| def get_text_from_url(url): | |
| if not validators.url(url): | |
| return None, "Invalid URL" | |
| if "medium.com/" not in url: # Check if it's a Medium URL | |
| return None, "URL is not a Medium story/article." | |
| try: | |
| downloaded = trafilatura.fetch_url(url) | |
| if downloaded: | |
| return trafilatura.extract(downloaded), None | |
| else: | |
| return None, "Could not download content from URL." | |
| except Exception as e: | |
| return None, f"Error extracting text: {e}" | |
| # Function to handle long texts | |
| def classify_long_text(text): | |
| max_length = tokenizer.model_max_length | |
| # Split the text into chunks | |
| chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)] | |
| aggregate_scores = [0] * len(sentiment_labels) | |
| chunk_scores_list = [] | |
| for chunk in chunks: | |
| chunk_scores = classify_text(chunk, max_length) | |
| chunk_scores_list.append(chunk_scores) | |
| aggregate_scores = [x + y for x, y in zip(aggregate_scores, chunk_scores)] | |
| # Average the scores | |
| aggregate_scores = [x / len(chunks) for x in aggregate_scores] | |
| return aggregate_scores, chunk_scores_list, chunks | |
| # Function to classify each sentence in the text | |
| def classify_sentences(text): | |
| sentences = sent_tokenize(text) | |
| sentence_scores = [] | |
| for sentence in sentences: | |
| scores = classify_text(sentence, tokenizer.model_max_length) | |
| sentiment_idx = scores.index(max(scores)) | |
| sentiment = sentiment_labels[sentiment_idx] | |
| sentence_scores.append((sentence, sentiment)) | |
| return sentence_scores | |
| # Streamlit UI | |
| st.title("Sentiment Classification Model (Medium Only)") | |
| url = st.text_input("Enter Medium URL:") | |
| if url: | |
| text, error_message = get_text_from_url(url) | |
| if error_message: | |
| st.error(error_message) # Display error message | |
| elif text: | |
| # ... (rest of the analysis and display code remains the same) | |
| scores, chunk_scores_list, chunks = classify_long_text(text) | |
| scores_dict = {sentiment_labels[i]: scores[i] for i in range(len(sentiment_labels))} | |
| # Ensure the exact order of labels in the graph | |
| sentiment_order = [ | |
| "very positive", "positive", "somewhat positive", | |
| "neutral", | |
| "somewhat negative", "negative", "very negative" | |
| ] | |
| ordered_scores_dict = OrderedDict((label, scores_dict[label]) for label in sentiment_order) | |
| # Prepare the DataFrame and reindex | |
| df = pd.DataFrame.from_dict(ordered_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order) | |
| # Use Altair to plot the bar chart | |
| chart = alt.Chart(df.reset_index()).mark_bar().encode( | |
| x=alt.X('index', sort=sentiment_order, title='Sentiment'), | |
| y='Likelihood' | |
| ).properties( | |
| width=600, | |
| height=400 | |
| ) | |
| st.altair_chart(chart, use_container_width=True) | |
| # Display each chunk and its own chart | |
| for i, (chunk_scores, chunk) in enumerate(zip(chunk_scores_list, chunks)): | |
| chunk_scores_dict = {sentiment_labels[j]: chunk_scores[j] for j in range(len(sentiment_labels))} | |
| ordered_chunk_scores_dict = OrderedDict((label, chunk_scores_dict[label]) for label in sentiment_order) | |
| df_chunk = pd.DataFrame.from_dict(ordered_chunk_scores_dict, orient='index', columns=['Likelihood']).reindex(sentiment_order) | |
| chunk_chart = alt.Chart(df_chunk.reset_index()).mark_bar().encode( | |
| x=alt.X('index', sort=sentiment_order, title='Sentiment'), | |
| y='Likelihood' | |
| ).properties( | |
| width=600, | |
| height=400 | |
| ) | |
| st.write(f"Chunk {i + 1}:") | |
| # Sentence-level classification with background colors | |
| st.write("Extracted Text with Sentiment Highlights:") | |
| sentence_scores = classify_sentences(text) | |
| for sentence, sentiment in sentence_scores: | |
| bg_color = background_colors[sentiment] | |
| st.markdown(f'<span style="background-color: {bg_color}">{sentence}</span>', unsafe_allow_html=True) |