File size: 10,321 Bytes
47e116f
 
 
 
 
b029cce
47e116f
b029cce
47e116f
 
b029cce
47e116f
b029cce
47e116f
b029cce
47e116f
b029cce
47e116f
 
 
 
b029cce
47e116f
 
b029cce
47e116f
 
 
 
 
 
 
 
 
 
 
 
b029cce
 
47e116f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b029cce
47e116f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b029cce
 
47e116f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b029cce
47e116f
 
 
b029cce
47e116f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
## **Setting Up the Development Environment**
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import gradio as gr

import torch

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

"""## **Building a Baseline Chatbot**"""

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the pretrained DialoGPT model and tokenizer
MODEL_NAME= "microsoft/DialoGPT-medium"
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Baseline chatbot function
chat_history_ids = None

def chatbot_response(user_input, chat_history_ids=None):
    new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
    # Add conversational history
    # torch.cat() concatenates tensors along the last dimension (dim=-1).
    # If this is the FIRST message (chat_history_ids is None), we just use new_input_ids.
    bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
    # Generate a response
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    # bot_input_ids.shape[-1] → length of the input tokens
    # chat_history_ids[:, bot_input_ids.shape[-1]:] → slice off the input, keep only newly generated tokens
    response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response

"""
## **Launch Your First Chatbot Locally**"""

css = """
/* Container */
.container {
    background-color: #fdf4f4;
    border-radius: 15px;
    box-shadow: 0 6px 20px rgba(0, 0, 0, 0.1);
    padding: 25px;
    font-family: 'Comic Sans MS', sans-serif;
}

/* Title */
h1 {
    text-align: center;
    font-size: 32px;
    color: #ff7f7f;
    font-weight: 600;
    margin-bottom: 25px;
    font-family: 'Pacifico', sans-serif;
}

/* Outer box */
.input_output_outerbox {
    background-color: #f8d3d3; /* Light pink */
    padding: 10px;
    border-radius: 15px;
    margin-bottom: 15px;
}

/* Input and Text area */
input[type="text"], textarea {
    width: 100%;
    padding: 18px 22px;
    font-size: 18px;
    border-radius: 25px;
    border: 2px solid #ff6f61;
    background-color: #fff9e6; /* Cream color */
    color: brown;
    font-weight: bold;
    outline: none;
    transition: border-color 0.3s ease;
}

/* Keep background and text color on focus */
input[type="text"]:focus, textarea:focus {
    border-color: #ff1493;
    background-color: #fff9e6 !important;
    color: brown;
    font-weight: bold;
    box-shadow: none;
}

/* Output */
.output_text {
    padding: 16px 22px;
    background-color: #2e082e;
    border-radius: 20px;
    font-size: 18px;
    color: brown;
    font-weight: bold;
    border: 1px solid #ff6f61;
    word-wrap: break-word;
    min-height: 60px;
}

/* Button */
button {
    background-color: #ff6f61;
    color: red;
    padding: 16px 28px;
    font-size: 20px;
    font-weight: bold;
    border-radius: 30px;
    border: none;
    cursor: pointer;
    width: 100%;
    transition: background-color 0.3s ease, transform 0.2s;
}

/* Button hover effect with animation */
button:hover {
    background-color: #ff1493;
    transform: scale(1.1);
}

/* Cute footer with smaller text */
footer {
    text-align: center;
    margin-top: 20px;
    font-size: 16px;
    color: #ff6f61;
}

"""

iface = gr.Interface(fn=chatbot_response,
                     theme="default",
                     inputs="text",
                     outputs="text",
                     title="Baseline Chatbot",
                     css=css)
iface.launch()

"""## **Fine-Tuning the Chatbot for Better Conversations (Most effective upgrade)**"""

# Load the SAMSum dataset (robust alternative to DailyDialog)
# Using the full namespace 'knkarthick/samsum' to ensure access
dataset = load_dataset("knkarthick/samsum")

# Rename 'dialogue' to 'dialog' to match the expected variable name
dataset = dataset.rename_column("dialogue", "dialog")

