youssefleb commited on
Commit
2960fc5
·
verified ·
1 Parent(s): 7daf486

Update agent_logic.py

Browse files
Files changed (1) hide show
  1. agent_logic.py +116 -96
agent_logic.py CHANGED
@@ -1,4 +1,4 @@
1
- # agent_logic.py (Now with error logging)
2
  import asyncio
3
  from typing import AsyncGenerator, Dict, Optional
4
  import json
@@ -10,83 +10,60 @@ from personas import PERSONAS_DATA
10
  import config
11
  from utils import load_prompt
12
  from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
 
 
13
 
14
  CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
15
  HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
16
  HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
17
 
18
- # --- (Specialist Agent classes are unchanged) ---
19
  class Baseline_Single_Agent:
20
  def __init__(self, api_clients: dict):
21
  self.gemini_client = api_clients.get("Gemini")
22
-
23
  async def solve(self, problem: str, persona_prompt: str):
24
- print(f"--- (Specialist Team: Single Agent) solving (live)... ---")
25
- if not self.gemini_client:
26
- raise ValueError("Single_Agent requires a Google/Gemini client.")
27
  return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
28
 
29
  class Baseline_Static_Homogeneous:
30
  def __init__(self, api_clients: dict):
31
  self.api_clients = {name: client for name, client in api_clients.items() if client}
32
  self.gemini_client = api_clients.get("Gemini")
33
-
34
  async def solve(self, problem: str, persona_prompt: str):
35
- print(f"--- (Specialist Team: Homogeneous) solving (live)... ---")
36
- if not self.gemini_client:
37
- raise ValueError("Homogeneous_Team requires a Google/Gemini client for its manager.")
38
-
39
  system_prompt = persona_prompt
40
  user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
41
-
42
- tasks = []
43
- for llm_name, client in self.api_clients.items():
44
- tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
45
-
46
  responses = await asyncio.gather(*tasks)
47
-
48
  manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
49
  reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
50
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
51
-
52
  return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
53
 
54
  class Baseline_Static_Heterogeneous:
55
  def __init__(self, api_clients: dict):
56
  self.api_clients = api_clients
57
  self.gemini_client = api_clients.get("Gemini")
58
-
59
  async def solve(self, problem: str, team_plan: dict):
60
- print(f"--- (Specialist Team: Heterogeneous) solving (live)... ---")
61
- if not self.gemini_client:
62
- raise ValueError("Heterogeneous_Team requires a Google/Gemini client for its manager.")
63
-
64
  tasks = []
65
  for role, config_data in team_plan.items():
66
  llm_name = config_data["llm"]
67
  persona_key = config_data["persona"]
68
  client = self.api_clients.get(llm_name)
69
-
70
  if not client:
71
- print(f"Warning: Calibrated LLM '{llm_name}' for role '{role}' is not available. Defaulting to Gemini.")
72
  llm_name = "Gemini"
73
  client = self.gemini_client
74
-
75
  system_prompt = PERSONAS_DATA[persona_key]["description"]
76
  user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
77
  tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
78
-
79
  responses = await asyncio.gather(*tasks)
80
-
81
  manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
82
  reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
83
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
84
-
85
  return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
86
 
87
  class StrategicSelectorAgent:
88
- """This is MudabbirAI. It gets keys passed to it on creation."""
89
-
90
  def __init__(self, api_keys: Dict[str, Optional[str]]):
91
  self.api_keys = api_keys
92
  self.api_clients = { "Gemini": None, "Anthropic": None, "SambaNova": None }
@@ -95,46 +72,114 @@ class StrategicSelectorAgent:
95
  try:
96
  genai.configure(api_key=api_keys["google"])
97
  self.api_clients["Gemini"] = genai.GenerativeModel(config.MODELS["Gemini"]["default"])
98
- except Exception as e:
99
- print(f"Warning: Failed to initialize Gemini client. Error: {e}")
100
-
101
  if api_keys.get("anthropic"):
102
  try:
103
  self.api_clients["Anthropic"] = AsyncAnthropic(api_key=api_keys["anthropic"])
104
- except Exception as e:
105
- print(f"Warning: Failed to initialize Anthropic client. Error: {e}")
106
-
107
  if api_keys.get("sambanova"):
108
  try:
109
  base_url = os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
110
  self.api_clients["SambaNova"] = AsyncOpenAI(api_key=api_keys["sambanova"], base_url=base_url)
111
- except Exception as e:
112
- print(f"Warning: Failed to initialize SambaNova client. Error: {e}")
113
 
