Spaces:
Running
Running
File size: 2,921 Bytes
e62bece 7f3026b c7c0d53 7f424d1 02976e0 e62bece 00c8a57 e95c2d3 bb16527 4bc3e8b e95c2d3 15e7b42 e95c2d3 7f3026b e95c2d3 f5903a4 e95c2d3 02976e0 f5903a4 15e7b42 bb16527 15e7b42 f5903a4 bb16527 a2f39c6 a0fbd48 f5903a4 a0fbd48 f5903a4 bb16527 f5903a4 4bc3e8b f5903a4 bb16527 22df2c5 e95c2d3 bb16527 f5903a4 7f3026b e62bece 7f424d1 22df2c5 e95c2d3 02976e0 f5903a4 8b67be0 15e7b42 e95c2d3 15e7b42 84031c5 e62bece f5903a4 bb16527 f5903a4 bb16527 84031c5 1344c31 bb16527 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import warnings
warnings.filterwarnings("ignore")
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
torch.set_num_threads(1)
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()
print("Model ready")
# βββββββββββββββββββββββββ
# SQL FILTER
# βββββββββββββββββββββββββ
SQL_KEYWORDS = [
"sql", "database", "table", "select", "insert",
"update", "delete", "join", "group by",
"postgres", "mysql", "sqlite", "query"
]
def is_sql_related(text):
text = text.lower()
return any(k in text for k in SQL_KEYWORDS)
# βββββββββββββββββββββββββ
# GENERATION
# βββββββββββββββββββββββββ
SYSTEM_PROMPT = """
You are an expert SQL generator.
Rules:
- Only respond to SQL or database related questions.
- If the question is not about SQL or databases, refuse.
- Output ONLY SQL query.
- Do not explain.
"""
def generate_sql(user_input):
if not user_input.strip():
return "Enter SQL question."
# HARD GUARD
if not is_sql_related(user_input):
return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."
prompt = f"""
{SYSTEM_PROMPT}
User request: {user_input}
SQL:
"""
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.1,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
# return only SQL part
result = text.split("SQL:")[-1].strip()
# extra safety: remove explanations
result = result.split("\n\n")[0]
return result
# βββββββββββββββββββββββββ
# UI
# βββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=gr.Textbox(
lines=3,
label="SQL Question",
placeholder="Find duplicate emails in users table"
),
outputs=gr.Textbox(
lines=8,
label="Generated SQL"
),
title="AI SQL Generator (Portfolio Project)",
description="This model ONLY responds to SQL/database queries.",
examples=[
["Find duplicate emails in users table"],
["Top 5 highest paid employees"],
["Count orders per customer last month"],
["Write a joke about cats"] # will be blocked
],
)
demo.launch(server_name="0.0.0.0")
|