# Split the dataset into training and validation sets
# SAMSum already has 'train' and 'validation' splits
train_data = dataset["train"].shuffle(seed=42).select(range(len(dataset["train"]) // 20))
valid_data = dataset["validation"].shuffle(seed=42).select(range(len(dataset["validation"]) // 20))



tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    # Flatten multi-turn dialog structure
    text_list = ["" .join(dialog) if isinstance(dialog, list) else dialog for dialog in examples ["dialog"] ]
    # Tokenize each conversation
    model_inputs = tokenizer(text_list, padding="max_length", truncation=True, max_length=128)

    # Set labels = input_ids
    model_inputs["labels"] = model_inputs["input_ids"].copy()

    return model_inputs

# Tokenizing dataset
tokenized_train = train_data.map(tokenize_function, batched=True, remove_columns=["dialog"])
tokenized_valid = valid_data.map(tokenize_function, batched=True, remove_columns=["dialog"])

# Convert dataset format
tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
tokenized_valid.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

training_args = TrainingArguments(
    output_dir="./fine_tuned_chatbot",
    learning_rate=5e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    save_steps=500,
    save_total_limit=2 # keeping only the two most recent points
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid
)

import os
from transformers.integrations import WandbCallback

# Disable wandb logging environment variable
os.environ["WANDB_DISABLED"] = "true"

# Remove the WandbCallback that was added during Trainer initialization
# This is necessary because the Trainer was created before we disabled wandb
try:
    trainer.remove_callback(WandbCallback)
except ValueError:
    pass

# Train the model
trainer.train()

def chatbot_response(user_input):
    input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(model.device)
    output_ids = model.generate(
        input_ids,
        max_new_tokens=30,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_k=50,
        top_p=0.9,
        temperature=0.7,
        repetition_penalty=1.2
    )
    response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response



# Gradio UI

iface.launch()

"""#### **TESTED QUERIES**

Ex: How is it going?

Ex: I am feeling a bit stressed today. Any advice?

Ex: Can you explain why people dream?

Ex: Purple elephants dance faster in the rain, right?

## **Further Upgrading Chatbot Responses**

### **Upgrade 1: RAG (Retrieval-Augmented Generation)**
"""

# Small knowledge base
knowledge_base = {
    "huggingface": "Hugging Face is a company specializing in Natural Language Processing technologies.",
    "transformers": "Transformers are a type of deep learning model introduced in the paper 'Attention is All You Need'.",
    "gradio": "Gradio is a Python library that allows you to rapidly create user interfaces for machine learning models."
}

def retrieve_relevant_info(query):
    # Simple keyword matching
    # instead using BM25 or Dense Passage Retrieval methods
    for keyword, info in knowledge_base.items():
        if keyword.lower() in query.lower():
            return info
    return ""

def chatbot_response(user_input):

    retrieved_info = retrieve_relevant_info(user_input)
    augmented_prompt = (retrieved_info + "\n" if retrieved_info else "") + "User: " + user_input + "\nBot:"

    input_ids = tokenizer.encode(augmented_prompt, return_tensors="pt").to(model.device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=50,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_p=0.85,
        temperature=0.7,
        top_k=50,
        repetition_penalty=1.1
    )

    response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    return response.strip()

"""### **Upgrade 2: Improving Response Coherence and Context Awareness**"""

conversation_history = []

def chatbot_response(user_input):
    global conversation_history
    conversation_history.append(f"User: {user_input}")
    if len(conversation_history) > 6:  # Limit to last 6 turns
        conversation_history = conversation_history[-6:]

    prompt = "\n".join(conversation_history) + "\nBot:"

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=50,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_p=0.85,
        temperature=0.7,
        top_k=50,
        repetition_penalty=1.1
    )

    response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip()

    conversation_history.append(f"Bot: {response}")
    return response

"""### **Upgrade 3: Handle Uncertain Responses with Fallback Mechanism**"""

conversation_history = []

def chatbot_response(user_input):
    global conversation_history
    conversation_history.append(f"User: {user_input}")
    if len(conversation_history) > 6:
        conversation_history = conversation_history[-6:]

    prompt = "\n".join(conversation_history) + "\nBot:"

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=50,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        top_k=50,
        repetition_penalty=1.2
    )

    response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True).strip()

    # Fallback if response is too short or vague
    if not response or len(response.split()) <= 2:
        response = "I'm not sure I understood that. Could you please rephrase?"

    conversation_history.append(f"Bot: {response}")
    return response