|
|
import json |
|
|
import random |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
from datasets import Dataset, load_dataset |
|
|
from loguru import logger |
|
|
|
|
|
from src.config import DATASET_MAPPING, DEFAULT_DATASET_PREFIX, MODELS_PATH, TEST_TYPES |
|
|
|
|
|
|
|
|
class DatasetManager: |
|
|
"""Manages the loading and retrieval of evaluation datasets.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.datasets: Dict[str, Dataset] = {} |
|
|
self.current_dataset: Optional[Dataset] = None |
|
|
self.current_dataset_name: str = "" |
|
|
self.current_type: str = TEST_TYPES[0] |
|
|
|
|
|
def load_datasets(self) -> List[str]: |
|
|
"""Load all available datasets based on test types.""" |
|
|
dataset_names = [] |
|
|
|
|
|
for test_type in TEST_TYPES: |
|
|
try: |
|
|
test_type_kebab = test_type.replace(" ", "-") |
|
|
dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}" |
|
|
logger.info(f"Loading dataset: {dataset_name}") |
|
|
self.datasets[test_type] = load_dataset( |
|
|
dataset_name, |
|
|
split="train", |
|
|
) |
|
|
dataset_names.append(dataset_name) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset {dataset_name}: {e}") |
|
|
|
|
|
return dataset_names |
|
|
|
|
|
def switch_dataset(self, test_type: str) -> None: |
|
|
"""Switch to a different dataset based on test type.""" |
|
|
if test_type not in self.datasets: |
|
|
logger.error(f"Dataset for test type '{test_type}' not loaded") |
|
|
return |
|
|
|
|
|
self.current_dataset = self.datasets[test_type] |
|
|
test_type_kebab = test_type.replace(" ", "-") |
|
|
self.current_dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}" |
|
|
self.current_type = test_type |
|
|
logger.info(f"Switched to dataset: {self.current_dataset_name}") |
|
|
|
|
|
def get_random_example(self) -> Dict[str, Any]: |
|
|
"""Get a random example from the current dataset.""" |
|
|
if not self.current_dataset: |
|
|
raise ValueError("No dataset loaded") |
|
|
|
|
|
idx = random.randint(0, len(self.current_dataset) - 1) |
|
|
return self.current_dataset[idx] |
|
|
|
|
|
|
|
|
def load_models() -> List[Dict[str, Any]]: |
|
|
"""Load models from the models file.""" |
|
|
try: |
|
|
with open(MODELS_PATH, "r") as f: |
|
|
models = [] |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
try: |
|
|
models.append(json.loads(line)) |
|
|
except json.JSONDecodeError: |
|
|
logger.warning( |
|
|
f"Skipping invalid JSON in line: {line}.", |
|
|
) |
|
|
return models |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading models: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def save_model(model: Dict[str, Any]) -> None: |
|
|
"""Save a model to the models.jsonl file.""" |
|
|
with open(MODELS_PATH, "a") as f: |
|
|
f.write(json.dumps(model) + "\n") |
|
|
|
|
|
|
|
|
def get_random_example(test_type: str) -> Dict[str, str]: |
|
|
"""Get a random example from the dataset for the given test type.""" |
|
|
try: |
|
|
dataset_name = DATASET_MAPPING.get(test_type) |
|
|
if not dataset_name: |
|
|
logger.warning( |
|
|
f"No dataset mapping found for test type: {test_type}", |
|
|
) |
|
|
return { |
|
|
"text": f"Sample text for {test_type}", |
|
|
"claim": f"Sample claim for {test_type}", |
|
|
"input": f"Sample input for {test_type}", |
|
|
"output": f"Sample output for {test_type}", |
|
|
"assertion": f"Sample assertion for {test_type}", |
|
|
} |
|
|
|
|
|
|
|
|
logger.info(f"Loading dataset: {dataset_name}") |
|
|
dataset = load_dataset(dataset_name) |
|
|
|
|
|
|
|
|
if "train" in dataset: |
|
|
examples = dataset["train"] |
|
|
else: |
|
|
|
|
|
examples = dataset[list(dataset.keys())[0]] |
|
|
|
|
|
if len(examples) == 0: |
|
|
logger.warning(f"No examples found in dataset: {dataset_name}") |
|
|
return { |
|
|
"text": f"No examples found for {test_type}", |
|
|
"claim": "", |
|
|
"input": "", |
|
|
"output": "", |
|
|
"assertion": "", |
|
|
} |
|
|
|
|
|
|
|
|
example = random.choice(examples) |
|
|
|
|
|
|
|
|
result = { |
|
|
"text": "", |
|
|
"claim": "", |
|
|
"input": "", |
|
|
"output": "", |
|
|
"assertion": "", |
|
|
} |
|
|
|
|
|
|
|
|
if test_type == "grounding": |
|
|
result["text"] = example.get("doc", "") |
|
|
result["claim"] = example.get("claim", "") |
|
|
elif test_type in ["prompt_injections", "safety"]: |
|
|
result["text"] = example.get("text", "") |
|
|
elif test_type == "policy": |
|
|
result["input"] = example.get("input_text", "") |
|
|
result["output"] = example.get("output_text", "") |
|
|
result["assertion"] = example.get("assertion", "") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error getting example for {test_type}: {e}") |
|
|
return { |
|
|
"text": f"Error getting example for {test_type}", |
|
|
"claim": "", |
|
|
"input": "", |
|
|
"output": "", |
|
|
"assertion": "", |
|
|
} |
|
|
|