114
  if not self.api_clients["Gemini"]:
115
- raise ValueError("Google API Key is required or invalid. The agent cannot function without its 'Judge'.")
116
 
117
  self.evaluator = BusinessSolutionEvaluator(self.api_clients["Gemini"])
118
  self.calibrator = AgentCalibrator(self.api_clients, self.evaluator)
119
-
120
  self.single_agent = Baseline_Single_Agent(self.api_clients)
121
  self.homo_team = Baseline_Static_Homogeneous(self.api_clients)
122
  self.hetero_team = Baseline_Static_Heterogeneous(self.api_clients)
 
123
 
124
- if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT:
125
- raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
126
 
127
  async def _classify_problem(self, problem: str) -> AsyncGenerator[str, None]:
128
  yield "Classifying problem archetype (live)..."
129
- classification = await get_llm_response(
130
- "Gemini",
131
- self.api_clients["Gemini"],
132
- CLASSIFIER_SYSTEM_PROMPT,
133
- problem
134
- )
135
  classification = classification.strip().replace("\"", "")
136
  yield f"Diagnosis: {classification}"
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  async def solve(self, problem: str) -> AsyncGenerator[str, None]:
139
  classification_generator = self._classify_problem(problem)
140
  classification = ""
@@ -147,60 +192,35 @@ class StrategicSelectorAgent:
147
  yield "Classifier failed. Defaulting to Single Agent."
148
  classification = "Direct_Procedure"
149
 
150
- solution_draft = ""
151
 
152
  try:
153
- default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
154
-
155
- if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
156
- yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
157
- solution_draft = await self.single_agent.solve(problem, default_persona)
158
-
159
- elif classification == "Local_Geometric_Procedural":
160
- yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
161
- solution_draft = await self.homo_team.solve(problem, default_persona)
162
 
163
- elif classification == "Cognitive_Labyrinth":
164
- yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
165
 
166
- # --- NEW: Capture errors from calibration ---
167
- team_plan, calibration_errors = await self.calibrator.calibrate_team(problem)
168
 
169
- # --- NEW: Yield any calibration errors ---
170
- if calibration_errors:
171
- yield "--- CALIBRATION WARNINGS ---"
172
- for err in calibration_errors:
173
- yield err
174
- yield "-----------------------------"
175
-
176
- yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
177
- solution_draft = await self.hetero_team.solve(problem, team_plan)
178
-
179
- else:
180
- yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
181
- solution_draft = await self.single_agent.solve(problem, default_persona)
182
-
183
- if "Error generating response" in solution_draft:
184
- raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
185
-
186
- yield f"Draft solution received: '{solution_draft[:60]}...'"
187
-
188
- yield "Evaluating final draft (live)..."
189
- v_fitness_json = await self.evaluator.evaluate(problem, solution_draft)
190
-
191
- scores = {k: v.get('score', 0) for k, v in v_fitness_json.items()}
192
- yield f"Initial Score: {scores}"
193
- # If the score is the default '1', show the error message hidden in the justification
194
- if scores.get('Novelty', 0) == 1:
195
- yield f"⚠️ Low Score Detected. Reason: {v_fitness_json.get('Novelty', {}).get('justification', 'Unknown')}"
196
- # -----------------------
197
-
198
- # --- This is where Milestone 5 will go ---
199
- yield "Skipping self-correction for now..."
200
 
 
201
  await asyncio.sleep(0.5)
202
- yield "Milestone 4 (with error logging) Complete."
203
-
204
  solution_draft_json_safe = json.dumps(solution_draft)
205
  yield f"FINAL: {{\"text\": {solution_draft_json_safe}, \"audio\": null}}"
206
 
 
1
+ # agent_logic.py (Milestone 5 - FINAL & ROBUST)
2
  import asyncio
3
  from typing import AsyncGenerator, Dict, Optional
4
  import json
 
10
  import config
11
  from utils import load_prompt
12
  from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
13
+ from self_correction import SelfCorrector
14
+ from async_generator import async_generator, yield_
15
 
16
  CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
17
  HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
18
  HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
19
 
20
+ # (Baseline Agent Classes - UNCHANGED)
21
  class Baseline_Single_Agent:
22
  def __init__(self, api_clients: dict):
23
  self.gemini_client = api_clients.get("Gemini")
 
24
  async def solve(self, problem: str, persona_prompt: str):
25
+ if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
 
 
26
  return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
27
 
28
  class Baseline_Static_Homogeneous:
29
  def __init__(self, api_clients: dict):
