streamlit_ui#

A sample User Interface powered by Streamlit implementing a WebSocket client that connects to the bot WebSocket server.

  1import base64
  2import json
  3import queue
  4import sys
  5import threading
  6import time
  7from datetime import datetime
  8
  9import pandas as pd
 10import plotly
 11import streamlit as st
 12import websocket
 13from audio_recorder_streamlit import audio_recorder
 14from streamlit.runtime import Runtime
 15from streamlit.runtime.app_session import AppSession
 16from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
 17from streamlit.web import cli as stcli
 18
 19from besser.bot.core.file import File
 20from besser.bot.core.message import Message, MessageType
 21from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder
 22
 23# Time interval to check if a streamlit session is still active, in seconds
 24SESSION_MONITORING_INTERVAL = 1
 25
 26
 27def get_streamlit_session() -> AppSession or None:
 28    session_id = get_script_run_ctx().session_id
 29    runtime: Runtime = Runtime.instance()
 30    return next((
 31        s.session
 32        for s in runtime._session_mgr.list_sessions()
 33        if s.session.id == session_id
 34    ), None)
 35
 36
 37def session_monitoring(interval: int):
 38    runtime: Runtime = Runtime.instance()
 39    session = get_streamlit_session()
 40    while True:
 41        time.sleep(interval)
 42        if not runtime.is_active_session(session.id):
 43            runtime.close_session(session.id)
 44            session.session_state['websocket'].close()
 45            break
 46
 47
 48def main():
 49    try:
 50        # We get the websocket host and port from the script arguments
 51        bot_name = sys.argv[1]
 52    except Exception as e:
 53        # If they are not provided, we use default values
 54        bot_name = 'Chatbot Demo'
 55    st.header(bot_name)
 56    st.markdown("[Github](https://github.com/BESSER-PEARL/BESSER-Bot-Framework)")
 57    # User input component. Must be declared before history writing
 58    user_input = st.chat_input("What is up?")
 59
 60    def on_message(ws, payload_str):
 61        # https://github.com/streamlit/streamlit/issues/2838
 62        streamlit_session = get_streamlit_session()
 63        payload: Payload = Payload.decode(payload_str)
 64        if payload.action == PayloadAction.BOT_REPLY_STR.value:
 65            content = payload.message
 66            t = MessageType.STR
 67        elif payload.action == PayloadAction.BOT_REPLY_FILE.value:
 68            content = payload.message
 69            t = MessageType.FILE
 70        elif payload.action == PayloadAction.BOT_REPLY_DF.value:
 71            content = pd.read_json(payload.message)
 72            t = MessageType.DATAFRAME
 73        elif payload.action == PayloadAction.BOT_REPLY_PLOTLY.value:
 74            content = plotly.io.from_json(payload.message)
 75            t = MessageType.PLOTLY
 76        elif payload.action == PayloadAction.BOT_REPLY_LOCATION.value:
 77            content = {
 78                'latitude': [payload.message['latitude']],
 79                'longitude': [payload.message['longitude']]
 80            }
 81            t = MessageType.LOCATION
 82        elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value:
 83            t = MessageType.OPTIONS
 84            d = json.loads(payload.message)
 85            content = []
 86            for button in d.values():
 87                content.append(button)
 88        elif payload.action == PayloadAction.BOT_REPLY_RAG.value:
 89            t = MessageType.RAG_ANSWER
 90            content = payload.message
 91        message = Message(t=t, content=content, is_user=False, timestamp=datetime.now())
 92        streamlit_session._session_state['queue'].put(message)
 93        streamlit_session._handle_rerun_script_request()
 94
 95    def on_error(ws, error):
 96        pass
 97
 98    def on_open(ws):
 99        pass
