File size: 5,446 Bytes
af28f6f 94407ab af28f6f 5bed6f3 af28f6f 94407ab af28f6f 94407ab af28f6f 94407ab 06e7b99 5bed6f3 06e7b99 5bed6f3 94407ab af28f6f 94407ab 5bed6f3 94407ab 3df66f9 94407ab 6b070cd 94407ab 1ef8a9e 94407ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import json
import random
from typing import Any, Dict, List, Optional
from datasets import Dataset, load_dataset
from loguru import logger
from src.config import DATASET_MAPPING, DEFAULT_DATASET_PREFIX, MODELS_PATH, TEST_TYPES
class DatasetManager:
"""Manages the loading and retrieval of evaluation datasets."""
def __init__(self):
self.datasets: Dict[str, Dataset] = {}
self.current_dataset: Optional[Dataset] = None
self.current_dataset_name: str = ""
self.current_type: str = TEST_TYPES[0]
def load_datasets(self) -> List[str]:
"""Load all available datasets based on test types."""
dataset_names = []
for test_type in TEST_TYPES:
try:
test_type_kebab = test_type.replace(" ", "-")
dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
logger.info(f"Loading dataset: {dataset_name}")
self.datasets[test_type] = load_dataset(
dataset_name,
split="train",
)
dataset_names.append(dataset_name)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_name}: {e}")
return dataset_names
def switch_dataset(self, test_type: str) -> None:
"""Switch to a different dataset based on test type."""
if test_type not in self.datasets:
logger.error(f"Dataset for test type '{test_type}' not loaded")
return
self.current_dataset = self.datasets[test_type]
test_type_kebab = test_type.replace(" ", "-")
self.current_dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
self.current_type = test_type
logger.info(f"Switched to dataset: {self.current_dataset_name}")
def get_random_example(self) -> Dict[str, Any]:
"""Get a random example from the current dataset."""
if not self.current_dataset:
raise ValueError("No dataset loaded")
idx = random.randint(0, len(self.current_dataset) - 1)
return self.current_dataset[idx]
def load_models() -> List[Dict[str, Any]]:
"""Load models from the models file."""
try:
with open(MODELS_PATH, "r") as f:
models = []
for line in f:
line = line.strip()
if line: # Skip empty lines
try:
models.append(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Skipping invalid JSON in line: {line}.",
)
return models
except Exception as e:
logger.error(f"Error loading models: {e}")
return []
def save_model(model: Dict[str, Any]) -> None:
"""Save a model to the models.jsonl file."""
with open(MODELS_PATH, "a") as f:
f.write(json.dumps(model) + "\n")
def get_random_example(test_type: str) -> Dict[str, str]:
"""Get a random example from the dataset for the given test type."""
try:
dataset_name = DATASET_MAPPING.get(test_type)
if not dataset_name:
logger.warning(
f"No dataset mapping found for test type: {test_type}",
)
return {
"text": f"Sample text for {test_type}",
"claim": f"Sample claim for {test_type}",
"input": f"Sample input for {test_type}",
"output": f"Sample output for {test_type}",
"assertion": f"Sample assertion for {test_type}",
}
# Load the dataset
logger.info(f"Loading dataset: {dataset_name}")
dataset = load_dataset(dataset_name)
# Get a random example from the dataset
if "train" in dataset:
examples = dataset["train"]
else:
# Use the first split available
examples = dataset[list(dataset.keys())[0]]
if len(examples) == 0:
logger.warning(f"No examples found in dataset: {dataset_name}")
return {
"text": f"No examples found for {test_type}",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}
# Get a random example
example = random.choice(examples)
# Map dataset fields to our internal format
result = {
"text": "",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}
# Map fields based on test type
if test_type == "grounding":
result["text"] = example.get("doc", "")
result["claim"] = example.get("claim", "")
elif test_type in ["prompt_injections", "safety"]:
result["text"] = example.get("text", "")
elif test_type == "policy":
result["input"] = example.get("input_text", "")
result["output"] = example.get("output_text", "")
result["assertion"] = example.get("assertion", "")
return result
except Exception as e:
logger.error(f"Error getting example for {test_type}: {e}")
return {
"text": f"Error getting example for {test_type}",
"claim": "",
"input": "",
"output": "",
"assertion": "",
}
|