youssefleb commited on
Commit
a4f27db
·
verified ·
1 Parent(s): 8f7bac8

Update mcp_servers.py

Browse files
Files changed (1) hide show
  1. mcp_servers.py +64 -41
mcp_servers.py CHANGED
@@ -1,4 +1,4 @@
1
- # mcp_servers.py (FIXED: Prompt-Schema Alignment + Detailed Logging)
2
  import asyncio
3
  import json
4
  import re
@@ -13,7 +13,6 @@ from personas import PERSONAS_DATA
13
  EVALUATION_PROMPT_TEMPLATE = load_prompt(config.PROMPT_FILES["evaluator"])
14
 
15
  # --- DEFINING THE SCHEMA TO FORCE JUSTIFICATIONS ---
16
- # This forces the LLM to return a specific JSON structure with "score" and "justification"
17
  EVALUATION_SCHEMA = {
18
  "type": "OBJECT",
19
  "properties": {
@@ -87,16 +86,13 @@ class BusinessSolutionEvaluator:
87
  if "ERROR:" in EVALUATION_PROMPT_TEMPLATE:
88
  raise FileNotFoundError(EVALUATION_PROMPT_TEMPLATE)
89
 
90
- async def evaluate(self, problem: str, solution_text: str) -> dict:
 
91
  print(f"Evaluating solution (live): {solution_text[:50]}...")
92
 
93
- # 1. Base Prompt from the clean text file
94
  base_prompt = EVALUATION_PROMPT_TEMPLATE.format(problem=problem, solution_text=solution_text)
95
 
96
- # 2. INJECT STRONG INSTRUCTION
97
- # This prevents the model from regurgitating examples found in the prompt file
98
  schema_instruction = """
99
-
100
  [IMPORTANT SYSTEM INSTRUCTION]
101
  Ignore any previous examples of JSON formatting in this prompt.
102
  You MUST strictly follow the Output Schema provided below.
@@ -109,9 +105,9 @@ class BusinessSolutionEvaluator:
109
  """
110
 
111
  final_prompt = base_prompt + schema_instruction
 
112
 
113
  try:
114
- # --- ENFORCE SCHEMA ---
115
  response = await self.gemini_model.generate_content_async(
116
  final_prompt,
117
  generation_config=genai.types.GenerationConfig(
@@ -120,14 +116,18 @@ class BusinessSolutionEvaluator:
120
  )
121
  )
122
 
 
 
 
 
 
123
  v_fitness = extract_json(response.text)
124
 
125
- # Strict Type Checking
126
  if not isinstance(v_fitness, (dict, list)):
127
  raise ValueError(f"Judge returned invalid type: {type(v_fitness)}")
128
 
129
  print(f"Evaluation complete (live): {v_fitness}")
130
- return v_fitness
131
 
132
  except Exception as e:
133
  print(f"ERROR: BusinessSolutionEvaluator failed: {e}")
@@ -137,7 +137,7 @@ class BusinessSolutionEvaluator:
137
  "Flexibility": {"score": 1, "justification": f"Error: {str(e)}"},
138
  "Elaboration": {"score": 1, "justification": f"Error: {str(e)}"},
139
  "Cultural_Appropriateness": {"score": 1, "justification": f"Error: {str(e)}"}
140
- }
141
 
142
  class AgentCalibrator:
143
  def __init__(self, api_clients: dict, evaluator: BusinessSolutionEvaluator):
@@ -146,11 +146,11 @@ class AgentCalibrator:
146
  self.sponsor_llms = list(self.api_clients.keys())
147
  print(f"AgentCalibrator initialized with enabled clients: {self.sponsor_llms}")
148
 
149
- # --- UPDATED: Return detailed results for logging ---
150
- async def calibrate_team(self, problem: str) -> Tuple[Dict[str, Any], List[str], List[Dict[str, Any]]]:
151
  print(f"Running LIVE calibration test for specialist team on {self.sponsor_llms}...")
152
  error_log = []
153
- detailed_results = [] # To capture full calibration data
 
154
 
155
  if not self.sponsor_llms:
156
  raise Exception("AgentCalibrator cannot run: No LLM clients are configured.")
@@ -163,7 +163,7 @@ class AgentCalibrator:
163
  "Implementer": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Implementer"], "llm": default_llm},
164
  "Monitor": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Monitor"], "llm": default_llm}
165
  }
