"""CLI: dataset path, baseline, output dir, dry-run, smoke eval. Evaluation uses batch LLM judge: 2 calls/session + 2 calls/QA. Session and QA evaluations run in parallel via ThreadPoolExecutor. Pipeline results are checkpointed before eval so --eval-only can resume. """ from __future__ import annotations import argparse import json import os from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pathlib import Path from typing import Any try: from openai import OpenAI except ImportError: OpenAI = None # type: ignore[assignment] from eval_framework.config import EvalConfig from eval_framework.datasets.domain_a_v2 import ( DomainAV2AcademicBundle, NormalizedCheckpointQuestion, load_domain_a_v2_academic, ) from eval_framework.datasets.schemas import ( MemoryDeltaRecord, MemorySnapshotRecord, RetrievalItem, RetrievalRecord, ) from eval_framework.evaluators.aggregate import aggregate_metrics from eval_framework.evaluators.extraction import evaluate_extraction from eval_framework.evaluators.qa import evaluate_checkpoint_qa from eval_framework.memory_adapters.base import MemoryAdapter from eval_framework.openai_compat import ( patch_openai_chat_completions, rewrite_chat_completion_kwargs, ) from eval_framework.pipeline.gold_state import GoldMemoryPoint, SessionGoldState from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord from eval_framework.pipeline.runner import run_domain_a_v2_sample _CHECKPOINT_SESSIONS = "pipeline_sessions.jsonl" _CHECKPOINT_QA = "pipeline_qa.jsonl" # --------------------------------------------------------------------------- # Checkpoint deserialization: dict -> frozen dataclass # --------------------------------------------------------------------------- def _gold_point_from_dict(d: dict[str, Any]) -> GoldMemoryPoint: return GoldMemoryPoint( memory_id=d["memory_id"], memory_content=d["memory_content"], memory_type=d["memory_type"], memory_source=d["memory_source"], is_update=bool(d["is_update"]), original_memories=tuple(d.get("original_memories") or ()), importance=float(d.get("importance", 0.0)), timestamp=d.get("timestamp"), update_type=d.get("update_type", ""), ) def _gold_state_from_dict(d: dict[str, Any]) -> SessionGoldState: return SessionGoldState( session_id=d["session_id"], cumulative_gold_memories=tuple(_gold_point_from_dict(g) for g in d["cumulative_gold_memories"]), session_new_memories=tuple(_gold_point_from_dict(g) for g in d["session_new_memories"]), session_update_memories=tuple(_gold_point_from_dict(g) for g in d["session_update_memories"]), session_interference_memories=tuple(_gold_point_from_dict(g) for g in d["session_interference_memories"]), ) def _snapshot_record_from_dict(d: dict[str, Any]) -> MemorySnapshotRecord: return MemorySnapshotRecord( memory_id=d["memory_id"], text=d["text"], session_id=d["session_id"], status=d["status"], source=d.get("source"), raw_backend_id=d.get("raw_backend_id"), raw_backend_type=d.get("raw_backend_type"), metadata=d.get("metadata") or {}, ) def _delta_record_from_dict(d: dict[str, Any]) -> MemoryDeltaRecord: return MemoryDeltaRecord( session_id=d["session_id"], op=d["op"], text=d["text"], linked_previous=tuple(d.get("linked_previous") or ()), raw_backend_id=d.get("raw_backend_id"), metadata=d.get("metadata") or {}, ) def _retrieval_item_from_dict(d: dict[str, Any]) -> RetrievalItem: return RetrievalItem( rank=int(d["rank"]), memory_id=d["memory_id"], text=d["text"], score=float(d["score"]), raw_backend_id=d.get("raw_backend_id"), ) def _retrieval_record_from_dict(d: dict[str, Any]) -> RetrievalRecord: return RetrievalRecord( query=d["query"], top_k=int(d["top_k"]), items=[_retrieval_item_from_dict(i) for i in d["items"]], raw_trace=d.get("raw_trace") or {}, ) def _session_record_from_dict(d: dict[str, Any]) -> PipelineSessionRecord: return PipelineSessionRecord( sample_id=d["sample_id"], sample_uuid=d["sample_uuid"], session_id=d["session_id"], memory_snapshot=tuple(_snapshot_record_from_dict(s) for s in d["memory_snapshot"]), memory_delta=tuple(_delta_record_from_dict(dl) for dl in d["memory_delta"]), gold_state=_gold_state_from_dict(d["gold_state"]), ) def _qa_record_from_dict(d: dict[str, Any]) -> PipelineCheckpointQARecord: return PipelineCheckpointQARecord( sample_id=d["sample_id"], sample_uuid=d["sample_uuid"], checkpoint_id=d["checkpoint_id"], question=d["question"], gold_answer=d["gold_answer"], gold_evidence_memory_ids=tuple(d.get("gold_evidence_memory_ids") or ()), gold_evidence_contents=tuple(d.get("gold_evidence_contents") or ()), question_type=d["question_type"], question_type_abbrev=d["question_type_abbrev"], difficulty=d["difficulty"], retrieval=_retrieval_record_from_dict(d["retrieval"]), generated_answer=d["generated_answer"], cited_memories=tuple(d.get("cited_memories") or ()), ) def _read_jsonl(path: Path) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] with path.open("r", encoding="utf-8") as fh: for line in fh: line = line.strip() if line: rows.append(json.loads(line)) return rows def _load_pipeline_checkpoint( output_dir: Path, ) -> tuple[list[PipelineSessionRecord], list[PipelineCheckpointQARecord]]: """Restore pipeline records from checkpoint JSONL files.""" sess_path = output_dir / _CHECKPOINT_SESSIONS qa_path = output_dir / _CHECKPOINT_QA if not sess_path.exists() or not qa_path.exists(): raise SystemExit( f"Checkpoint files not found in {output_dir}. " f"Run without --eval-only first to generate them." ) session_records = [_session_record_from_dict(d) for d in _read_jsonl(sess_path)] qa_records = [_qa_record_from_dict(d) for d in _read_jsonl(qa_path)] return session_records, qa_records def _default_create_adapter(baseline_name: str) -> MemoryAdapter: from eval_framework.memory_adapters import registry as reg if baseline_name in reg.MEMGALLERY_NATIVE_REGISTRY: return reg.MEMGALLERY_NATIVE_REGISTRY[baseline_name]() if baseline_name in reg.EXTERNAL_ADAPTER_REGISTRY: return reg.EXTERNAL_ADAPTER_REGISTRY[baseline_name]() known = sorted( reg.MEMGALLERY_NATIVE_BASELINES | reg.EXTERNAL_ADAPTER_KEYS ) raise SystemExit( f"Unknown baseline {baseline_name!r}. " f"Expected one of: {', '.join(known)}" ) def _gold_echo_answer( q: NormalizedCheckpointQuestion, _retrieval: RetrievalRecord ) -> tuple[str, list[str]]: return q.gold_answer, [] def _parse_answer_json(raw: str) -> tuple[str, list[str]]: """Extract answer and cited_memories from the model's JSON response.""" # Try to parse as JSON first try: data = json.loads(raw) answer = str(data.get("answer", "")) cited = data.get("cited_memories", []) if isinstance(cited, list): return answer, [str(c) for c in cited] return answer, [] except (json.JSONDecodeError, TypeError): pass # Fallback: try to find JSON block in the response import re m = re.search(r"\{[\s\S]*\}", raw) if m: try: data = json.loads(m.group()) answer = str(data.get("answer", "")) cited = data.get("cited_memories", []) if isinstance(cited, list): return answer, [str(c) for c in cited] except (json.JSONDecodeError, TypeError): pass # Final fallback: treat entire response as the answer, no citations return raw.strip(), [] def build_default_answer_fn() -> Callable[ [NormalizedCheckpointQuestion, RetrievalRecord], tuple[str, list[str]] ]: api_key = os.getenv("OPENAI_API_KEY") if not api_key or OpenAI is None: return _gold_echo_answer client = OpenAI( api_key=api_key, base_url=os.getenv("OPENAI_BASE_URL"), ) model = os.getenv("OPENAI_MODEL") or "gpt-4o" temperature = float(os.getenv("OPENAI_TEMPERATURE", "0.0")) max_tokens = int(os.getenv("OPENAI_MAX_TOKENS", "1024")) def _answer( q: NormalizedCheckpointQuestion, retrieval: RetrievalRecord ) -> tuple[str, list[str]]: context_lines = [ f"[{item.rank}] {item.text}" for item in retrieval.items[: retrieval.top_k] ] context = "\n".join(context_lines) if context_lines else "No retrieved memories." prompt = ( "Answer the user's question using only the retrieved memories below. " "If the memories are insufficient, answer exactly: Not mentioned in memory.\n\n" "You MUST also list the specific memory passages you relied on to produce " "the answer. Copy the relevant text verbatim from the retrieved memories.\n\n" f"Question: {q.question}\n\n" f"Retrieved memories:\n{context}\n\n" 'Respond in JSON:\n' '{\n' ' "answer": "your concise answer",\n' ' "cited_memories": ["verbatim passage 1", "verbatim passage 2"]\n' '}\n' ) request_kwargs = rewrite_chat_completion_kwargs( { "model": model, "messages": [ { "role": "system", "content": ( "You answer benchmark questions using only supplied memory context. " "Be concise and do not invent missing facts. " "Always respond in the requested JSON format." ), }, {"role": "user", "content": prompt}, ], "temperature": temperature, "max_tokens": max_tokens, } ) response = client.chat.completions.create(**request_kwargs) raw = response.choices[0].message.content or "" return _parse_answer_json(raw) return _answer def config_from_namespace(ns: argparse.Namespace) -> EvalConfig: return EvalConfig( dataset_path=Path(ns.dataset).expanduser().resolve(), output_dir=Path(ns.output_dir).expanduser().resolve(), baseline=str(ns.baseline), smoke=bool(ns.smoke), dry_run=bool(ns.dry_run), ) def _record_to_json_obj(obj: Any) -> dict[str, Any]: return asdict(obj) def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as fh: for row in rows: fh.write(json.dumps(row, ensure_ascii=False) + "\n") def _write_json(path: Path, payload: dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text( json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8", ) def run_eval( config: EvalConfig, *, load_domain_bundle: Callable[[Path], DomainAV2AcademicBundle] = load_domain_a_v2_academic, create_adapter: Callable[[str], MemoryAdapter] | None = None, answer_fn: Callable | None = None, max_eval_workers: int = 5, eval_only: bool = False, ) -> None: """Load data, run pipeline (serial) + LLM eval (parallel).""" patch_openai_chat_completions() if config.dry_run: return out = config.output_dir out.mkdir(parents=True, exist_ok=True) if eval_only: # --- Resume from checkpoint --- print(f"[Eval-only] Loading pipeline checkpoint from {out}") session_records, qa_records = _load_pipeline_checkpoint(out) print(f"[Eval-only] Loaded {len(session_records)} sessions + {len(qa_records)} QA records") else: # --- Stage 1: Pipeline (serial — adapter is stateful) --- adapter_factory = create_adapter or _default_create_adapter bundle = load_domain_bundle(config.dataset_path) samples = bundle.samples[:1] if config.smoke else bundle.samples _answer = answer_fn if answer_fn is not None else build_default_answer_fn() session_records: list[PipelineSessionRecord] = [] qa_records: list[PipelineCheckpointQARecord] = [] print(f"[Pipeline] Running {len(samples)} sample(s) with baseline={config.baseline}") for i, sample in enumerate(samples): print(f" Sample {i + 1}/{len(samples)}: {sample.sample_id}") adapter = adapter_factory(config.baseline) sess, qa = run_domain_a_v2_sample( adapter, sample, answer_fn=_answer, ) session_records.extend(sess) qa_records.extend(qa) # --- Save checkpoint --- _write_jsonl(out / _CHECKPOINT_SESSIONS, [_record_to_json_obj(r) for r in session_records]) _write_jsonl(out / _CHECKPOINT_QA, [_record_to_json_obj(r) for r in qa_records]) print(f"[Checkpoint] Saved {len(session_records)} sessions + {len(qa_records)} QA to {out}") # --- Stage 2: Eval (parallel — each record is self-contained) --- print(f"[Eval] Evaluating {len(session_records)} sessions + {len(qa_records)} QA with LLM judge (workers={max_eval_workers})...") session_evals: list[dict[str, object] | None] = [None] * len(session_records) qa_evals: list[dict[str, object] | None] = [None] * len(qa_records) with ThreadPoolExecutor(max_workers=max_eval_workers) as pool: # Submit session evals session_futures = {} for idx, srec in enumerate(session_records): fut = pool.submit(evaluate_extraction, srec) session_futures[fut] = idx # Submit QA evals qa_futures = {} for idx, qrec in enumerate(qa_records): fut = pool.submit(evaluate_checkpoint_qa, qrec) qa_futures[fut] = idx # Collect session results done_sessions = 0 for fut in as_completed(session_futures): idx = session_futures[fut] try: session_evals[idx] = fut.result() except Exception as e: session_evals[idx] = {"error": str(e)} done_sessions += 1 if done_sessions % 10 == 0 or done_sessions == len(session_records): print(f" Sessions: {done_sessions}/{len(session_records)} done") # Collect QA results done_qa = 0 for fut in as_completed(qa_futures): idx = qa_futures[fut] try: qa_evals[idx] = fut.result() except Exception as e: qa_evals[idx] = {"error": str(e)} done_qa += 1 if done_qa % 20 == 0 or done_qa == len(qa_records): print(f" QA: {done_qa}/{len(qa_records)} done") # --- Stage 3: Aggregate + write --- agg = aggregate_metrics( config.baseline, session_evaluations=[e for e in session_evals if e is not None], qa_evaluations=[e for e in qa_evals if e is not None], ) session_rows = [] for srec, s_eval in zip(session_records, session_evals): row = _record_to_json_obj(srec) row["eval"] = s_eval session_rows.append(row) qa_rows = [] for qrec, q_eval in zip(qa_records, qa_evals): row = _record_to_json_obj(qrec) row["eval"] = q_eval qa_rows.append(row) _write_jsonl(out / "session_records.jsonl", session_rows) _write_jsonl(out / "qa_records.jsonl", qa_rows) _write_json(out / "aggregate_metrics.json", agg) print(f"\n[Done] Results written to {out}") print(f" Aggregate: {json.dumps(agg, indent=2)}") def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(prog="eval_framework") p.add_argument("--dataset", required=True) p.add_argument("--baseline", required=True) p.add_argument("--output-dir", default="eval_framework/results") p.add_argument("--smoke", action="store_true") p.add_argument("--dry-run", action="store_true") p.add_argument("--eval-only", action="store_true", help="Skip pipeline, load from checkpoint in output-dir.") p.add_argument("--max-eval-workers", type=int, default=5, help="Parallel threads for eval stage (default 5).") return p def main(argv: list[str] | None = None) -> None: parser = build_parser() args = parser.parse_args(argv) cfg = config_from_namespace(args) if cfg.dry_run: print(json.dumps(cfg.to_display_dict(), indent=2)) return eval_only = bool(args.eval_only) if not eval_only and not cfg.dataset_path.is_dir(): raise SystemExit(f"Dataset path is not a directory: {cfg.dataset_path}") run_eval(cfg, max_eval_workers=args.max_eval_workers, eval_only=eval_only) if __name__ == "__main__": main()