Spaces:
Sleeping
Sleeping
| ## **Setting Up the Development Environment** | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| import gradio as gr | |
| import torch | |
| # Check if GPU is available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| """## **Building a Baseline Chatbot**""" | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Load the pretrained DialoGPT model and tokenizer | |
| MODEL_NAME= "microsoft/DialoGPT-medium" | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Baseline chatbot function | |
| chat_history_ids = None | |
| def chatbot_response(user_input, chat_history_ids=None): | |
| new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt") | |
| # Add conversational history | |
| # torch.cat() concatenates tensors along the last dimension (dim=-1). | |
| # If this is the FIRST message (chat_history_ids is None), we just use new_input_ids. | |
| bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids | |
| # Generate a response | |
| chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) | |
| # bot_input_ids.shape[-1] → length of the input tokens | |
| # chat_history_ids[:, bot_input_ids.shape[-1]:] → slice off the input, keep only newly generated tokens | |
| response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
| return response | |
| """ | |
| ## **Launch Your First Chatbot Locally**""" | |
| css = """ | |
| /* Container */ | |
| .container { | |
| background-color: #fdf4f4; | |
| border-radius: 15px; | |
| box-shadow: 0 6px 20px rgba(0, 0, 0, 0.1); | |
| padding: 25px; | |
| font-family: 'Comic Sans MS', sans-serif; | |
| } | |
| /* Title */ | |
| h1 { | |
| text-align: center; | |
| font-size: 32px; | |
| color: #ff7f7f; | |
| font-weight: 600; | |
| margin-bottom: 25px; | |
| font-family: 'Pacifico', sans-serif; | |
| } | |
| /* Outer box */ | |
| .input_output_outerbox { | |
| background-color: #f8d3d3; /* Light pink */ | |
| padding: 10px; | |
| border-radius: 15px; | |
| margin-bottom: 15px; | |
| } | |
| /* Input and Text area */ | |
| input[type="text"], textarea { | |
| width: 100%; | |
| padding: 18px 22px; | |
| font-size: 18px; | |
| border-radius: 25px; | |
| border: 2px solid #ff6f61; | |
| background-color: #fff9e6; /* Cream color */ | |
| color: brown; | |
| font-weight: bold; | |
| outline: none; | |
| transition: border-color 0.3s ease; | |
| } | |
| /* Keep background and text color on focus */ | |
| input[type="text"]:focus, textarea:focus { | |
| border-color: #ff1493; | |
| background-color: #fff9e6 !important; | |
| color: brown; | |
| font-weight: bold; | |
| box-shadow: none; | |
| } | |
| /* Output */ | |
| .output_text { | |
| padding: 16px 22px; | |
| background-color: #2e082e; | |
| border-radius: 20px; | |
| font-size: 18px; | |
| color: brown; | |
| font-weight: bold; | |
| border: 1px solid #ff6f61; | |
| word-wrap: break-word; | |
| min-height: 60px; | |
| } | |
| /* Button */ | |
| button { | |
| background-color: #ff6f61; | |
| color: red; | |
| padding: 16px 28px; | |
| font-size: 20px; | |
| font-weight: bold; | |
| border-radius: 30px; | |
| border: none; | |
| cursor: pointer; | |
| width: 100%; | |
| transition: background-color 0.3s ease, transform 0.2s; | |
| } | |
| /* Button hover effect with animation */ | |
| button:hover { | |
| background-color: #ff1493; | |
| transform: scale(1.1); | |
| } | |
| /* Cute footer with smaller text */ | |
| footer { | |
| text-align: center; | |
| margin-top: 20px; | |
| font-size: 16px; | |
| color: #ff6f61; | |
| } | |
| """ | |
| iface = gr.Interface(fn=chatbot_response, | |
| theme="default", | |
| inputs="text", | |
| outputs="text", | |
| title="Baseline Chatbot", | |
| css=css) | |
| iface.launch() | |
| """## **Fine-Tuning the Chatbot for Better Conversations (Most effective upgrade)**""" | |
| # Load the SAMSum dataset (robust alternative to DailyDialog) | |
| # Using the full namespace 'knkarthick/samsum' to ensure access | |
| dataset = load_dataset("knkarthick/samsum") | |
| # Rename 'dialogue' to 'dialog' to match the expected variable name | |
| dataset = dataset.rename_column("dialogue", "dialog") | |
| # Split the dataset into training and validation sets | |
| # SAMSum already has 'train' and 'validation' splits | |
| train_data = dataset["train"].shuffle(seed=42).select(range(len(dataset["train"]) // 20)) | |
| valid_data = dataset["validation"].shuffle(seed=42).select(range(len(dataset["validation"]) // 20)) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def tokenize_function(examples): | |
| # Flatten multi-turn dialog structure | |
| text_list = ["" .join(dialog) if isinstance(dialog, list) else dialog for dialog in examples ["dialog"] ] | |
| # Tokenize each conversation | |
| model_inputs = tokenizer(text_list, padding="max_length", truncation=True, max_length=128) | |
| # Set labels = input_ids | |
| model_inputs["labels"] = model_inputs["input_ids"].copy() | |
| return model_inputs | |
| # Tokenizing dataset | |
| tokenized_train = train_data.map(tokenize_function, batched=True, remove_columns=["dialog"]) | |
| tokenized_valid = valid_data.map(tokenize_function, batched=True, remove_columns=["dialog"]) | |
| # Convert dataset format | |
| tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) | |
| tokenized_valid.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"]) | |
| training_args = TrainingArguments( | |
| output_dir="./fine_tuned_chatbot", | |
| learning_rate=5e-5, | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| num_train_epochs=3, | |
| save_steps=500, | |
| save_total_limit=2 # keeping only the two most recent points | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_train, | |
| eval_dataset=tokenized_valid | |
| ) | |
| import os | |
| from transformers.integrations import WandbCallback | |
| # Disable wandb logging environment variable | |
| os.environ["WANDB_DISABLED"] = "true" | |
| # Remove the WandbCallback that was added during Trainer initialization | |
| # This is necessary because the Trainer was created before we disabled wandb | |
| try: | |
| trainer.remove_callback(WandbCallback) | |
| except ValueError: | |
| pass | |
| # Train the model | |
| trainer.train() | |
| def chatbot_response(user_input): | |
| input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(model.device) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=30, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.9, | |
| temperature=0.7, | |
| repetition_penalty=1.2 | |
| ) | |
| response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
| return response | |
| # Gradio UI | |
| iface.launch() | |
| """#### **TESTED QUERIES** | |
| Ex: How is it going? | |
| Ex: I am feeling a bit stressed today. Any advice? | |
| Ex: Can you explain why people dream? | |
| Ex: Purple elephants dance faster in the rain, right? | |
| ## **Further Upgrading Chatbot Responses** | |
| ### **Upgrade 1: RAG (Retrieval-Augmented Generation)** | |
| """ | |
| # Small knowledge base | |
| knowledge_base = { | |
| "huggingface": "Hugging Face is a company specializing in Natural Language Processing technologies.", | |
| "transformers": "Transformers are a type of deep learning model introduced in the paper 'Attention is All You Need'.", | |
| "gradio": "Gradio is a Python library that allows you to rapidly create user interfaces for machine learning models." | |
| } | |
| def retrieve_relevant_info(query): | |
| # Simple keyword matching | |
| # instead using BM25 or Dense Passage Retrieval methods | |
| for keyword, info in knowledge_base.items(): | |
| if keyword.lower() in query.lower(): | |
| return info | |
| return "" | |
| def chatbot_response(user_input): | |
| retrieved_info = retrieve_relevant_info(user_input) | |
| augmented_prompt = (retrieved_info + "\n" if retrieved_info else "") + "User: " + user_input + "\nBot:" | |
| input_ids = tokenizer.encode(augmented_prompt, return_tensors="pt").to(model.device) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_p=0.85, | |
| temperature=0.7, | |
| top_k=50, | |
| repetition_penalty=1.1 | |
| ) | |
| response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
| return response.strip() | |
| """### **Upgrade 2: Improving Response Coherence and Context Awareness**""" | |
| conversation_history = [] | |
| def chatbot_response(user_input): | |
| global conversation_history | |
| conversation_history.append(f"User: {user_input}") | |
| if len(conversation_history) > 6: # Limit to last 6 turns | |
| conversation_history = conversation_history[-6:] | |
| prompt = "\n".join(conversation_history) + "\nBot:" | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_p=0.85, | |
| temperature=0.7, | |
| top_k=50, | |
| repetition_penalty=1.1 | |
| ) | |
| response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip() | |
| conversation_history.append(f"Bot: {response}") | |
| return response | |
| """### **Upgrade 3: Handle Uncertain Responses with Fallback Mechanism**""" | |
| conversation_history = [] | |
| def chatbot_response(user_input): | |
| global conversation_history | |
| conversation_history.append(f"User: {user_input}") | |
| if len(conversation_history) > 6: | |
| conversation_history = conversation_history[-6:] | |
| prompt = "\n".join(conversation_history) + "\nBot:" | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=50, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.8, | |
| top_k=50, | |
| repetition_penalty=1.2 | |
| ) | |
| response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip() | |
| # Fallback if response is too short or vague | |
| if not response or len(response.split()) <= 2: | |
| response = "I'm not sure I understood that. Could you please rephrase?" | |
| conversation_history.append(f"Bot: {response}") | |
| return response | |