166
- return plan, error_log, []
167
 
168
  roles_to_test = {
169
  role: PERSONAS_DATA[key]["description"]
@@ -177,7 +177,12 @@ class AgentCalibrator:
177
  tasks.append(self.run_calibration_test(problem, role, llm_name, persona, test_problem))
178
 
179
  results = await asyncio.gather(*tasks)
180
- detailed_results = results # Store the full results here
 
 
 
 
 
181
 
182
  best_llms = {}
183
  role_metrics = config.CALIBRATION_CONFIG["role_metrics"]
@@ -194,20 +199,12 @@ class AgentCalibrator:
194
 
195
  # Robust Dict Access
196
  raw_score_data = res.get("score", {})
197
-
198
- if not isinstance(raw_score_data, (dict, list)):
199
- raw_score_data = {}
200
-
201
- if isinstance(raw_score_data, list):
202
- raw_score_data = raw_score_data[0] if len(raw_score_data) > 0 else {}
203
 
204
  metric_data = raw_score_data.get(metric, {})
205
-
206
- if not isinstance(metric_data, (dict, list)):
207
- metric_data = {}
208
-
209
- if isinstance(metric_data, list):
210
- metric_data = metric_data[0] if len(metric_data) > 0 else {}
211
 
212
  score = metric_data.get("score", 0)
213
 
@@ -222,24 +219,34 @@ class AgentCalibrator:
222
  "Monitor": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Monitor"], "llm": best_llms["Monitor"]}
223
  }
224
  print(f"Calibration complete (live). Team plan: {team_plan}")
225
-
226
- # Return 3 values: The plan, errors, and the full trace
227
- return team_plan, error_log, detailed_results
228
 
229
  async def run_calibration_test(self, problem, role, llm_name, persona, test_problem):
230
  print(f"...Calibrating {role} on {llm_name}...")
231
  client = self.api_clients[llm_name]
232
- solution = await get_llm_response(llm_name, client, persona, test_problem)
 
 
233
 
234
  if "Error generating response" in solution:
235
- return {"role": role, "llm": llm_name, "error": solution, "output": solution}
 
 
 
236
 
237
- score = await self.evaluator.evaluate(problem, solution)
238
- return {"role": role, "llm": llm_name, "score": score, "output": solution}
 
 
 
 
 
 
239
 
240
  # --- Unified API Call Function ---
241
- async def get_llm_response(client_name: str, client, system_prompt: str, user_prompt: str) -> str:
242
- """A single function to handle calling any of the three sponsor LLMs."""
 
243
  try:
244
  if client_name == "Gemini":
245
  model = client
@@ -249,7 +256,13 @@ async def get_llm_response(client_name: str, client, system_prompt: str, user_pr
249
  {'role': 'user', 'parts': [user_prompt]}
250
  ]
251
  response = await model.generate_content_async(full_prompt)
252
- return response.text
 
 
 
 
 
 
253
 
254
  elif client_name == "Anthropic":
