Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import gradio as gr | |
| from transformers import BlenderbotTokenizer | |
| from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration, BlenderbotConfig | |
| from transformers import BlenderbotTokenizerFast | |
| import contextlib | |
| #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") | |
| #model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") | |
| #tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B") | |
| mname = "facebook/blenderbot-3B" | |
| #configuration = BlenderbotConfig.from_pretrained(mname) | |
| tokenizer = BlenderbotTokenizerFast.from_pretrained(mname) | |
| model = BlenderbotForConditionalGeneration.from_pretrained(mname) | |
| #tokenizer = BlenderbotTokenizer.from_pretrained(mname) | |
| #-----------new chat----------- | |
| print(mname + 'model loaded') | |
| def predict(input,history=[]): | |
| history.append(input) | |
| listToStr= '</s> <s>'.join([str(elem)for elem in history[len(history)-3:]]) | |
| #print('listToStr -->',str(listToStr)) | |
| input_ids = tokenizer([(listToStr)], return_tensors="pt",max_length=512,truncation=True) | |
| next_reply_ids = model.generate(**input_ids,max_length=512, pad_token_id=tokenizer.eos_token_id) | |
| response = tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0] | |
| history.append(response) | |
| response = [(history[i], history[i+1]) for i in range(0, len(history)-1, 2)] # convert to tuples of list | |
| return response, history | |
| demo = gr.Interface(fn=predict, inputs=["text",'state'], outputs=["chatbot",'state']) | |
| demo.launch() |