Spaces:
Runtime error
Runtime error
| import math | |
| import pickle | |
| import re | |
| import bs4 | |
| import torch | |
| from GoogleNews import GoogleNews | |
| from tqdm import tqdm | |
| from kb import KB | |
| from newspaper import Article, ArticleException | |
| from pyvis.network import Network | |
| def extract_relations_from_model_output(text): | |
| relations = [] | |
| relation, subject, relation, object_ = '', '', '', '' | |
| text = text.strip() | |
| current = 'x' | |
| text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "") | |
| for token in text_replaced.split(): | |
| if token == "<obj>": | |
| current = 'o' | |
| relation = '' | |
| elif token == "<subj>": | |
| current = 's' | |
| if relation != '': | |
| relations.append({ | |
| 'head': subject.strip(), | |
| 'type': relation.strip(), | |
| 'tail': object_.strip() | |
| }) | |
| object_ = '' | |
| elif token == "<triplet>": | |
| current = 't' | |
| if relation != '': | |
| relations.append({ | |
| 'head': subject.strip(), | |
| 'type': relation.strip(), | |
| 'tail': object_.strip() | |
| }) | |
| relation = '' | |
| subject = '' | |
| elif current == 'o': | |
| relation += f' {token}' | |
| elif current == 's': | |
| object_ += f' {token}' | |
| elif current == 't': | |
| subject += f' {token}' | |
| if subject != '' and relation != '' and object_ != '': | |
| relations.append({ | |
| 'head': subject.strip(), | |
| 'type': relation.strip(), | |
| 'tail': object_.strip() | |
| }) | |
| return relations | |
| def from_text_to_kb(text, model, tokenizer, article_url, span_length=128, article_title=None, | |
| article_publish_date=None, verbose=False): | |
| # tokenize whole text | |
| print('Start tokenizing') | |
| inputs = tokenizer([text], return_tensors="pt") | |
| print('End tokenizing') | |
| # compute span boundaries | |
| num_tokens = len(inputs["input_ids"][0]) | |
| if verbose: | |
| print(f"Input has {num_tokens} tokens") | |
| num_spans = math.ceil(num_tokens / span_length) | |
| if verbose: | |
| print(f"Input has {num_spans} spans") | |
| overlap = math.ceil((num_spans * span_length - num_tokens) / | |
| max(num_spans - 1, 1)) | |
| spans_boundaries = [] | |
| start = 0 | |
| for i in tqdm(range(num_spans)): | |
| spans_boundaries.append([start + span_length * i, | |
| start + span_length * (i + 1)]) | |
| start -= overlap | |
| if verbose: | |
| print(f"Span boundaries are {spans_boundaries}") | |
| # transform input with spans | |
| tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] | |
| for boundary in spans_boundaries] | |
| tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] | |
| for boundary in spans_boundaries] | |
| inputs = { | |
| "input_ids": torch.stack(tensor_ids), | |
| "attention_mask": torch.stack(tensor_masks) | |
| } | |
| # generate relations | |
| num_return_sequences = 3 | |
| gen_kwargs = { | |
| "max_length": 256, | |
| "length_penalty": 0, | |
| "num_beams": 3, | |
| "num_return_sequences": num_return_sequences | |
| } | |
| generated_tokens = model.generate( | |
| **inputs, | |
| **gen_kwargs, | |
| ) | |
| # decode relations | |
| decoded_preds = tokenizer.batch_decode(generated_tokens, | |
| skip_special_tokens=False) | |
| # create kb | |
| kb = KB() | |
| for i, sentence_pred in enumerate(decoded_preds): | |
| current_span_index = i // num_return_sequences | |
| relations = extract_relations_from_model_output(sentence_pred) | |
| for relation in relations: | |
| relation["meta"] = { | |
| article_url: { | |
| "spans": [spans_boundaries[current_span_index]] | |
| } | |
| } | |
| kb.add_relation(relation, article_title, article_publish_date) | |
| return kb | |
| def get_article(url): | |
| article = Article(url) | |
| article.download() | |
| article.parse() | |
| return article | |
| def from_url_to_kb(url, model, tokenizer): | |
| article = get_article(url) | |
| config = { | |
| "article_title": article.title, | |
| "article_publish_date": article.publish_date | |
| } | |
| return from_text_to_kb(article.text, model, tokenizer, article.url, **config) | |
| def get_news_links(query, lang="en", region="US", pages=1): | |
| googlenews = GoogleNews(lang=lang, region=region) | |
| googlenews.search(query) | |
| all_urls = [] | |
| for page in range(pages): | |
| googlenews.get_page(page) | |
| all_urls += googlenews.get_links() | |
| return list(set(all_urls)) | |
| def from_urls_to_kb(urls, model, tokenizer, verbose=False): | |
| kb = KB() | |
| if verbose: | |
| print(f"{len(urls)} links to visit") | |
| for url in urls: | |
| if verbose: | |
| print(f"Visiting {url}...") | |
| try: | |
| kb_url = from_url_to_kb(url, model, tokenizer) | |
| kb.merge_with_kb(kb_url) | |
| except ArticleException: | |
| if verbose: | |
| print(f" Couldn't download article at url {url}") | |
| return kb | |
| def save_network_html(kb, filename="network.html"): | |
| # create network | |
| net = Network(directed=True, width="700px", height="700px") | |
| # nodes | |
| color_entity = "#00FF00" | |
| for e in kb.entities: | |
| net.add_node(e, shape="circle", color=color_entity) | |
| # edges | |
| for r in kb.relations: | |
| net.add_edge(r["head"], r["tail"], | |
| title=r["type"], label=r["type"]) | |
| # save network | |
| net.repulsion( | |
| node_distance=200, | |
| central_gravity=0.2, | |
| spring_length=200, | |
| spring_strength=0.05, | |
| damping=0.09 | |
| ) | |
| net.set_edge_smooth('dynamic') | |
| net.show_buttons(filter_=['physics']) | |
| net.show(filename) | |
| def save_kb(kb, filename): | |
| with open(filename, "wb") as f: | |
| pickle.dump(kb, f) | |
| class CustomUnpickler(pickle.Unpickler): | |
| def find_class(self, module, name): | |
| return KB if name == 'KB' else super().find_class(module, name) | |
| def load_kb(filename): | |
| res = None | |
| with open(filename, "rb") as f: | |
| res = CustomUnpickler(f).load() | |
| return res | |
| def process_transcript(src, dist): | |
| with open(src, 'r') as src: | |
| html = bs4.BeautifulSoup(src.read()) | |
| transcript = html.findChildren('div', {'class': 'transcript-line'}) | |
| with open(dist, 'w') as dist: | |
| transcript_texts = map(lambda x: x.find('span', {'class': 'transcript-text'}).text, transcript) | |
| transcript_texts = map(lambda text: re.sub(r'\s(?=\s)','',re.sub(r'\s',' ', text)) , transcript_texts) | |
| text = ' '.join(transcript_texts) | |
| dist.write(text) |