30
  self.api_clients = {name: client for name, client in api_clients.items() if client}
31
  self.gemini_client = api_clients.get("Gemini")
 
32
  async def solve(self, problem: str, persona_prompt: str):
33
+ if not self.gemini_client: raise ValueError("Homogeneous_Team requires a Google/Gemini client.")
 
 
 
34
  system_prompt = persona_prompt
35
  user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
36
+ tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
 
 
 
 
37
  responses = await asyncio.gather(*tasks)
 
38
  manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
39
  reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
40
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
 
41
  return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
42
 
43
  class Baseline_Static_Heterogeneous:
44
  def __init__(self, api_clients: dict):
45
  self.api_clients = api_clients
46
  self.gemini_client = api_clients.get("Gemini")
 
47
  async def solve(self, problem: str, team_plan: dict):
48
+ if not self.gemini_client: raise ValueError("Heterogeneous_Team requires a Google/Gemini client.")
 
 
 
49
  tasks = []
50
  for role, config_data in team_plan.items():
51
  llm_name = config_data["llm"]
52
  persona_key = config_data["persona"]
53
  client = self.api_clients.get(llm_name)
 
54
  if not client:
 
55
  llm_name = "Gemini"
56
  client = self.gemini_client
 
57
  system_prompt = PERSONAS_DATA[persona_key]["description"]
58
  user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
59
  tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
 
60
  responses = await asyncio.gather(*tasks)
 
61
  manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
62
  reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
63
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
 
64
  return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
65
 
66
  class StrategicSelectorAgent:
 
 
67
  def __init__(self, api_keys: Dict[str, Optional[str]]):
68
  self.api_keys = api_keys
69
  self.api_clients = { "Gemini": None, "Anthropic": None, "SambaNova": None }
 
72
  try:
73
  genai.configure(api_key=api_keys["google"])
74
  self.api_clients["Gemini"] = genai.GenerativeModel(config.MODELS["Gemini"]["default"])
75
+ except Exception as e: print(f"Warning: Gemini init failed: {e}")
 
 
76
  if api_keys.get("anthropic"):
77
  try:
78
  self.api_clients["Anthropic"] = AsyncAnthropic(api_key=api_keys["anthropic"])
79
+ except Exception as e: print(f"Warning: Anthropic init failed: {e}")
 
 
80
  if api_keys.get("sambanova"):
81
  try:
82
  base_url = os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
83
  self.api_clients["SambaNova"] = AsyncOpenAI(api_key=api_keys["sambanova"], base_url=base_url)
84
+ except Exception as e: print(f"Warning: SambaNova init failed: {e}")
 
85
 
86
  if not self.api_clients["Gemini"]:
87
+ raise ValueError("Google API Key is required.")
88
 
89
  self.evaluator = BusinessSolutionEvaluator(self.api_clients["Gemini"])
90
  self.calibrator = AgentCalibrator(self.api_clients, self.evaluator)
91
+ self.corrector = SelfCorrector(threshold=3.0)
92
  self.single_agent = Baseline_Single_Agent(self.api_clients)
93
  self.homo_team = Baseline_Static_Homogeneous(self.api_clients)
94
  self.hetero_team = Baseline_Static_Heterogeneous(self.api_clients)
95
+ self.current_team_plan = None
96
 
97
+ if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
 
98
 
99
  async def _classify_problem(self, problem: str) -> AsyncGenerator[str, None]:
100
  yield "Classifying problem archetype (live)..."
101
+ classification = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
 
 
 
 
 
102
  classification = classification.strip().replace("\"", "")
103
  yield f"Diagnosis: {classification}"
104
 