255
  response = await client.messages.create(
@@ -258,7 +271,12 @@ async def get_llm_response(client_name: str, client, system_prompt: str, user_pr
258
  system=system_prompt,
259
  messages=[{"role": "user", "content": user_prompt}]
260
  )
261
- return response.content[0].text
 
 
 
 
 
262
 
263
  elif client_name == "SambaNova":
264
  completion = await client.chat.completions.create(
@@ -268,9 +286,14 @@ async def get_llm_response(client_name: str, client, system_prompt: str, user_pr
268
  {"role": "user", "content": user_prompt}
269
  ]
270
  )
271
- return completion.choices[0].message.content
 
 
 
 
 
272
 
273
  except Exception as e:
274
  error_message = f"Error generating response from {client_name}: {str(e)}"
275
  print(f"ERROR: API call to {client_name} failed: {e}")
276
- return error_message
 
1
+ # mcp_servers.py (FIXED: Schema Enforcement + Detailed Logging + Usage Tracking)
2
  import asyncio
3
  import json
4
  import re
 
13
  EVALUATION_PROMPT_TEMPLATE = load_prompt(config.PROMPT_FILES["evaluator"])
14
 
15
  # --- DEFINING THE SCHEMA TO FORCE JUSTIFICATIONS ---
 
16
  EVALUATION_SCHEMA = {
17
  "type": "OBJECT",
18
  "properties": {
 
86
  if "ERROR:" in EVALUATION_PROMPT_TEMPLATE:
87
  raise FileNotFoundError(EVALUATION_PROMPT_TEMPLATE)
88
 
89
+ async def evaluate(self, problem: str, solution_text: str) -> Tuple[dict, dict]:
90
+ """Returns (evaluation_dict, usage_dict)"""
91
  print(f"Evaluating solution (live): {solution_text[:50]}...")
92
 
 
93
  base_prompt = EVALUATION_PROMPT_TEMPLATE.format(problem=problem, solution_text=solution_text)
94
 
 
 
95
  schema_instruction = """
 
96
  [IMPORTANT SYSTEM INSTRUCTION]
97
  Ignore any previous examples of JSON formatting in this prompt.
98
  You MUST strictly follow the Output Schema provided below.
 
105
  """
106
 
107
  final_prompt = base_prompt + schema_instruction
108
+ usage = {"model": "Gemini", "input": 0, "output": 0}
109
 
110
  try:
 
111
  response = await self.gemini_model.generate_content_async(
112
  final_prompt,
113
  generation_config=genai.types.GenerationConfig(
 
116
  )
117
  )
118
 
119
+ # Capture Usage
120
+ if hasattr(response, "usage_metadata"):
121
+ usage["input"] = response.usage_metadata.prompt_token_count
122
+ usage["output"] = response.usage_metadata.candidates_token_count
123
+
124
  v_fitness = extract_json(response.text)
125
 
 
126
  if not isinstance(v_fitness, (dict, list)):
127
  raise ValueError(f"Judge returned invalid type: {type(v_fitness)}")
128
 
129
  print(f"Evaluation complete (live): {v_fitness}")
130
+ return v_fitness, usage
131
 
132
  except Exception as e:
133
  print(f"ERROR: BusinessSolutionEvaluator failed: {e}")
 
137
  "Flexibility": {"score": 1, "justification": f"Error: {str(e)}"},
138
  "Elaboration": {"score": 1, "justification": f"Error: {str(e)}"},
139
  "Cultural_Appropriateness": {"score": 1, "justification": f"Error: {str(e)}"}
140
+ }, usage
141
 
142
  class AgentCalibrator:
143
  def __init__(self, api_clients: dict, evaluator: BusinessSolutionEvaluator):
 
146
  self.sponsor_llms = list(self.api_clients.keys())
147
  print(f"AgentCalibrator initialized with enabled clients: {self.sponsor_llms}")
148
 
149
+ async def calibrate_team(self, problem: str) -> Tuple[Dict[str, Any], List[str], List[Dict[str, Any]], List[Dict[str, Any]]]:
 
150
  print(f"Running LIVE calibration test for specialist team on {self.sponsor_llms}...")
151
  error_log = []
152
+ detailed_results = []
153
+ all_usage_stats = [] # Collect all usage data here
154
 
155
  if not self.sponsor_llms:
156
  raise Exception("AgentCalibrator cannot run: No LLM clients are configured.")
 
163
  "Implementer": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Implementer"], "llm": default_llm},
164
  "Monitor": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Monitor"], "llm": default_llm}
165
  }
166
+ return plan, error_log, [], []
167
 
168
  roles_to_test = {
169
  role: PERSONAS_DATA[key]["description"]
 
177
  tasks.append(self.run_calibration_test(problem, role, llm_name, persona, test_problem))
178
 
179
  results = await asyncio.gather(*tasks)
180
+ detailed_results = results
181
+
182
+ # Flatten results to extract usage
183
+ for res in results:
184
+ if "usage_gen" in res: all_usage_stats.append(res["usage_gen"])
185
+ if "usage_eval" in res: all_usage_stats.append(res["usage_eval"])
186
 
187
  best_llms = {}
188
  role_metrics = config.CALIBRATION_CONFIG["role_metrics"]
 
199
 
200
  # Robust Dict Access
201
  raw_score_data = res.get("score", {})
202
+ if not isinstance(raw_score_data, (dict, list)): raw_score_data = {}
203
+ if isinstance(raw_score_data, list): raw_score_data = raw_score_data[0] if len(raw_score_data) > 0 else {}
 
 
 
 
204
 
205
  metric_data = raw_score_data.get(metric, {})
206
+ if not isinstance(metric_data, (dict, list)): metric_data = {}
207
+ if isinstance(metric_data, list): metric_data = metric_data[0] if len(metric_data) > 0 else {}
 
 
 
 
208
 
209
  score = metric_data.get("score", 0)
210
 
 
219
  "Monitor": {"persona": config.CALIBRATION_CONFIG["roles_to_test"]["Monitor"], "llm": best_llms["Monitor"]}
220
  }
221
  print(f"Calibration complete (live). Team plan: {team_plan}")
222
+ return team_plan, error_log, detailed_results, all_usage_stats
 
 
223
 
224
  async def run_calibration_test(self, problem, role, llm_name, persona, test_problem):
225
  print(f"...Calibrating {role} on {llm_name}...")
226
  client = self.api_clients[llm_name]
227
+
228
+ # 1. Generate Solution (and get usage)
229
+ solution, gen_usage = await get_llm_response(llm_name, client, persona, test_problem)
230
 
231
  if "Error generating response" in solution:
232
+ return {"role": role, "llm": llm_name, "error": solution, "output": solution, "usage_gen": gen_usage}
233
+
234
+ # 2. Evaluate Solution (and get usage)
235
+ score, eval_usage = await self.evaluator.evaluate(problem, solution)
236
 
237
+ return {
238
+ "role": role,
239
+ "llm": llm_name,
240
+ "score": score,
241
+ "output": solution,
242
+ "usage_gen": gen_usage,
243
+ "usage_eval": eval_usage
244
+ }
245
 
246
  # --- Unified API Call Function ---
247
+ async def get_llm_response(client_name: str, client, system_prompt: str, user_prompt: str) -> Tuple[str, dict]:
248
+ """Returns (text_response, usage_dict)"""
249
+ usage = {"model": client_name, "input": 0, "output": 0}
250
  try:
251
  if client_name == "Gemini":
252
  model = client
 
256
  {'role': 'user', 'parts': [user_prompt]}
257
  ]
