Source code for diva.gui.service_streamlit.tab1

# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.

"""
Tab 1 that contains the chatbot page
"""
import streamlit as st
import time
import psutil
from pyJoules.energy_meter import EnergyMeter

from loguru import logger

import diva.config as config
import diva.tools as src_tools
import diva.tools as tools
from diva import parameters
from diva import energy
from diva.chat import ModuleChat
from diva.graphs import service_graph_generation
from diva.logging_.logger import Process


[docs] def get_module_llm(): """ Imports and returns the default language model module. This function dynamically imports the `module_llm` from the `llm` package and returns the default language model (`default_llm`) defined within that module. Returns: ------- object The default language model (`default_llm`) from the `module_llm`. """ from diva.llm import llms return llms.generator
[docs] def initialize_session_state(module_chat, langage): """Initialize all session state variables""" if "module_config" not in st.session_state: st.session_state.module_config = config.ModuleConfig(module_chat) if "p" not in st.session_state: st.session_state.p = psutil.Process() st.session_state.p.cpu_percent() if "time_last_call_p" not in st.session_state: st.session_state.time_last_call_p = time.time() if "energy_meter" not in st.session_state: st.session_state.energy_meter = EnergyMeter(energy.devices) if "energy_records" not in st.session_state: st.session_state.energy_records = energy.EnergyRecords() if "messages" not in st.session_state: first_sentence = parameters.first_sentence source_lang = tools.langage_to_iso(langage) first_sentence = tools.translate_from_en(first_sentence, source_lang) st.session_state.messages = { "msg_0": {"role": "assistant", "content": first_sentence} } if "plots_history" not in st.session_state: st.session_state.plots_history = {} if "config_params" not in st.session_state: st.session_state.config_params = {} if "last_response_is_fig" not in st.session_state: st.session_state.last_response_is_fig = False if "conv_cleared" not in st.session_state: st.session_state.conv_cleared = True if "data_history" not in st.session_state: st.session_state.data_history = {} st.session_state.process_logging = Process(text_language='html')
[docs] def clear_conversation(source_lang): """Clear all conversation history and reset state""" first_sentence = parameters.first_sentence first_sentence = tools.translate_from_en(first_sentence, source_lang) st.session_state.messages = { "msg_0": {"role": "assistant", "content": first_sentence} } st.session_state.plots_history = {} st.session_state.config_params = {} st.session_state.data_history = {} st.session_state.last_response_is_fig = False st.session_state.conv_cleared = True st.session_state.module_chat = ModuleChat(st.session_state.connected_user['user_type']) st.session_state.module_config = config.ModuleConfig(st.session_state.module_chat) st.session_state.process_logging = Process(text_language="html") st.session_state.history_logging = '' if st.session_state.history_logging != '': st.session_state.log_container.markdown(st.session_state.history_logging)
[docs] def rebuild_chat_history(source_lang): """Rebuild and display all messages and figures from history""" for msg_id in st.session_state.messages.keys(): message = st.session_state.messages[msg_id] st.chat_message(message["role"]).write( tools.translate_chat_msgs(message["role"], message["content"], source_lang) ) if msg_id in st.session_state.plots_history.keys(): fig = st.session_state.plots_history[msg_id] if isinstance(fig, (parameters.altair_type_fig, parameters.altair_type_fig_concat)): st.altair_chart(fig, use_container_width=True) else: st.plotly_chart(fig, use_container_width=True, key=msg_id+'_plotly_chart') if msg_id != list(st.session_state.plots_history.keys())[-1]: title, data_csv = st.session_state.data_history[msg_id] tools.add_download_button_after_prompts( title, data_csv, message, st.session_state.config_params[msg_id], key=msg_id+'_download_button' ) # Display download & feedback buttons for last figure if st.session_state.config_params != {} and st.session_state.last_response_is_fig: msg_plots = list(st.session_state.plots_history.keys()) msg_prompt = list(st.session_state.messages.keys()) if msg_plots != []: msg_id = msg_plots[-1] msg_id_ = msg_prompt[-1] tools.add_download_feedback_button( st.session_state.messages[msg_id_], st.session_state.config_params[msg_id], st.session_state.plots_history[msg_id], )
[docs] def get_user_input(source_lang, dev_mode): """Get user input from chat or suggested prompts""" prompt = st.chat_input("Ask question...") if st.session_state.conv_cleared: sentences_proposed = (parameters.sentences_proposed_dev if dev_mode else parameters.sentences_proposed_user) sentences_on_cards = list(sentences_proposed.keys()) cols = st.columns(3) for i, col in enumerate(cols): with col: label = tools.translate_from_en(sentences_on_cards[i], source_lang) if st.button(label): prompt = tools.translate_from_en( sentences_proposed[sentences_on_cards[i]], source_lang ) st.session_state.conv_cleared = False return prompt
[docs] def handle_discussion(module_chat, prompt, source_lang): """Handle simple discussion requests""" llm_answer = module_chat.chat.llm_answer llm_answer = tools.translate_from_en(llm_answer, source_lang) with st.chat_message("assistant"): st.write_stream(tools.write_like_chatGPT(llm_answer)) st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "assistant", "content": llm_answer } tools.update_logs(prompt, str(module_chat.prompt), llm_answer) tools.add_feedback_msg_button(prompt, llm_answer) st.session_state.last_response_is_fig = False
[docs] def handle_visualization(module_chat, module_config, prompt, source_lang, energy_meter, energy_records, user_type, langage): """Handle visualization requests""" with st.spinner("Extracting information...⌛"): energy_meter.start() module_config.prompt_to_config(module_chat.prompt) energy_meter.stop() energy_records.set_gpu(energy_meter) llm_answer = module_chat.chat.llm_answer config_params = tools.convert_to_dict_config(module_config.config) logger.info(f"missing info --> {module_config.missings}") if not module_config.missings: _handle_complete_visualization( module_chat, module_config, config_params, llm_answer, prompt, source_lang, user_type, langage ) else: _handle_missing_params( module_chat, module_config, config_params, prompt, source_lang )
def _handle_complete_visualization(module_chat, module_config, config_params, llm_answer, prompt, source_lang, user_type, langage): """Handle visualization when all parameters are present""" if len(module_config.config.not_in_shp) > 0: extra = "location" if len(module_config.config.location.split(", ")) == 1 else "locations" llm_answer = ( f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}" f" but I can answer for the other {extra}." ) graph_gen = service_graph_generation.ServiceGeneratePlotlyGraph( config_params, langage, user_type ) config_params['aggreg_type'] = graph_gen.aggreg_type display_config = f"{llm_answer} \n" + tools.display_config(config_params) + " \n" if config_params['graph_type'] == 'warming stripes': display_config += "The calculations of warming stripes is based on the period 1971-2000." with st.chat_message("assistant"): display_config = tools.translate_from_en(display_config, source_lang) st.write_stream(tools.write_like_chatGPT(display_config)) st.session_state.config_params[f"msg_{len(st.session_state.messages)}"] = config_params st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "assistant", "content": display_config } st.session_state.last_response_is_fig = True with st.spinner(tools.generate_waiting_for_graph_msg() + "⌛"): graph_gen.generate() msg_plots = list(st.session_state.plots_history.keys()) if msg_plots: msg_id = msg_plots[-1] title, data_csv = tools.add_download_feedback_button( prompt, config_params, st.session_state.plots_history[msg_id], "in_prompt", ) st.session_state.data_history[msg_id] = (title, data_csv) tools.update_logs(prompt, str(module_chat.prompt), tools.from_dict_to_str(config_params)) def _handle_missing_params(module_chat,module_config, config_params, prompt, source_lang): """Handle visualization when some parameters are missing""" link_sentences_for_missings = [ tools.get_link_question_for_missing(link) for link in module_config.missings ] if len(module_config.config.not_in_shp) > 0: extra = "another location" if len(module_config.config.not_in_shp) == 1 else "other locations" ask_for_complete_request = ( f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}." f" Could you please ask me again for {extra}?" ) to_remove = list(tools.get_link_question_for_missing.link_questions.values())[2] link_sentences_for_missings.remove(to_remove) if len(link_sentences_for_missings) > 0: ask_for_complete_request += f" In addition, could you please precise me {src_tools.enumeration(link_sentences_for_missings)}." else: ask_for_complete_request = "Thank you for your request. " if config_params['climate_variable'].lower() in parameters.available_vars: ask_for_complete_request += f"I can give you a graphical view of the {config_params['climate_variable']}. " ask_for_complete_request += f"Could you please precise me {src_tools.enumeration(link_sentences_for_missings)} ?" _ = module_config.ask_missings() with st.chat_message("assistant"): ask_for_complete_request = tools.translate_from_en(ask_for_complete_request, source_lang) st.write_stream(tools.write_like_chatGPT(ask_for_complete_request)) st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "assistant", "content": ask_for_complete_request } st.session_state.last_response_is_fig = False tools.update_logs(prompt, str(module_chat.prompt), ask_for_complete_request) tools.add_feedback_msg_button(prompt, ask_for_complete_request)
[docs] def process_prompt(module_chat, module_config, prompt, source_lang, energy_meter, energy_records, p, user_type,langage): """Process user prompt and generate appropriate response""" tools.send_user_event("prompt") prompt = tools.clean_prompt(prompt) # Energy recording setup energy_records.clear_query() cpu_percent = p.cpu_percent() st.session_state.time_last_call_p = time.time() st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "user", "content": prompt, } st.session_state.conv_cleared = False with st.chat_message("user"): st.markdown(prompt) prompt = tools.translate_to_en(prompt, source_lang) with st.spinner("Processing the request...⌛"): module_chat.create_user_prompt(prompt) assert module_chat.prompt is not None energy_meter.start() module_chat.prompt_classification().lower() module_chat.prompt_rephrasing() module_chat.is_prompt_in_scope() module_chat.generate_text_answer() energy_meter.stop() energy_records.set_gpu(energy_meter) type_of_request = module_chat.prompt.type print(module_chat.prompt.__repr__()) if "discussion" in type_of_request: handle_discussion(module_chat, prompt, source_lang) if "visualisation" in type_of_request or "visualization" in type_of_request: handle_visualization( module_chat, module_config, prompt, source_lang, energy_meter, energy_records, user_type, langage ) # Finalize energy recording cpu_percent = p.cpu_percent() duration = min(time.time() - st.session_state.time_last_call_p, 11) st.session_state.time_last_call_p = time.time() energy_records.set_cpu( energy.get_energy_cpu(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_cpu(cpu_percent=0, duration=duration) ) energy_records.set_ram( energy.get_energy_ram(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_ram(cpu_percent=0, duration=duration) ) st.session_state.energy_consumption = { "query_energy": round(energy_records.query / 1000, 2), "query_CO2": energy_records.get_co2(type_="query"), "query_water": energy_records.get_water(type_="query"), "total_energy": round(energy_records.total / 1000, 2), "total_CO2": energy_records.get_co2(type_="total"), "total_water": energy_records.get_water(type_="total") } st.session_state.history_logging = tools.update_logging_history( st.session_state.history_logging, st.session_state.process_logging.doc(), prompt ) st.session_state.log_container.markdown(st.session_state.history_logging) nrg = st.session_state.energy_consumption tools.show_energy_consumption(nrg)
[docs] def main(tab1_options): """Main function for Tab 1 - refactored version""" # Initialize module_chat if "module_chat" not in st.session_state: st.session_state.module_chat = ModuleChat(st.session_state.connected_user['user_type']) module_chat = st.session_state.module_chat langage = tab1_options["langage"] # Initialize all session state variables initialize_session_state(module_chat, langage) source_lang = tools.langage_to_iso(tab1_options["langage"]) user_type = st.session_state.connected_user['user_type'] module_config = st.session_state.module_config energy_meter = st.session_state.energy_meter energy_records = st.session_state.energy_records p = st.session_state.p # Handle conversation clearing if tab1_options["clear_conv"]: print("CLEAR CONVERSATION") clear_conversation(source_lang) st.rerun() # Rebuild chat history rebuild_chat_history(source_lang) # Get user input prompt = get_user_input(source_lang, tab1_options['dev_mode']) # Check for empty prompt if prompt: prompt_is_empty = tools.chech_if_empty_prompt(prompt) if prompt_is_empty: prompt = False st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "user", "content": "" } with st.chat_message("user"): st.markdown("") with st.chat_message("assistant"): answer = "Please write something ..." st.write_stream(tools.write_like_chatGPT(answer)) st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = { "role": "assistant", "content": answer } # Process the prompt if valid if prompt: process_prompt( module_chat, module_config, prompt, source_lang, energy_meter, energy_records, p, user_type, langage ) st.rerun()