MudabbirAI / agent_logic.py
youssefleb's picture
Update agent_logic.py
d33d284 verified
raw
history blame
16.1 kB
# agent_logic.py (Milestone 5 - FINAL & ROBUST + LOGGING + NATURAL TEXT + ALLOWLIST FILTER)
import asyncio
from typing import AsyncGenerator, Dict, Optional
import json
import os
import uuid
import datetime
import google.generativeai as genai
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
import re
from personas import PERSONAS_DATA
import config
from utils import load_prompt
# Removed extract_json_str as we no longer need to parse the solution
from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
from self_correction import SelfCorrector
CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
# --- METRIC BOUNCER (Allowlist) ---
# We map any variation of the key to the canonical internal name.
# If a key isn't in here, it gets dropped.
METRIC_MAPPING = {
"novelty": "Novelty",
"usefulness": "Usefulness_Feasibility",
"feasibility": "Usefulness_Feasibility",
"usefulness_feasibility": "Usefulness_Feasibility",
"usefulness/feasibility": "Usefulness_Feasibility",
"flexibility": "Flexibility",
"elaboration": "Elaboration",
"cultural_appropriateness": "Cultural_Appropriateness",
"cultural_sensitivity": "Cultural_Appropriateness",
"cultural appropriateness": "Cultural_Appropriateness",
"cultural appropriateness/sensitivity": "Cultural_Appropriateness"
}
class Baseline_Single_Agent:
def __init__(self, api_clients: dict):
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, persona_prompt: str):
if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
class Baseline_Static_Homogeneous:
def __init__(self, api_clients: dict):
self.api_clients = {name: client for name, client in api_clients.items() if client}
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, persona_prompt: str):
if not self.gemini_client: raise ValueError("Homogeneous_Team requires a Google/Gemini client.")
system_prompt = persona_prompt
user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
responses = await asyncio.gather(*tasks)
manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
class Baseline_Static_Heterogeneous:
def __init__(self, api_clients: dict):
self.api_clients = api_clients
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, team_plan: dict):
if not self.gemini_client: raise ValueError("Heterogeneous_Team requires a Google/Gemini client.")
tasks = []
for role, config_data in team_plan.items():
llm_name = config_data["llm"]
persona_key = config_data["persona"]
client = self.api_clients.get(llm_name)
if not client:
llm_name = "Gemini"
client = self.gemini_client
system_prompt = PERSONAS_DATA[persona_key]["description"]
user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
responses = await asyncio.gather(*tasks)
manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
class StrategicSelectorAgent:
def __init__(self, api_keys: Dict[str, Optional[str]]):
self.api_keys = api_keys
self.api_clients = { "Gemini": None, "Anthropic": None, "SambaNova": None }
if api_keys.get("google"):
try:
genai.configure(api_key=api_keys["google"])
self.api_clients["Gemini"] = genai.GenerativeModel(config.MODELS["Gemini"]["default"])
except Exception as e: print(f"Warning: Gemini init failed: {e}")
if api_keys.get("anthropic"):
try:
self.api_clients["Anthropic"] = AsyncAnthropic(api_key=api_keys["anthropic"])
except Exception as e: print(f"Warning: Anthropic init failed: {e}")
if api_keys.get("sambanova"):
try:
self.api_clients["SambaNova"] = AsyncOpenAI(api_key=api_keys["sambanova"], base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1"))
except Exception as e: print(f"Warning: SambaNova init failed: {e}")
if not self.api_clients["Gemini"]:
raise ValueError("Google API Key is required.")
self.evaluator = BusinessSolutionEvaluator(self.api_clients["Gemini"])
self.calibrator = AgentCalibrator(self.api_clients, self.evaluator)
self.corrector = SelfCorrector(threshold=3.0)
self.single_agent = Baseline_Single_Agent(self.api_clients)
self.homo_team = Baseline_Static_Homogeneous(self.api_clients)
self.hetero_team = Baseline_Static_Heterogeneous(self.api_clients)
self.current_team_plan = None
if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
async def _classify_problem(self, problem: str) -> AsyncGenerator[str, None]:
yield "Classifying problem archetype (live)..."
classification = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
classification = classification.strip().replace("\"", "")
yield f"Diagnosis: {classification}"
async def solve(self, problem: str) -> AsyncGenerator[str, None]:
# --- 1. Initialize Logging ---
run_id = str(uuid.uuid4())[:8]
debug_log = {
"run_id": run_id,
"timestamp": datetime.datetime.now().isoformat(),
"problem": problem,
"classification": "",
"trace": []
}
try:
classification_generator = self._classify_problem(problem)
classification = ""
async for status_update in classification_generator:
yield status_update
if "Diagnosis: " in status_update:
classification = status_update.split(": ")[-1]
debug_log["classification"] = classification
if "Error generating response" in classification:
yield "Classifier failed. Defaulting to Single Agent."
classification = "Direct_Procedure"
solution_draft = ""
v_fitness_json = {}
scores = {}
# --- MAIN LOOP (Self-Correction) ---
for i in range(2):
current_problem = problem
if i > 0:
yield f"--- (Loop {i}) Score is too low. Initiating Self-Correction... ---"
correction_prompt_text = self.corrector.get_correction_plan(v_fitness_json)
yield f"Diagnosis: {correction_prompt_text.splitlines()[3].strip()}"
current_problem = f"{problem}\n\n{correction_prompt_text}"
debug_log["trace"].append({
"step_type": "correction_plan",
"loop_index": i,
"prompt": correction_prompt_text
})
# --- DEPLOY ---
default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
if i == 0: yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
solution_draft = await self.single_agent.solve(current_problem, default_persona)
elif classification == "Local_Geometric_Procedural":
if i == 0: yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
solution_draft = await self.homo_team.solve(current_problem, default_persona)
elif classification == "Cognitive_Labyrinth":
if i == 0:
yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
team_plan, calibration_errors, calib_details = await self.calibrator.calibrate_team(current_problem)
debug_log["trace"].append({
"step_type": "calibration",
"details": calib_details,
"errors": calibration_errors,
"selected_plan": team_plan
})
if calibration_errors:
yield "--- CALIBRATION WARNINGS ---"
for err in calibration_errors: yield err
yield "-----------------------------"
yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
self.current_team_plan = team_plan
solution_draft = await self.hetero_team.solve(current_problem, self.current_team_plan)
else:
if i == 0: yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
solution_draft = await self.single_agent.solve(current_problem, default_persona)
if "Error generating response" in solution_draft:
raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
yield f"Draft solution received: '{solution_draft[:60]}...'"
# --- EVALUATE ---
yield "Evaluating draft (live)..."
v_fitness_json = await self.evaluator.evaluate(current_problem, solution_draft)
# --- Safety Check for List ---
if isinstance(v_fitness_json, list):
if len(v_fitness_json) > 0 and isinstance(v_fitness_json[0], dict):
v_fitness_json = v_fitness_json[0]
else:
v_fitness_json = {}
# --- ROBUST NORMALIZATION WITH ALLOWLIST FILTER ---
normalized_fitness = {}
if isinstance(v_fitness_json, dict):
for k, v in v_fitness_json.items():
# 1. Map fuzzy keys to canonical keys
canonical_key = None
clean_k = k.lower().strip()
# Check exact match or known variation
if clean_k in METRIC_MAPPING:
canonical_key = METRIC_MAPPING[clean_k]
# If we couldn't map it to a valid metric, SKIP IT.
if not canonical_key:
continue
# 2. Extract Score Value
if isinstance(v, dict):
score_value = v.get('score')
justification_value = v.get('justification', str(v))
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
score_value = v[0].get('score')
justification_value = v[0].get('justification', str(v[0]))
else:
# Flat value case
score_value = v
justification_value = "Score extracted directly."
# 3. Clean Score (handle "4/5" strings)
if isinstance(score_value, str):
try:
# Looks for the first number in the string
match = re.search(r'\d+', score_value)
score_value = int(match.group()) if match else 0
except:
score_value = 0
try:
score_value = int(score_value)
except (ValueError, TypeError):
score_value = 0
normalized_fitness[canonical_key] = {'score': score_value, 'justification': justification_value}
else:
# Fallback for total failure
normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
v_fitness_json = normalized_fitness
scores = {k: v.get('score', 0) for k, v in v_fitness_json.items()}
yield f"Evaluation Score: {scores}"
debug_log["trace"].append({
"step_type": "attempt",
"loop_index": i,
"draft": solution_draft,
"scores": scores,
"full_evaluation": v_fitness_json
})
if scores.get('Novelty', 0) <= 1:
yield f"⚠️ Low Score Detected. Reason: {v_fitness_json.get('Novelty', {}).get('justification', 'Unknown')}"
if self.corrector.is_good_enough(scores):
yield "--- Solution approved by self-corrector. ---"
break
elif i == 1:
yield "--- Max correction loops reached. Accepting best effort. ---"
# --- FINALIZE ---
await asyncio.sleep(0.5)
yield "Milestone 5 Complete. Self-Correction loop is live."
solution_draft_json_safe = json.dumps(solution_draft)
debug_log_json_safe = json.dumps(debug_log)
yield f"FINAL: {{\"text\": {solution_draft_json_safe}, \"audio\": null, \"log\": {debug_log_json_safe}}}"
except Exception as e:
error_msg = f"An error occurred in the agent's solve loop: {e}"
print(error_msg)
debug_log["error"] = str(e)
yield error_msg
finally:
try:
os.makedirs("logs", exist_ok=True)
log_path = f"logs/run_{run_id}.json"
with open(log_path, "w", encoding="utf-8") as f:
json.dump(debug_log, f, indent=2)
print(f"Detailed execution log saved to {log_path}")
except Exception as log_err:
print(f"Failed to save log: {log_err}")