File size: 2,921 Bytes
e62bece
7f3026b
c7c0d53
7f424d1
02976e0
e62bece
00c8a57
e95c2d3
 
bb16527
4bc3e8b
e95c2d3
 
 
 
15e7b42
e95c2d3
7f3026b
e95c2d3
 
f5903a4
e95c2d3
02976e0
f5903a4
 
 
 
 
 
 
 
 
 
 
 
 
15e7b42
bb16527
15e7b42
f5903a4
 
 
 
 
 
 
 
 
 
 
 
 
bb16527
a2f39c6
a0fbd48
f5903a4
a0fbd48
f5903a4
bb16527
f5903a4
4bc3e8b
f5903a4
bb16527
 
 
 
22df2c5
e95c2d3
 
bb16527
 
f5903a4
7f3026b
e62bece
7f424d1
22df2c5
e95c2d3
02976e0
f5903a4
 
 
 
 
 
 
8b67be0
15e7b42
e95c2d3
15e7b42
84031c5
e62bece
f5903a4
 
 
 
 
 
 
 
 
 
 
bb16527
 
 
f5903a4
 
bb16527
84031c5
1344c31
bb16527
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import warnings
warnings.filterwarnings("ignore")

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.set_num_threads(1)

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print("Loading model...")

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float32
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()

print("Model ready")

# ─────────────────────────
# SQL FILTER
# ─────────────────────────
SQL_KEYWORDS = [
    "sql", "database", "table", "select", "insert",
    "update", "delete", "join", "group by",
    "postgres", "mysql", "sqlite", "query"
]

def is_sql_related(text):
    text = text.lower()
    return any(k in text for k in SQL_KEYWORDS)

# ─────────────────────────
# GENERATION
# ─────────────────────────
SYSTEM_PROMPT = """
You are an expert SQL generator.

Rules:
- Only respond to SQL or database related questions.
- If the question is not about SQL or databases, refuse.
- Output ONLY SQL query.
- Do not explain.
"""

def generate_sql(user_input):

    if not user_input.strip():
        return "Enter SQL question."

    # HARD GUARD
    if not is_sql_related(user_input):
        return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you."

    prompt = f"""
{SYSTEM_PROMPT}

User request: {user_input}
SQL:
"""

    inputs = tokenizer(prompt, return_tensors="pt")

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=120,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )

    text = tokenizer.decode(output[0], skip_special_tokens=True)

    # return only SQL part
    result = text.split("SQL:")[-1].strip()

    # extra safety: remove explanations
    result = result.split("\n\n")[0]

    return result

# ─────────────────────────
# UI
# ─────────────────────────
demo = gr.Interface(
    fn=generate_sql,
    inputs=gr.Textbox(
        lines=3,
        label="SQL Question",
        placeholder="Find duplicate emails in users table"
    ),
    outputs=gr.Textbox(
        lines=8,
        label="Generated SQL"
    ),
    title="AI SQL Generator (Portfolio Project)",
    description="This model ONLY responds to SQL/database queries.",
    examples=[
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees"],
        ["Count orders per customer last month"],
        ["Write a joke about cats"]  # will be blocked
    ],
)

demo.launch(server_name="0.0.0.0")