258
  response = await model.generate_content_async(full_prompt)
259
+
260
+ # Capture Gemini Usage
261
+ if hasattr(response, "usage_metadata"):
262
+ usage["input"] = response.usage_metadata.prompt_token_count
263
+ usage["output"] = response.usage_metadata.candidates_token_count
264
+
265
+ return response.text, usage
266
 
267
  elif client_name == "Anthropic":
268
  response = await client.messages.create(
 
271
  system=system_prompt,
272
  messages=[{"role": "user", "content": user_prompt}]
273
  )
274
+ # Capture Anthropic Usage
275
+ if hasattr(response, "usage"):
276
+ usage["input"] = response.usage.input_tokens
277
+ usage["output"] = response.usage.output_tokens
278
+
279
+ return response.content[0].text, usage
280
 
281
  elif client_name == "SambaNova":
282
  completion = await client.chat.completions.create(
 
286
  {"role": "user", "content": user_prompt}
287
  ]
288
  )
289
+ # Capture SambaNova Usage
290
+ if hasattr(completion, "usage"):
291
+ usage["input"] = completion.usage.prompt_tokens
292
+ usage["output"] = completion.usage.completion_tokens
293
+
294
+ return completion.choices[0].message.content, usage
295
 
296
  except Exception as e:
297
  error_message = f"Error generating response from {client_name}: {str(e)}"
298
  print(f"ERROR: API call to {client_name} failed: {e}")
299
+ return error_message, usage