streamlit_ui#
A sample User Interface powered by Streamlit implementing a WebSocket client that connects to the bot WebSocket server.
1import base64
2import json
3import queue
4import sys
5import threading
6import time
7from datetime import datetime
8
9import pandas as pd
10import plotly
11import streamlit as st
12import websocket
13from audio_recorder_streamlit import audio_recorder
14from streamlit.runtime import Runtime
15from streamlit.runtime.app_session import AppSession
16from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
17from streamlit.web import cli as stcli
18
19from besser.bot.core.file import File
20from besser.bot.core.message import Message, MessageType
21from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder
22
23# Time interval to check if a streamlit session is still active, in seconds
24SESSION_MONITORING_INTERVAL = 1
25
26
27def get_streamlit_session() -> AppSession or None:
28 session_id = get_script_run_ctx().session_id
29 runtime: Runtime = Runtime.instance()
30 return next((
31 s.session
32 for s in runtime._session_mgr.list_sessions()
33 if s.session.id == session_id
34 ), None)
35
36
37def session_monitoring(interval: int):
38 runtime: Runtime = Runtime.instance()
39 session = get_streamlit_session()
40 while True:
41 time.sleep(interval)
42 if not runtime.is_active_session(session.id):
43 runtime.close_session(session.id)
44 session.session_state['websocket'].close()
45 break
46
47
48def main():
49 try:
50 # We get the websocket host and port from the script arguments
51 bot_name = sys.argv[1]
52 except Exception as e:
53 # If they are not provided, we use default values
54 bot_name = 'Chatbot Demo'
55 st.header(bot_name)
56 st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)")
57 # User input component. Must be declared before history writing
58 user_input = st.chat_input("What is up?")
59
60 def on_message(ws, payload_str):
61 # https://github.com/streamlit/streamlit/issues/2838
62 streamlit_session = get_streamlit_session()
63 payload: Payload = Payload.decode(payload_str)
64 if payload.action == PayloadAction.BOT_REPLY_STR.value:
65 content = payload.message
66 t = MessageType.STR
67 elif payload.action == PayloadAction.BOT_REPLY_FILE.value:
68 content = payload.message
69 t = MessageType.FILE
70 elif payload.action == PayloadAction.BOT_REPLY_DF.value:
71 content = pd.read_json(payload.message)
72 t = MessageType.DATAFRAME
73 elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value:
74 content = plotly.io.from_json(payload.message)
75 t = MessageType.PLOTLY
76 elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value:
77 content = {
78 'latitude': [payload.message['latitude']],
79 'longitude': [payload.message['longitude']]
80 }
81 t = MessageType.LOCATION
82 elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value:
83 t = MessageType.OPTIONS
84 d = json.loads(payload.message)
85 content = []
86 for button in d.values():
87 content.append(button)
88 elif payload.action == PayloadAction.BOT_REPLY_RAG.value:
89 t = MessageType.RAG_ANSWER
90 content = payload.message
91 message = Message(t=t, content=content, is_user=False, timestamp=datetime.now())
92 streamlit_session._session_state['queue'].put(message)
93 streamlit_session._handle_rerun_script_request()
94
95 def on_error(ws, error):
96 pass
97
98 def on_open(ws):
99 pass
100
101 def on_close(ws, close_status_code, close_msg):
102 pass
103
104 def on_ping(ws, data):
105 pass
106
107 def on_pong(ws, data):
108 pass
109
110 user_type = {
111 0: 'assistant',
112 1: 'user'
113 }
114
115 if 'history' not in st.session_state:
116 st.session_state['history'] = []
117
118 if 'queue' not in st.session_state:
119 st.session_state['queue'] = queue.Queue()
120
121 if 'websocket' not in st.session_state:
122 try:
123 # We get the websocket host and port from the script arguments
124 host = sys.argv[2]
125 port = sys.argv[3]
126 except Exception as e:
127 # If they are not provided, we use default values
128 host = 'localhost'
129 port = '8765'
130 ws = websocket.WebSocketApp(f"ws://{host}:{port}/",
131 on_open=on_open,
132 on_message=on_message,
133 on_error=on_error,
134 on_close=on_close,
135 on_ping=on_ping,
136 on_pong=on_pong)
137 websocket_thread = threading.Thread(target=ws.run_forever)
138 add_script_run_ctx(websocket_thread)
139 websocket_thread.start()
140 st.session_state['websocket'] = ws
141
142 if 'session_monitoring' not in st.session_state:
143 session_monitoring_thread = threading.Thread(target=session_monitoring,
144 kwargs={'interval': SESSION_MONITORING_INTERVAL})
145 add_script_run_ctx(session_monitoring_thread)
146 session_monitoring_thread.start()
147 st.session_state['session_monitoring'] = session_monitoring_thread
148
149 ws = st.session_state['websocket']
150
151 with st.sidebar:
152
153 if reset_button := st.button(label="Reset bot"):
154 st.session_state['history'] = []
155 st.session_state['queue'] = queue.Queue()
156 payload = Payload(action=PayloadAction.RESET)
157 ws.send(json.dumps(payload, cls=PayloadEncoder))
158
159 if voice_bytes := audio_recorder(text=None, pause_threshold=2):
160 if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes:
161 st.session_state['last_voice_message'] = voice_bytes
162 # Encode the audio bytes to a base64 string
163 voice_message = Message(t=MessageType.AUDIO, content=voice_bytes, is_user=True, timestamp=datetime.now())
164 st.session_state.history.append(voice_message)
165 voice_base64 = base64.b64encode(voice_bytes).decode('utf-8')
166 payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64)
167 try:
168 ws.send(json.dumps(payload, cls=PayloadEncoder))
169 except Exception as e:
170 st.error('Your message could not be sent. The connection is already closed')
171 if uploaded_file := st.file_uploader("Choose a file", accept_multiple_files=False):
172 if 'last_file' not in st.session_state or st.session_state['last_file'] != uploaded_file:
173 st.session_state['last_file'] = uploaded_file
174 bytes_data = uploaded_file.read()
175 file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, file_type=uploaded_file.type)
176 payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string())
177 file_message = Message(t=MessageType.FILE, content=file_object.to_dict(), is_user=True, timestamp=datetime.now())
178 st.session_state.history.append(file_message)
179 try:
180 ws.send(json.dumps(payload, cls=PayloadEncoder))
181 except Exception as e:
182 st.error('Your message could not be sent. The connection is already closed')
183 for message in st.session_state['history']:
184 with st.chat_message(user_type[message.is_user]):
185 if message.type == MessageType.AUDIO:
186 st.audio(message.content, format="audio/wav")
187 elif message.type == MessageType.FILE:
188 file: File = File.from_dict(message.content)
189 file_name = file.name
190 file_type = file.type
191 file_data = base64.b64decode(file.base64.encode('utf-8'))
192 st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
193 key=file_name + str(time.time()))
194 elif message.type == MessageType.LOCATION:
195 st.map(message.content)
196 elif message.type == MessageType.RAG_ANSWER:
197 # TODO: Avoid duplicate in history and queue
198 st.write(f'🔮 {message.content["answer"]}')
199 with st.expander('Details'):
200 st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**')
201 st.write(f'It received the following documents as input to come up with a relevant answer:')
202 if 'docs' in message.content:
203 for i, doc in enumerate(message.content['docs']):
204 st.write(f'**Document {i + 1}/{len(message.content["docs"])}**')
205 st.write(f'- **Source:** {doc["metadata"]["source"]}')
206 st.write(f'- **Page:** {doc["metadata"]["page"]}')
207 st.write(f'- **Content:** {doc["content"]}')
208 else:
209 st.write(message.content)
210
211 first_message = True
212 while not st.session_state['queue'].empty():
213 with st.chat_message("assistant"):
214 message = st.session_state['queue'].get()
215 if hasattr(message, '__len__'):
216 t = len(message.content) / 1000 * 3
217 else:
218 t = 2
219 if t > 3:
220 t = 3
221 elif t < 1 and first_message:
222 t = 1
223 first_message = False
224 if message.type == MessageType.OPTIONS:
225 st.session_state['buttons'] = message.content
226 elif message.type == MessageType.FILE:
227 st.session_state['history'].append(message)
228 with st.spinner(''):
229 time.sleep(t)
230 file: File = File.from_dict(message.content)
231 file_name = file.name
232 file_type = file.type
233 file_data = base64.b64decode(file.base64.encode('utf-8'))
234 st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
235 key=file_name + str(time.time()))
236 elif message.type == MessageType.LOCATION:
237 st.session_state['history'].append(message)
238 st.map(message.content)
239 elif message.type == MessageType.RAG_ANSWER:
240 st.session_state['history'].append(message)
241 st.write(f'🔮 {message.content["answer"]}')
242 with st.expander('Details'):
243 st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**')
244 st.write(f'It received the following documents as input to come up with a relevant answer:')
245 if 'docs' in message.content:
246 for i, doc in enumerate(message.content['docs']):
247 st.write(f'**Document {i + 1}/{len(message.content["docs"])}**')
248 st.write(f'- **Source:** {doc["metadata"]["source"]}')
249 st.write(f'- **Page:** {doc["metadata"]["page"]}')
250 st.write(f'- **Content:** {doc["content"]}')
251 elif message.type == MessageType.STR:
252 st.session_state['history'].append(message)
253 with st.spinner(''):
254 time.sleep(t)
255 st.write(message.content)
256
257 if 'buttons' in st.session_state:
258 buttons = st.session_state['buttons']
259 cols = st.columns(1)
260 for i, option in enumerate(buttons):
261 if cols[0].button(option):
262 with st.chat_message("user"):
263 st.write(option)
264 message = Message(t=MessageType.STR, content=option, is_user=True, timestamp=datetime.now())
265 st.session_state.history.append(message)
266 payload = Payload(action=PayloadAction.USER_MESSAGE,
267 message=option)
268 ws.send(json.dumps(payload, cls=PayloadEncoder))
269 del st.session_state['buttons']
270 break
271
272 if user_input:
273 if 'buttons' in st.session_state:
274 del st.session_state['buttons']
275 with st.chat_message("user"):
276 st.write(user_input)
277 message = Message(t=MessageType.STR, content=user_input, is_user=True, timestamp=datetime.now())
278 st.session_state.history.append(message)
279 payload = Payload(action=PayloadAction.USER_MESSAGE,
280 message=user_input)
281 try:
282 ws.send(json.dumps(payload, cls=PayloadEncoder))
283 except Exception as e:
284 st.error('Your message could not be sent. The connection is already closed')
285
286 st.stop()
287
288
289if __name__ == "__main__":
290 if st.runtime.exists():
291 main()
292 else:
293 sys.argv = ["streamlit", "run", sys.argv[0]]
294 sys.exit(stcli.main())