SQL_chatbot_API / app.py
saadkhi's picture
Update app.py
e62bece verified
# app.py
# Stable CPU-only Hugging Face Space
# Phi-3-mini + LoRA (NO bitsandbytes, NO SSR issues)
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# ─────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0
DO_SAMPLE = False
# ─────────────────────────────────────────────
# Load model & tokenizer (CPU SAFE)
# ─────────────────────────────────────────────
print("Loading base model on CPU...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="cpu",
torch_dtype=torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(model, LORA_PATH)
print("Merging LoRA weights...")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model & tokenizer loaded successfully")
# ─────────────────────────────────────────────
# Inference
# ─────────────────────────────────────────────
def generate_sql(question: str) -> str:
if not question or not question.strip():
return "Please enter a SQL-related question."
messages = [
{"role": "user", "content": question.strip()}
]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
with torch.inference_mode():
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
do_sample=DO_SAMPLE,
pad_token_id=tokenizer.eos_token_id,
use_cache=True,
)
response = tokenizer.decode(
output_ids[0],
skip_special_tokens=True
)
# Clean Phi-3 chat artifacts
for token in ["<|assistant|>", "<|user|>", "<|end|>"]:
if token in response:
response = response.split(token)[-1]
return response.strip() or "(empty response)"
# ─────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
label="SQL Question",
placeholder="Find duplicate emails in users table",
lines=3,
),
outputs=gr.Textbox(
label="Generated SQL",
lines=8,
),
title="SQL Chat – Phi-3-mini (CPU)",
description=(
"CPU-only Hugging Face Space.\n"
"First response may take 60–180 seconds. "
"Subsequent requests are faster."
),
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Delete duplicate rows based on email"],
],
cache_examples=False,
)
# ─────────────────────────────────────────────
# Launch
# ─────────────────────────────────────────────
if __name__ == "__main__":
print("Launching Gradio interface...")
demo.launch(
server_name="0.0.0.0",
ssr_mode=False, # important: avoids asyncio FD bug
show_error=True,
)