-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule_chatbot.py
78 lines (63 loc) · 2.91 KB
/
module_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 12 18:25:49 2024
@author: Jiaqi Ye
"""
from transformers import pipeline
import streamlit as st
import torch
# Chatbot Module
class ChatbotModule:
def __init__(self):
# Initialize predefined responses
self.responses = {
"Hi": "Hello!",
"How are you?": "I am fine, thanks! How are you?",
}
# Attempt to load the advanced chatbot model
device = 0 if torch.cuda.is_available() else -1
try:
self.advanced_chatbot = pipeline(task="text2text-generation", model="facebook/blenderbot_small-90M", device=device)
except Exception as e:
self.advanced_chatbot = None
st.error(f"Failed to load the advanced chatbot model: {e}")
def simple_chatbot(self, text):
"""Predefined response chatbot."""
return self.responses.get(text, "Sorry, I don't understand that.")
def get_response(self, text, bot="simple"):
"""Get response from the selected chatbot."""
if bot == "advanced" and self.advanced_chatbot:
return self.advanced_chatbot(text)[0]['generated_text']
elif bot == "simple":
return self.simple_chatbot(text)
else:
return "The advanced chatbot is currently unavailable. Please use predefined responses."
def display_chatbot(self):
"""Streamlit GUI for the chatbot."""
st.subheader("Interactive Chatbot")
# Initialize session state for chat history
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Text input for user messages
user_input = st.text_input("Message to the chatbot:", "")
if st.button("Send") and user_input:
# Cache the response to minimize repeated computation
@st.cache_data(show_spinner=False)
def cached_response(text, bot="simple"):
return self.get_response(text, bot)
# Choose which chatbot to use (toggle between advanced and simple for testing)
chatbot_type = "advanced" if self.advanced_chatbot else "simple"
# Generate response
response = cached_response(user_input, bot=chatbot_type)
# Append the user input and response to the chat history
st.session_state.chat_history.append(("user", user_input))
st.session_state.chat_history.append(("chatbot", response))
# Display chat history with alignment
for role, message in st.session_state.chat_history:
if role == "user":
st.markdown(f"<div style='text-align:right;'><b>You:</b> {message}</div>", unsafe_allow_html=True)
elif role == "chatbot":
st.markdown(f"<div style='text-align:left;'><b>Chatbot:</b> {message}</div>", unsafe_allow_html=True)
# if __name__ == "__main__":
# chatbot_app = ChatbotModule()
# chatbot_app.main()