| | from typing import Generator |
| | from utils import validate_api_key, get_info, validate_uri, extract_code_blocks, get_info_sqlalchemy |
| | from langchain_community.utilities import SQLDatabase |
| | from var import system_prompt, markdown_info, query_output, groq_models |
| | import streamlit as st |
| | from groq import Groq |
| |
|
| | st.set_page_config(layout="wide") |
| |
|
| | |
| | if "messages" not in st.session_state: |
| | st.session_state.messages = [] |
| | st.session_state.sql_result = [] |
| |
|
| | if "selected_model" not in st.session_state: |
| | st.session_state.selected_model = None |
| |
|
| | st.markdown("# SQL Chat") |
| |
|
| | st.sidebar.title("Settings") |
| | api_key = st.sidebar.text_input("Groq API Key", type="password") |
| |
|
| | |
| | if not validate_api_key(api_key): |
| | st.sidebar.error("Enter valid API Key") |
| | model = st.sidebar.selectbox("Select Model", groq_models, disabled=True) |
| | else: |
| | st.sidebar.success("API Key is valid") |
| | model = st.sidebar.selectbox("Select Model", groq_models, index=0) |
| |
|
| | if st.session_state.selected_model != model: |
| | st.session_state.messages = [] |
| | st.session_state.sql_result = [] |
| | st.session_state.selected_model = model |
| |
|
| | uri = st.sidebar.text_input("Enter SQL Database URI") |
| |
|
| | if not validate_uri(uri): |
| | st.sidebar.error("Enter valid URI") |
| | else: |
| | st.sidebar.success("URI is valid") |
| | db_info = get_info_sqlalchemy(uri) |
| | markdown_info = markdown_info.format(**db_info) |
| | with st.expander("SQL Database Info"): |
| | st.markdown(markdown_info) |
| | system_prompt = system_prompt.format(markdown_info = markdown_info) |
| |
|
| | if validate_api_key(api_key) and validate_uri(uri): |
| | client = Groq( |
| | api_key=api_key, |
| | ) |
| |
|
| | db = SQLDatabase.from_uri(uri) |
| | |
| | avatar = {"user": 'π¨βπ»', "assistant": 'π€', "executor": 'π’'} |
| |
|
| | |
| | for i, message in enumerate(st.session_state.messages): |
| | with st.chat_message(message["role"], avatar=avatar[message["role"]]): |
| | st.markdown(message["content"]) |
| | if (i+1)%2 == 0: |
| | with st.chat_message("SQL Executor", avatar=avatar["executor"]): |
| | st.markdown(st.session_state.sql_result[i//2]) |
| |
|
| |
|
| | def generate_chat_responses(chat_completion) -> Generator[str, None, None]: |
| | """Yield chat response content from the Groq API response.""" |
| | for chunk in chat_completion: |
| | if chunk.choices[0].delta.content: |
| | yield chunk.choices[0].delta.content |
| |
|
| |
|
| | if prompt := st.chat_input("Enter your prompt here..."): |
| | st.session_state.messages.append({"role": "user", "content": prompt}) |
| |
|
| | with st.chat_message("user", avatar=avatar["user"]): |
| | st.markdown(prompt) |
| |
|
| | |
| | try: |
| | chat_completion = client.chat.completions.create( |
| | model=model, |
| | messages=[{ |
| | "role": "system", |
| | "content": system_prompt |
| | }, |
| | ]+ |
| | [ |
| | { |
| | "role": m["role"], |
| | "content": m["content"] |
| | } |
| | for m in st.session_state.messages[-8:] |
| | ], |
| | max_tokens=3000, |
| | stream=True |
| | ) |
| |
|
| | |
| | with st.chat_message("SQL Assistant", avatar=avatar["assistant"]): |
| | chat_responses_generator = generate_chat_responses(chat_completion) |
| | llm_response = st.write_stream(chat_responses_generator) |
| |
|
| | with st.chat_message("SQL Executor", avatar=avatar["executor"]): |
| | query = extract_code_blocks(llm_response) |
| | result = db.run(query[0]) |
| | query_response = st.write(query_output.format(result=result)) |
| |
|
| | except Exception as e: |
| | st.error(e, icon="π¨") |
| |
|
| | if len(str(result)) > 1000: |
| | query_output_truncated = query_output.format(result=result)[:500]+query_output.format(result=result)[-500:] |
| | else: |
| | query_output_truncated = query_output.format(result=result) |
| |
|
| | st.session_state.sql_result.append(query_output_truncated) |
| |
|
| | |
| | if isinstance(llm_response, str): |
| | st.session_state.messages.append( |
| | {"role": "assistant", "content": llm_response}) |
| | else: |
| | |
| | combined_response = "\n".join(str(item) for item in llm_response) |
| | st.session_state.messages.append( |
| | {"role": "assistant", "content": combined_response}) |
| |
|
| | st.sidebar.button("Clear Chat History", on_click=lambda: st.session_state.messages.clear() and st.session_state.sql_result.clear()) |
| |
|
| | else: |
| | st.error("Please enter valid Groq API Key and URI in the sidebar.") |