import json import random from typing import Any, Dict, List, Optional from datasets import Dataset, load_dataset from loguru import logger from src.config import DATASET_MAPPING, DEFAULT_DATASET_PREFIX, MODELS_PATH, TEST_TYPES class DatasetManager: """Manages the loading and retrieval of evaluation datasets.""" def __init__(self): self.datasets: Dict[str, Dataset] = {} self.current_dataset: Optional[Dataset] = None self.current_dataset_name: str = "" self.current_type: str = TEST_TYPES[0] def load_datasets(self) -> List[str]: """Load all available datasets based on test types.""" dataset_names = [] for test_type in TEST_TYPES: try: test_type_kebab = test_type.replace(" ", "-") dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}" logger.info(f"Loading dataset: {dataset_name}") self.datasets[test_type] = load_dataset( dataset_name, split="train", ) dataset_names.append(dataset_name) except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {e}") return dataset_names def switch_dataset(self, test_type: str) -> None: """Switch to a different dataset based on test type.""" if test_type not in self.datasets: logger.error(f"Dataset for test type '{test_type}' not loaded") return self.current_dataset = self.datasets[test_type] test_type_kebab = test_type.replace(" ", "-") self.current_dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}" self.current_type = test_type logger.info(f"Switched to dataset: {self.current_dataset_name}") def get_random_example(self) -> Dict[str, Any]: """Get a random example from the current dataset.""" if not self.current_dataset: raise ValueError("No dataset loaded") idx = random.randint(0, len(self.current_dataset) - 1) return self.current_dataset[idx] def load_models() -> List[Dict[str, Any]]: """Load models from the models file.""" try: with open(MODELS_PATH, "r") as f: models = [] for line in f: line = line.strip() if line: # Skip empty lines try: models.append(json.loads(line)) except json.JSONDecodeError: logger.warning( f"Skipping invalid JSON in line: {line}.", ) return models except Exception as e: logger.error(f"Error loading models: {e}") return [] def save_model(model: Dict[str, Any]) -> None: """Save a model to the models.jsonl file.""" with open(MODELS_PATH, "a") as f: f.write(json.dumps(model) + "\n") def get_random_example(test_type: str) -> Dict[str, str]: """Get a random example from the dataset for the given test type.""" try: dataset_name = DATASET_MAPPING.get(test_type) if not dataset_name: logger.warning( f"No dataset mapping found for test type: {test_type}", ) return { "text": f"Sample text for {test_type}", "claim": f"Sample claim for {test_type}", "input": f"Sample input for {test_type}", "output": f"Sample output for {test_type}", "assertion": f"Sample assertion for {test_type}", } # Load the dataset logger.info(f"Loading dataset: {dataset_name}") dataset = load_dataset(dataset_name) # Get a random example from the dataset if "train" in dataset: examples = dataset["train"] else: # Use the first split available examples = dataset[list(dataset.keys())[0]] if len(examples) == 0: logger.warning(f"No examples found in dataset: {dataset_name}") return { "text": f"No examples found for {test_type}", "claim": "", "input": "", "output": "", "assertion": "", } # Get a random example example = random.choice(examples) # Map dataset fields to our internal format result = { "text": "", "claim": "", "input": "", "output": "", "assertion": "", } # Map fields based on test type if test_type == "grounding": result["text"] = example.get("doc", "") result["claim"] = example.get("claim", "") elif test_type in ["prompt_injections", "safety"]: result["text"] = example.get("text", "") elif test_type == "policy": result["input"] = example.get("input_text", "") result["output"] = example.get("output_text", "") result["assertion"] = example.get("assertion", "") return result except Exception as e: logger.error(f"Error getting example for {test_type}: {e}") return { "text": f"Error getting example for {test_type}", "claim": "", "input": "", "output": "", "assertion": "", }