File size: 5,446 Bytes
af28f6f
 
 
 
 
 
 
94407ab
af28f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bed6f3
 
 
 
af28f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94407ab
af28f6f
 
94407ab
af28f6f
94407ab
 
 
 
06e7b99
5bed6f3
06e7b99
5bed6f3
94407ab
 
 
 
af28f6f
 
 
 
 
 
94407ab
 
 
 
 
 
 
5bed6f3
 
 
94407ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3df66f9
94407ab
6b070cd
94407ab
 
1ef8a9e
 
94407ab
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import random
from typing import Any, Dict, List, Optional

from datasets import Dataset, load_dataset
from loguru import logger

from src.config import DATASET_MAPPING, DEFAULT_DATASET_PREFIX, MODELS_PATH, TEST_TYPES


class DatasetManager:
    """Manages the loading and retrieval of evaluation datasets."""

    def __init__(self):
        self.datasets: Dict[str, Dataset] = {}
        self.current_dataset: Optional[Dataset] = None
        self.current_dataset_name: str = ""
        self.current_type: str = TEST_TYPES[0]

    def load_datasets(self) -> List[str]:
        """Load all available datasets based on test types."""
        dataset_names = []

        for test_type in TEST_TYPES:
            try:
                test_type_kebab = test_type.replace(" ", "-")
                dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
                logger.info(f"Loading dataset: {dataset_name}")
                self.datasets[test_type] = load_dataset(
                    dataset_name,
                    split="train",
                )
                dataset_names.append(dataset_name)
            except Exception as e:
                logger.error(f"Failed to load dataset {dataset_name}: {e}")

        return dataset_names

    def switch_dataset(self, test_type: str) -> None:
        """Switch to a different dataset based on test type."""
        if test_type not in self.datasets:
            logger.error(f"Dataset for test type '{test_type}' not loaded")
            return

        self.current_dataset = self.datasets[test_type]
        test_type_kebab = test_type.replace(" ", "-")
        self.current_dataset_name = f"{DEFAULT_DATASET_PREFIX}-{test_type_kebab}"
        self.current_type = test_type
        logger.info(f"Switched to dataset: {self.current_dataset_name}")

    def get_random_example(self) -> Dict[str, Any]:
        """Get a random example from the current dataset."""
        if not self.current_dataset:
            raise ValueError("No dataset loaded")

        idx = random.randint(0, len(self.current_dataset) - 1)
        return self.current_dataset[idx]


def load_models() -> List[Dict[str, Any]]:
    """Load models from the models file."""
    try:
        with open(MODELS_PATH, "r") as f:
            models = []
            for line in f:
                line = line.strip()
                if line:  # Skip empty lines
                    try:
                        models.append(json.loads(line))
                    except json.JSONDecodeError:
                        logger.warning(
                            f"Skipping invalid JSON in line: {line}.",
                        )
        return models
    except Exception as e:
        logger.error(f"Error loading models: {e}")
        return []


def save_model(model: Dict[str, Any]) -> None:
    """Save a model to the models.jsonl file."""
    with open(MODELS_PATH, "a") as f:
        f.write(json.dumps(model) + "\n")


def get_random_example(test_type: str) -> Dict[str, str]:
    """Get a random example from the dataset for the given test type."""
    try:
        dataset_name = DATASET_MAPPING.get(test_type)
        if not dataset_name:
            logger.warning(
                f"No dataset mapping found for test type: {test_type}",
            )
            return {
                "text": f"Sample text for {test_type}",
                "claim": f"Sample claim for {test_type}",
                "input": f"Sample input for {test_type}",
                "output": f"Sample output for {test_type}",
                "assertion": f"Sample assertion for {test_type}",
            }

        # Load the dataset
        logger.info(f"Loading dataset: {dataset_name}")
        dataset = load_dataset(dataset_name)

        # Get a random example from the dataset
        if "train" in dataset:
            examples = dataset["train"]
        else:
            # Use the first split available
            examples = dataset[list(dataset.keys())[0]]

        if len(examples) == 0:
            logger.warning(f"No examples found in dataset: {dataset_name}")
            return {
                "text": f"No examples found for {test_type}",
                "claim": "",
                "input": "",
                "output": "",
                "assertion": "",
            }

        # Get a random example
        example = random.choice(examples)

        # Map dataset fields to our internal format
        result = {
            "text": "",
            "claim": "",
            "input": "",
            "output": "",
            "assertion": "",
        }

        # Map fields based on test type
        if test_type == "grounding":
            result["text"] = example.get("doc", "")
            result["claim"] = example.get("claim", "")
        elif test_type in ["prompt_injections", "safety"]:
            result["text"] = example.get("text", "")
        elif test_type == "policy":
            result["input"] = example.get("input_text", "")
            result["output"] = example.get("output_text", "")
            result["assertion"] = example.get("assertion", "")

        return result

    except Exception as e:
        logger.error(f"Error getting example for {test_type}: {e}")
        return {
            "text": f"Error getting example for {test_type}",
            "claim": "",
            "input": "",
            "output": "",
            "assertion": "",
        }