100
101    def on_close(ws, close_status_code, close_msg):
102        pass
103
104    def on_ping(ws, data):
105        pass
106
107    def on_pong(ws, data):
108        pass
109
110    user_type = {
111        0: 'assistant',
112        1: 'user'
113    }
114
115    if 'history' not in st.session_state:
116        st.session_state['history'] = []
117
118    if 'queue' not in st.session_state:
119        st.session_state['queue'] = queue.Queue()
120
121    if 'websocket' not in st.session_state:
122        try:
123            # We get the websocket host and port from the script arguments
124            host = sys.argv[2]
125            port = sys.argv[3]
126        except Exception as e:
127            # If they are not provided, we use default values
128            host = 'localhost'
129            port = '8765'
130        ws = websocket.WebSocketApp(f"ws://{host}:{port}/",
131                                    on_open=on_open,
132                                    on_message=on_message,
133                                    on_error=on_error,
134                                    on_close=on_close,
135                                    on_ping=on_ping,
136                                    on_pong=on_pong)
137        websocket_thread = threading.Thread(target=ws.run_forever)
138        add_script_run_ctx(websocket_thread)
139        websocket_thread.start()
140        st.session_state['websocket'] = ws
141
142    if 'session_monitoring' not in st.session_state:
143        session_monitoring_thread = threading.Thread(target=session_monitoring,
144                                                     kwargs={'interval': SESSION_MONITORING_INTERVAL})
145        add_script_run_ctx(session_monitoring_thread)
146        session_monitoring_thread.start()
147        st.session_state['session_monitoring'] = session_monitoring_thread
148
149    ws = st.session_state['websocket']
150
151    with st.sidebar:
152
153        if reset_button := st.button(label="Reset bot"):
154            st.session_state['history'] = []
155            st.session_state['queue'] = queue.Queue()
156            payload = Payload(action=PayloadAction.RESET)
157            ws.send(json.dumps(payload, cls=PayloadEncoder))
158
159        if voice_bytes := audio_recorder(text=None, pause_threshold=2):
160            if 'last_voice_message' not in st.session_state or st.session_state['last_voice_message'] != voice_bytes:
161                st.session_state['last_voice_message'] = voice_bytes
162                # Encode the audio bytes to a base64 string
163                voice_message = Message(t=MessageType.AUDIO, content=voice_bytes, is_user=True, timestamp=datetime.now())
164                st.session_state.history.append(voice_message)
165                voice_base64 = base64.b64encode(voice_bytes).decode('utf-8')
166                payload = Payload(action=PayloadAction.USER_VOICE, message=voice_base64)
167                try:
168                    ws.send(json.dumps(payload, cls=PayloadEncoder))
169                except Exception as e:
170                    st.error('Your message could not be sent. The connection is already closed')
171        if uploaded_file := st.file_uploader("Choose a file", accept_multiple_files=False):
172            if 'last_file' not in st.session_state or st.session_state['last_file'] != uploaded_file:
173                st.session_state['last_file'] = uploaded_file
174                bytes_data = uploaded_file.read()
175                file_object = File(file_base64=base64.b64encode(bytes_data).decode('utf-8'), file_name=uploaded_file.name, file_type=uploaded_file.type)
176                payload = Payload(action=PayloadAction.USER_FILE, message=file_object.get_json_string())
177                file_message = Message(t=MessageType.FILE, content=file_object.to_dict(), is_user=True, timestamp=datetime.now())
178                st.session_state.history.append(file_message)
179                try:
180                    ws.send(json.dumps(payload, cls=PayloadEncoder))
181                except Exception as e:
182                    st.error('Your message could not be sent. The connection is already closed')
183    for message in st.session_state['history']:
184        with st.chat_message(user_type[message.is_user]):
185            if message.type == MessageType.AUDIO:
186                st.audio(message.content, format="audio/wav")
187            elif message.type == MessageType.FILE:
188                file: File = File.from_dict(message.content)
189                file_name = file.name
190                file_type = file.type
191                file_data = base64.b64decode(file.base64.encode('utf-8'))
192                st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
193                                   key=file_name + str(time.time()))
194            elif message.type == MessageType.LOCATION:
195                st.map(message.content)
196            elif message.type == MessageType.RAG_ANSWER:
197                # TODO: Avoid duplicate in history and queue
198                st.write(f'🔮 {message.content["answer"]}')
199                with st.expander('Details'):
200                    st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**')
201                    st.write(f'It received the following documents as input to come up with a relevant answer:')
202                    if 'docs' in message.content:
203                        for i, doc in enumerate(message.content['docs']):
204                            st.write(f'**Document {i + 1}/{len(message.content["docs"])}**')
205                            st.write(f'- **Source:** {doc["metadata"]["source"]}')
206                            st.write(f'- **Page:** {doc["metadata"]["page"]}')
207                            st.write(f'- **Content:** {doc["content"]}')
208            else:
209                st.write(message.content)
210
211    first_message = True
212    while not st.session_state['queue'].empty():
213        with st.chat_message("assistant"):
214            message = st.session_state['queue'].get()
215            if hasattr(message, '__len__'):
216                t = len(message.content) / 1000 * 3
217            else:
218                t = 2
219            if t > 3:
220                t = 3
221            elif t < 1 and first_message:
222                t = 1
223            first_message = False
224            if message.type == MessageType.OPTIONS:
225                st.session_state['buttons'] = message.content
226            elif message.type == MessageType.FILE:
227                st.session_state['history'].append(message)
228                with st.spinner(''):
229                    time.sleep(t)
230                file: File = File.from_dict(message.content)
231                file_name = file.name
232                file_type = file.type
233                file_data = base64.b64decode(file.base64.encode('utf-8'))
234                st.download_button(label='Download ' + file_name, file_name=file_name, data=file_data, mime=file_type,
235                                   key=file_name + str(time.time()))
236            elif message.type == MessageType.LOCATION:
237                st.session_state['history'].append(message)
238                st.map(message.content)
239            elif message.type == MessageType.RAG_ANSWER:
240                st.session_state['history'].append(message)
241                st.write(f'🔮 {message.content["answer"]}')
242                with st.expander('Details'):
243                    st.write(f'This answer has been generated by an LLM: **{message.content["llm_name"]}**')
244                    st.write(f'It received the following documents as input to come up with a relevant answer:')
245                    if 'docs' in message.content:
246                        for i, doc in enumerate(message.content['docs']):
247                            st.write(f'**Document {i + 1}/{len(message.content["docs"])}**')
248                            st.write(f'- **Source:** {doc["metadata"]["source"]}')
249                            st.write(f'- **Page:** {doc["metadata"]["page"]}')
250                            st.write(f'- **Content:** {doc["content"]}')
251            elif message.type == MessageType.STR:
252                st.session_state['history'].append(message)
253                with st.spinner(''):
254                    time.sleep(t)
255                st.write(message.content)
256
257    if 'buttons' in st.session_state:
258        buttons = st.session_state['buttons']
259        cols = st.columns(1)
260        for i, option in enumerate(buttons):
261            if cols[0].button(option):
262                with st.chat_message("user"):
263                    st.write(option)
264                message = Message(t=MessageType.STR, content=option, is_user=True, timestamp=datetime.now())
265                st.session_state.history.append(message)
266                payload = Payload(action=PayloadAction.USER_MESSAGE,
267                                  message=option)
268                ws.send(json.dumps(payload, cls=PayloadEncoder))
269                del st.session_state['buttons']
270                break
271
272    if user_input:
273        if 'buttons' in st.session_state:
274            del st.session_state['buttons']
275        with st.chat_message("user"):
276            st.write(user_input)
277        message = Message(t=MessageType.STR, content=user_input, is_user=True, timestamp=datetime.now())
278        st.session_state.history.append(message)
279        payload = Payload(action=PayloadAction.USER_MESSAGE,
280                          message=user_input)
281        try:
282            ws.send(json.dumps(payload, cls=PayloadEncoder))
283        except Exception as e:
284            st.error('Your message could not be sent. The connection is already closed')
285
286    st.stop()
287
288
289if __name__ == "__main__":
290    if st.runtime.exists():
291        main()
292    else:
293        sys.argv = ["streamlit", "run", sys.argv[0]]
294        sys.exit(stcli.main())