triflix commited on
Commit
b3f0838
·
verified ·
1 Parent(s): d14886f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +48 -75
main.py CHANGED
@@ -11,119 +11,92 @@ from huggingface_hub import login
11
  # ==========================================
12
  # 1. APP SETUP
13
  # ==========================================
 
14
 
15
- app = FastAPI(
16
- title="FunctionGemma Brain API",
17
- version="1.0.0",
18
- )
19
-
20
- # Global variables
21
  MODEL_ID = "google/functiongemma-270m-it"
22
  tokenizer = None
23
  model = None
24
 
25
  # ==========================================
26
- # 2. DATA MODELS
27
  # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- class ChatRequest(BaseModel):
30
- query: str = Field(..., min_length=1, max_length=4096)
31
- tools: List[Dict[str, Any]]
32
- include_date: bool = True
33
-
34
- class HealthResponse(BaseModel):
35
- status: str
36
- model: str
37
- auth_status: str
38
 
39
  # ==========================================
40
- # 3. STARTUP (Auth + Load Model)
41
  # ==========================================
42
-
43
  @app.on_event("startup")
44
  async def startup():
45
  global tokenizer, model
46
-
47
- # A. Authenticate using Environment Variable
48
- print("🔐 Checking for HF_TOKEN...")
49
  hf_token = os.getenv("HF_TOKEN")
 
 
50
 
51
- if not hf_token:
52
- print("❌ Error: HF_TOKEN environment variable is missing.")
53
- raise RuntimeError("HF_TOKEN environment variable is missing in Space Settings.")
54
-
55
- try:
56
- login(token=hf_token)
57
- print("✅ Authentication successful.")
58
- except Exception as e:
59
- print(f"❌ Authentication Failed: {e}")
60
- raise RuntimeError(f"Hugging Face login failed: {e}")
61
-
62
- # B. Load Model
63
- print(f"🧠 Loading Model: {MODEL_ID}...")
64
- try:
65
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
66
- model = AutoModelForCausalLM.from_pretrained(
67
- MODEL_ID,
68
- device_map="cpu",
69
- torch_dtype=torch.float32,
70
- )
71
- print("✅ Model Loaded Successfully.")
72
- except Exception as e:
73
- print(f"❌ Model Load Failed: {e}")
74
- raise RuntimeError(f"Model load failed: {e}")
75
 
76
  # ==========================================
77
  # 4. API ENDPOINT
78
  # ==========================================
 
 
 
 
79
 
80
  @app.post("/generate")
81
  async def generate_function_call(request: ChatRequest):
82
- if model is None or tokenizer is None:
83
- raise HTTPException(status_code=503, detail="Model not ready")
84
 
85
  try:
86
- # System context
87
- system_content = (
88
- "You are a model that can do function calling with the following functions."
89
- )
90
  if request.include_date:
91
  today = datetime.date.today().isoformat()
92
  system_content += f" Today is {today}."
93
 
94
- messages = [
95
- {"role": "system", "content": system_content},
96
- {"role": "user", "content": request.query},
97
- ]
 
 
 
 
98
 
 
99
  inputs = tokenizer.apply_chat_template(
100
  messages,
101
  tools=request.tools,
102
  add_generation_prompt=True,
103
- return_tensors="pt",
104
  return_dict=True,
 
105
  )
106
 
107
- outputs = model.generate(
108
- **inputs,
109
- max_new_tokens=128,
110
- do_sample=False, # deterministic
111
- )
112
-
113
- generated_text = tokenizer.decode(
114
- outputs[0][len(inputs["input_ids"][0]):],
115
- skip_special_tokens=True,
116
- )
117
 
118
  return {"response": generated_text}
119
 
120
  except Exception as e:
121
- raise HTTPException(status_code=500, detail=str(e))
122
-
123
- @app.get("/", response_model=HealthResponse)
124
- def health_check():
125
- return {
126
- "status": "running",
127
- "model": MODEL_ID,
128
- "auth_status": "secure_env",
129
- }
 
