EvalArena / src /data_manager.py
dror44's picture
hotfix - tsq table
1ef8a9e
import json
import random
from typing import Any, Dict, List, Optional
from datasets import Dataset, load_dataset
from loguru import logger
from src.config import DATASET_MAPPING, DEFAULT_DATASET_PREFIX, MODELS_PATH, TEST_TYPES
class DatasetManager:
"""Manages the loading and retrieval of evaluation datasets."""
def __init__(self):
self.datasets: Dict[str, Dataset] = {}
self.current_dataset: Optional[Dataset] = None
self.current_dataset_name: str = ""
self.current_type: str = TEST_TYPES[0]
def load_datasets(self) -> List[str]:
"""Load all available datasets based on test types."""
dataset_names = []
for test_type in TEST_TYPES:
try:
test_type_kebab = test_type.replace(" ", "-")
dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
logger.info(f"Loading dataset: {dataset_name}")
self.datasets[test_type] = load_dataset(
dataset_name,
split="train",
)
dataset_names.append(dataset_name)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_name}: {e}")
return dataset_names
def switch_dataset(self, test_type: str) -> None:
"""Switch to a different dataset based on test type."""
if test_type not in self.datasets:
logger.error(f"Dataset for test type '{test_type}' not loaded")
return
self.current_dataset = self.datasets[test_type]
test_type_kebab = test_type.replace(" ", "-")
self.current_dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
self.current_type = test_type
logger.info(f"Switched to dataset: {self.current_dataset_name}")
def get_random_example(self) -> Dict[str, Any]:
"""Get a random example from the current dataset."""
if not self.current_dataset:
raise ValueError("No dataset loaded")
idx = random.randint(0, len(self.current_dataset) - 1)
return self.current_dataset[idx]
def load_models() -> List[Dict[str, Any]]:
"""Load models from the models file."""
try:
with open(MODELS_PATH, "r") as f:
models = []
for line in f:
line = line.strip()
if line: # Skip empty lines
try:
models.append(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Skipping invalid JSON in line: {line}.",
)
return models
except Exception as e:
logger.error(f"Error loading models: {e}")
return []
def save_model(model: Dict[str, Any]) -> None:
"""Save a model to the models.jsonl file."""
with open(MODELS_PATH, "a") as f:
f.write(json.dumps(model) + "\n")
def get_random_example(test_type: str) -> Dict[str, str]:
"""Get a random example from the dataset for the given test type."""
try:
dataset_name = DATASET_MAPPING.get(test_type)
if not dataset_name:
logger.warning(
f"No dataset mapping found for test type: {test_type}",
)
return {
"text": f"Sample text for {test_type}",
"claim": f"Sample claim for {test_type}",
"input": f"Sample input for {test_type}",
"output": f"Sample output for {test_type}",
"assertion": f"Sample assertion for {test_type}",
}
# Load the dataset
logger.info(f"Loading dataset: {dataset_name}")
dataset = load_dataset(dataset_name)
# Get a random example from the dataset
if "train" in dataset:
examples = dataset["train"]
else:
# Use the first split available
examples = dataset[list(dataset.keys())[0]]
if len(examples) == 0:
logger.warning(f"No examples found in dataset: {dataset_name}")
return {
"text": f"No examples found for {test_type}",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}
# Get a random example
example = random.choice(examples)
# Map dataset fields to our internal format
result = {
"text": "",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}
# Map fields based on test type
if test_type == "grounding":
result["text"] = example.get("doc", "")
result["claim"] = example.get("claim", "")
elif test_type in ["prompt_injections", "safety"]:
result["text"] = example.get("text", "")
elif test_type == "policy":
result["input"] = example.get("input_text", "")
result["output"] = example.get("output_text", "")
result["assertion"] = example.get("assertion", "")
return result
except Exception as e:
logger.error(f"Error getting example for {test_type}: {e}")
return {
"text": f"Error getting example for {test_type}",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}