105
+ @async_generator
106
+ async def _generate_and_evaluate(self, problem: str, classification: str, correction_prompt: Optional[str] = None):
107
+ solution_draft = ""
108
+ team_plan = {}
109
+
110
+ if correction_prompt:
111
+ problem = f"{problem}\n\n{correction_prompt}"
112
+
113
+ default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
114
+
115
+ if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
116
+ if not correction_prompt:
117
+ await yield_("Deploying: Baseline Single Agent (Simplicity Hypothesis)...")
118
+ solution_draft = await self.single_agent.solve(problem, default_persona)
119
+
120
+ elif classification == "Local_Geometric_Procedural":
121
+ if not correction_prompt:
122
+ await yield_("Deploying: Static Homogeneous Team (Expert Anomaly)...")
123
+ solution_draft = await self.homo_team.solve(problem, default_persona)
124
+
125
+ elif classification == "Cognitive_Labyrinth":
126
+ if not correction_prompt:
127
+ await yield_("Deploying: Static Heterogeneous Team (Cognitive Diversity)...")
128
+ team_plan, calibration_errors = await self.calibrator.calibrate_team(problem)
129
+ if calibration_errors:
130
+ await yield_("--- CALIBRATION WARNINGS ---")
131
+ for err in calibration_errors: await yield_(err)
132
+ await yield_("-----------------------------")
133
+ await yield_(f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}")
134
+ self.current_team_plan = team_plan
135
+
136
+ # Reuse the calibrated team
137
+ solution_draft = await self.hetero_team.solve(problem, self.current_team_plan)
138
+
139
+ else:
140
+ if not correction_prompt:
141
+ await yield_(f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent.")
142
+ solution_draft = await self.single_agent.solve(problem, default_persona)
143
+
144
+ if "Error generating response" in solution_draft:
145
+ raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
146
+
147
+ await yield_(f"Draft solution received: '{solution_draft[:60]}...'")
148
+
149
+ # --- EVALUATE ---
150
+ await yield_("Evaluating final draft (live)...")
151
+ v_fitness_json = await self.evaluator.evaluate(problem, solution_draft)
152
+
153
+ # --- NEW: Robust Normalization of Evaluation Data ---
154
+ # This block fixes the "list object has no attribute get" error
155
+ normalized_fitness = {}
156
+ if isinstance(v_fitness_json, dict):
157
+ for k, v in v_fitness_json.items():
158
+ if isinstance(v, dict):
159
+ normalized_fitness[k] = v
160
+ elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
161
+ # If the LLM wrapped the object in a list, unwrap it
162
+ normalized_fitness[k] = v[0]
163
+ else:
164
+ # Fallback for unexpected structure
165
+ normalized_fitness[k] = {'score': 0, 'justification': str(v)}
166
+ else:
167
+ # Fallback if the whole thing isn't a dict
168
+ await yield_(f"Warning: Invalid JSON structure from Judge: {type(v_fitness_json)}")
169
+ normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
170
+
171
+ v_fitness_json = normalized_fitness
172
+ # ----------------------------------------------------
173
+
174
+ scores = {k: v.get('score', 0) for k, v in v_fitness_json.items()}
175
+ await yield_(f"Evaluation Score: {scores}")
176
+
177
+ # Debug info if score is low
178
+ if scores.get('Novelty', 0) <= 1:
179
+ await yield_(f"⚠️ Low Score Detected. Reason: {v_fitness_json.get('Novelty', {}).get('justification', 'Unknown')}")
180
+
181
+ return solution_draft, v_fitness_json, scores
182
+
183
  async def solve(self, problem: str) -> AsyncGenerator[str, None]:
184
  classification_generator = self._classify_problem(problem)
185
  classification = ""
 
192
  yield "Classifier failed. Defaulting to Single Agent."
193
  classification = "Direct_Procedure"
194
 
195
+ solution_draft, v_fitness_json, scores = "", {}, {}
196
 
197
  try:
198
+ # --- MAIN LOOP (Self-Correction) ---
199
+ for i in range(2):
200
+ current_problem = problem
201
+ if i > 0:
202
+ yield f"--- (Loop {i}) Score is too low. Initiating Self-Correction... ---"
203
+ correction_prompt_text = self.corrector.get_correction_plan(v_fitness_json)
204
+ yield f"Diagnosis: {correction_prompt_text.splitlines()[3].strip()}"
205
+ current_problem = f"{problem}\n\n{correction_prompt_text}"
 
206
 
207
+ loop_generator = self._generate_and_evaluate(current_problem, classification, None if i==0 else "Correcting...")
 
208
 
209
+ async for status_update in loop_generator:
210
+ yield status_update
211
 
212
+ solution_draft, v_fitness_json, scores = await loop_generator.aclose() # Wait for return
213
+
214
+ # Check if we passed
215
+ if self.corrector.is_good_enough(scores):
216
+ yield "--- Solution approved by self-corrector. ---"
217
+ break
218
+ elif i == 1:
219
+ yield "--- Max correction loops reached. Accepting best effort. ---"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ # --- FINALIZE ---
222
  await asyncio.sleep(0.5)
223
+ yield "Milestone 5 Complete. Self-Correction loop is live."
 
224
  solution_draft_json_safe = json.dumps(solution_draft)
225
  yield f"FINAL: {{\"text\": {solution_draft_json_safe}, \"audio\": null}}"
226