"""Helpers to map turns, backend memory dicts, and recall outputs into shared schemas.""" from __future__ import annotations from typing import Any, Mapping from eval_framework.datasets.schemas import ( MemorySnapshotRecord, NormalizedTurn, RetrievalItem, RetrievalRecord, ) def turn_to_observation_dict(turn: NormalizedTurn) -> dict[str, Any]: """Build a Mem-Gallery store observation from a normalized turn.""" parts: list[str] = [turn.text] for att in turn.attachments: parts.append(f"[{att.type}] {att.caption}") text = "\n".join(parts) obs: dict[str, Any] = {"text": text} if turn.timestamp: obs["timestamp"] = turn.timestamp obs["dialogue_id"] = f"{turn.session_id}:{turn.turn_index}" return obs def memory_element_text(element: Mapping[str, Any]) -> str: """Best-effort text extraction from a Mem-Gallery memory dict.""" raw = element.get("text", "") if isinstance(raw, list): return " ".join(str(x) for x in raw) if raw is None: base = "" else: base = str(raw) image = element.get("image") if isinstance(image, dict): cap = image.get("caption") if cap: base = f"{base}\n[image] {cap}".strip() return base def linear_element_to_snapshot( element: Mapping[str, Any], *, memory_id: str, session_id: str, source: str, status: str = "active", ) -> MemorySnapshotRecord: """Map a linear-storage memory dict into MemorySnapshotRecord.""" cid = element.get("counter_id") raw_id = str(cid) if cid is not None else memory_id return MemorySnapshotRecord( memory_id=memory_id, text=memory_element_text(element), session_id=session_id, status=status, source=source, raw_backend_id=raw_id, raw_backend_type="linear", metadata={}, ) def normalize_recall_to_retrieval( query: str, top_k: int, raw: Any, *, raw_trace: dict[str, Any] | None = None, ) -> RetrievalRecord: """Normalize Mem-Gallery recall outputs into RetrievalRecord.""" trace = dict(raw_trace or {}) items: list[RetrievalItem] = [] if isinstance(raw, str): items.append( RetrievalItem( rank=0, memory_id="memgallery:string_bundle", text=raw, score=1.0, raw_backend_id=None, ) ) elif isinstance(raw, list): for i, row in enumerate(raw[: max(0, top_k)]): if isinstance(row, dict): mid = row.get("counter_id") items.append( RetrievalItem( rank=i, memory_id=str(mid if mid is not None else i), text=memory_element_text(row), score=float(row.get("score", 1.0)), raw_backend_id=str(mid) if mid is not None else None, ) ) else: items.append( RetrievalItem( rank=i, memory_id=str(i), text=str(row), score=1.0, raw_backend_id=None, ) ) else: items.append( RetrievalItem( rank=0, memory_id="memgallery:object_bundle", text=str(raw), score=1.0, raw_backend_id=None, ) ) return RetrievalRecord(query=query, top_k=top_k, items=items[:top_k], raw_trace=trace)