streamlit_ui#

A sample User Interface powered by Streamlit implementing a WebSocket client that connects to the bot WebSocket server.

  1import json
  2import queue
  3import sys
  4import threading
  5import time
  6
  7import pandas as pd
  8import streamlit as st
  9import websocket
 10from streamlit.runtime import Runtime
 11from streamlit.runtime.app_session import AppSession
 12from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
 13from streamlit.web import cli as stcli
 14
 15from besser.bot.platforms.payload import Payload, PayloadAction, PayloadEncoder
 16
 17# Time interval to check if a streamlit session is still active, in seconds
 18SESSION_MONITORING_INTERVAL = 10
 19
 20
 21def get_streamlit_session() -> AppSession or None:
 22    session_id = get_script_run_ctx().session_id
 23    runtime: Runtime = Runtime.instance()
 24    return next((
 25        s.session
 26        for s in runtime._session_mgr.list_sessions()
 27        if s.session.id == session_id
 28    ), None)
 29
 30
 31def session_monitoring(interval: int):
 32    runtime: Runtime = Runtime.instance()
 33    session = get_streamlit_session()
 34    while True:
 35        time.sleep(interval)
 36        if not runtime.is_active_session(session.id):
 37            runtime.close_session(session.id)
 38            session.session_state['websocket'].close()
 39            break
 40
 41
 42def main():
 43
 44    def on_message(ws, payload_str):
 45        # https://github.com/streamlit/streamlit/issues/2838
 46        streamlit_session = get_streamlit_session()
 47        payload: Payload = Payload.decode(payload_str)
 48        if payload.action == PayloadAction.BOT_REPLY_STR.value:
 49            message = payload.message
 50        elif payload.action == PayloadAction.BOT_REPLY_DF.value:
 51            message = pd.read_json(payload.message)
 52        elif payload.action == PayloadAction.BOT_REPLY_OPTIONS.value:
 53            d = json.loads(payload.message)
 54            message = []
 55            for button in d.values():
 56                message.append(button)
 57        streamlit_session._session_state['queue'].put(message)
 58        streamlit_session._handle_rerun_script_request()
 59
 60    def on_error(ws, error):
 61        pass
 62
 63    def on_open(ws):
 64        pass
 65
 66    def on_close(ws, close_status_code, close_msg):
 67        pass
 68
 69    def on_ping(ws, data):
 70        pass
 71
 72    def on_pong(ws, data):
 73        pass
 74
 75    st.set_page_config(
 76        page_title="Streamlit Chat - Demo",
 77        page_icon=":robot:"
 78    )
 79
 80    user_type = {
 81        0: 'assistant',
 82        1: 'user'
 83    }
 84
 85    st.header("Chat Demo")
 86    st.markdown("[Github](https://github.com/BESSER-PEARL/bot-framework)")
 87
 88    if 'history' not in st.session_state:
 89        st.session_state['history'] = []
 90
 91    if 'queue' not in st.session_state:
 92        st.session_state['queue'] = queue.Queue()
 93
 94    if 'websocket' not in st.session_state:
 95        ws = websocket.WebSocketApp("ws://localhost:8765/",
 96                                    on_open=on_open,
 97                                    on_message=on_message,
 98                                    on_error=on_error,
 99                                    on_close=on_close,
100                                    on_ping=on_ping,
101                                    on_pong=on_pong)
102        websocket_thread = threading.Thread(target=ws.run_forever)
103        add_script_run_ctx(websocket_thread)
104        websocket_thread.start()
105        st.session_state['websocket'] = ws
106
107    if 'session_monitoring' not in st.session_state:
108        session_monitoring_thread = threading.Thread(target=session_monitoring,
109                                                     kwargs={'interval': SESSION_MONITORING_INTERVAL})
110        add_script_run_ctx(session_monitoring_thread)
111        session_monitoring_thread.start()
112        st.session_state['session_monitoring'] = session_monitoring_thread
113
114    ws = st.session_state['websocket']
115
116    with st.sidebar:
117        reset_button = st.button(label="Reset bot")
118        if reset_button:
119            st.session_state['history'] = []
120            st.session_state['queue'] = queue.Queue()
121            payload = Payload(action=PayloadAction.RESET)
122            ws.send(json.dumps(payload, cls=PayloadEncoder))
123
124    for message in st.session_state['history']:
125        with st.chat_message(user_type[message[1]]):
126            st.write(message[0])
127
128    first_message = True
129    while not st.session_state['queue'].empty():
130        message = st.session_state['queue'].get()
131        t = len(message) / 1000 * 3
132        if t > 3:
133            t = 3
134        elif t < 1 and first_message:
135            t = 1
136        first_message = False
137        if isinstance(message, list):
138            st.session_state['buttons'] = message
139        else:
140            st.session_state['history'].append((message, 0))
141            with st.chat_message("assistant"):
142                with st.spinner(''):
143                    time.sleep(t)
144                st.write(message)
145
146    if 'buttons' in st.session_state:
147        buttons = st.session_state['buttons']
148        cols = st.columns(1)
149        for i, option in enumerate(buttons):
150            if cols[0].button(option):
151                with st.chat_message("user"):
152                    st.write(option)
153                st.session_state.history.append((option, 1))
154                payload = Payload(action=PayloadAction.USER_MESSAGE,
155                                  message=option)
156                ws.send(json.dumps(payload, cls=PayloadEncoder))
157                del st.session_state['buttons']
158                break
159
160    # React to user input
161    if user_input := st.chat_input("What is up?"):
162        if 'buttons' in st.session_state:
163            del st.session_state['buttons']
164        with st.chat_message("user"):
165            st.write(user_input)
166        st.session_state.history.append((user_input, 1))
167        payload = Payload(action=PayloadAction.USER_MESSAGE,
168                          message=user_input)
169        try:
170            ws.send(json.dumps(payload, cls=PayloadEncoder))
171        except Exception as e:
172            st.error('Your message could not be sent. The connection is already closed')
173
174    st.stop()
175
176
177if __name__ == "__main__":
178    if st.runtime.exists():
179        main()
180    else:
181        sys.argv = ["streamlit", "run", sys.argv[0]]
182        sys.exit(stcli.main())