11
  # ==========================================
12
  # 1. APP SETUP
13
  # ==========================================
14
+ app = FastAPI(title="FunctionGemma Brain API", version="1.0.0")
15
 
 
 
 
 
 
 
16
  MODEL_ID = "google/functiongemma-270m-it"
17
  tokenizer = None
18
  model = None
19
 
20
  # ==========================================
21
+ # 2. FEW-SHOT EXAMPLES (The Teacher)
22
  # ==========================================
23
+ # We teach the model the correct tool names here.
24
+ # This list simulates a previous conversation so the model knows what to do.
25
+ FEW_SHOT_MESSAGES = [
26
+ # Example 1: Counting/Stats
27
+ {"role": "user", "content": "How many regions are there?"},
28
+ {"role": "model", "content": "<start_function_call>call:get_aggregate_stats{target_entity:revenue_region}<end_function_call>"},
29
+
30
+ # Example 2: Specific Search
31
+ {"role": "user", "content": "What is the water level in Aadale dam?"},
32
+ {"role": "model", "content": "<start_function_call>call:search_specific_dam{dam_name:Aadale}<end_function_call>"},
33
+
34
+ # Example 3: Filtering
35
+ {"role": "user", "content": "Show me Major dams in Pune."},
36
+ {"role": "model", "content": "<start_function_call>call:filter_dams{district:Pune,project_type:Major}<end_function_call>"},
37
 
38
+ # Example 4: Irrelevant Question (Teach it to NOT call functions for random stuff)
39
+ {"role": "user", "content": "What is the capital of France?"},
40
+ {"role": "model", "content": "I cannot answer that as it is not related to the dam database."}
41
+ ]
 
 
 
 
 
42
 
43
  # ==========================================
44
+ # 3. STARTUP
45
  # ==========================================
 
46
  @app.on_event("startup")
47
  async def startup():
48
  global tokenizer, model
 
 
 
49
  hf_token = os.getenv("HF_TOKEN")
50
+ if not hf_token: raise RuntimeError("HF_TOKEN missing")
51
+ login(token=hf_token)
52
 
53
+ print(f"🧠 Loading {MODEL_ID}...")
54
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
55
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32)
56
+ print("✅ Model Loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # ==========================================
59
  # 4. API ENDPOINT
60
  # ==========================================
61
+ class ChatRequest(BaseModel):
62
+ query: str
63
+ tools: List[Dict[str, Any]]
64
+ include_date: bool = True
65
 
66
  @app.post("/generate")
67
  async def generate_function_call(request: ChatRequest):
68
+ if not model: raise HTTPException(status_code=503, detail="Model loading")
 
69
 
70
  try:
71
+ # 1. System Prompt
72
+ system_content = "You are a model that can do function calling with the following functions."
 
 
73
  if request.include_date:
74
  today = datetime.date.today().isoformat()
75
  system_content += f" Today is {today}."
76
 
77
+ # 2. Construct History: System -> Examples -> Current User Query
78
+ messages = [{"role": "system", "content": system_content}]
79
+
80
+ # Inject the examples!
81
+ messages.extend(FEW_SHOT_MESSAGES)
82
+
83
+ # Add the actual user query
84
+ messages.append({"role": "user", "content": request.query})
85
 
86
+ # 3. Tokenize
87
  inputs = tokenizer.apply_chat_template(
88
  messages,
89
  tools=request.tools,
90
  add_generation_prompt=True,
 
91
  return_dict=True,
92
+ return_tensors="pt",
93
  )
94
 
95
+ # 4. Generate
96
+ outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False)
97
+ generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 
 
 
 
 
 
 
98
 
99
  return {"response": generated_text}
100
 
101
  except Exception as e:
102
+ raise HTTPException(status_code=500, detail=str(e))