Spaces:
Runtime error
Runtime error
| import os | |
| from copy import deepcopy | |
| from enum import Enum | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union | |
| import datasets | |
| import psutil | |
| import torch | |
| import transformers as tr | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from relik.common.log import get_console_logger, get_logger | |
| from relik.retriever.common.model_inputs import ModelInputs | |
| from relik.retriever.data.base.datasets import BaseDataset, IterableBaseDataset | |
| from relik.retriever.data.utils import HardNegativesManager | |
| console_logger = get_console_logger() | |
| logger = get_logger(__name__) | |
| class SubsampleStrategyEnum(Enum): | |
| NONE = "none" | |
| RANDOM = "random" | |
| IN_ORDER = "in_order" | |
| class GoldenRetrieverDataset: | |
| def __init__( | |
| self, | |
| name: str, | |
| path: Union[str, os.PathLike, List[str], List[os.PathLike]] = None, | |
| data: Any = None, | |
| tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, | |
| # passages: Union[str, os.PathLike, List[str]] = None, | |
| passage_batch_size: int = 32, | |
| question_batch_size: int = 32, | |
| max_positives: int = -1, | |
| max_negatives: int = 0, | |
| max_hard_negatives: int = 0, | |
| max_question_length: int = 256, | |
| max_passage_length: int = 64, | |
| shuffle: bool = False, | |
| subsample_strategy: Optional[str] = SubsampleStrategyEnum.NONE, | |
| subsample_portion: float = 0.1, | |
| num_proc: Optional[int] = None, | |
| load_from_cache_file: bool = True, | |
| keep_in_memory: bool = False, | |
| prefetch: bool = True, | |
| load_fn_kwargs: Optional[Dict[str, Any]] = None, | |
| batch_fn_kwargs: Optional[Dict[str, Any]] = None, | |
| collate_fn_kwargs: Optional[Dict[str, Any]] = None, | |
| ): | |
| if path is None and data is None: | |
| raise ValueError("Either `path` or `data` must be provided") | |
| if tokenizer is None: | |
| raise ValueError("A tokenizer must be provided") | |
| # dataset parameters | |
| self.name = name | |
| self.path = Path(path) or path | |
| if path is not None and not isinstance(self.path, Sequence): | |
| self.path = [self.path] | |
| # self.project_folder = Path(__file__).parent.parent.parent | |
| self.data = data | |
| # hyper-parameters | |
| self.passage_batch_size = passage_batch_size | |
| self.question_batch_size = question_batch_size | |
| self.max_positives = max_positives | |
| self.max_negatives = max_negatives | |
| self.max_hard_negatives = max_hard_negatives | |
| self.max_question_length = max_question_length | |
| self.max_passage_length = max_passage_length | |
| self.shuffle = shuffle | |
| self.num_proc = num_proc | |
| self.load_from_cache_file = load_from_cache_file | |
| self.keep_in_memory = keep_in_memory | |
| self.prefetch = prefetch | |
| self.tokenizer = tokenizer | |
| if isinstance(self.tokenizer, str): | |
| self.tokenizer = tr.AutoTokenizer.from_pretrained(self.tokenizer) | |
| self.padding_ops = { | |
| "input_ids": partial( | |
| self.pad_sequence, | |
| value=self.tokenizer.pad_token_id, | |
| ), | |
| "attention_mask": partial(self.pad_sequence, value=0), | |
| "token_type_ids": partial( | |
| self.pad_sequence, | |
| value=self.tokenizer.pad_token_type_id, | |
| ), | |
| } | |
| # check if subsample strategy is valid | |
| if subsample_strategy is not None: | |
| # subsample_strategy can be a string or a SubsampleStrategy | |
| if isinstance(subsample_strategy, str): | |
| try: | |
| subsample_strategy = SubsampleStrategyEnum(subsample_strategy) | |
| except ValueError: | |
| raise ValueError( | |
| f"Subsample strategy {subsample_strategy} is not valid. " | |
| f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
| ) | |
| if not isinstance(subsample_strategy, SubsampleStrategyEnum): | |
| raise ValueError( | |
| f"Subsample strategy {subsample_strategy} is not valid. " | |
| f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
| ) | |
| self.subsample_strategy = subsample_strategy | |
| self.subsample_portion = subsample_portion | |
| # load the dataset | |
| if data is None: | |
| self.data: Dataset = self.load( | |
| self.path, | |
| tokenizer=self.tokenizer, | |
| load_from_cache_file=load_from_cache_file, | |
| load_fn_kwargs=load_fn_kwargs, | |
| num_proc=num_proc, | |
| shuffle=shuffle, | |
| keep_in_memory=keep_in_memory, | |
| max_positives=max_positives, | |
| max_negatives=max_negatives, | |
| max_hard_negatives=max_hard_negatives, | |
| max_question_length=max_question_length, | |
| max_passage_length=max_passage_length, | |
| ) | |
| else: | |
| self.data: Dataset = data | |
| self.hn_manager: Optional[HardNegativesManager] = None | |
| # keep track of how many times the dataset has been iterated over | |
| self.number_of_complete_iterations = 0 | |
| def __repr__(self) -> str: | |
| return f"GoldenRetrieverDataset({self.name=}, {self.path=})" | |
| def __len__(self) -> int: | |
| raise NotImplementedError | |
| def __getitem__( | |
| self, index | |
| ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: | |
| raise NotImplementedError | |
| def to_torch_dataset(self, *args, **kwargs) -> torch.utils.data.Dataset: | |
| raise NotImplementedError | |
| def load( | |
| self, | |
| paths: Union[str, os.PathLike, List[str], List[os.PathLike]], | |
| tokenizer: tr.PreTrainedTokenizer = None, | |
| load_fn_kwargs: Dict = None, | |
| load_from_cache_file: bool = True, | |
| num_proc: Optional[int] = None, | |
| shuffle: bool = False, | |
| keep_in_memory: bool = True, | |
| max_positives: int = -1, | |
| max_negatives: int = -1, | |
| max_hard_negatives: int = -1, | |
| max_passages: int = -1, | |
| max_question_length: int = 256, | |
| max_passage_length: int = 64, | |
| *args, | |
| **kwargs, | |
| ) -> Any: | |
| # if isinstance(paths, Sequence): | |
| # paths = [self.project_folder / path for path in paths] | |
| # else: | |
| # paths = [self.project_folder / paths] | |
| # read the data and put it in a placeholder list | |
| for path in paths: | |
| if not path.exists(): | |
| raise ValueError(f"{path} does not exist") | |
| fn_kwargs = dict( | |
| tokenizer=tokenizer, | |
| max_positives=max_positives, | |
| max_negatives=max_negatives, | |
| max_hard_negatives=max_hard_negatives, | |
| max_passages=max_passages, | |
| max_question_length=max_question_length, | |
| max_passage_length=max_passage_length, | |
| ) | |
| if load_fn_kwargs is not None: | |
| fn_kwargs.update(load_fn_kwargs) | |
| if num_proc is None: | |
| num_proc = psutil.cpu_count(logical=False) | |
| # The data is a list of dictionaries, each dictionary is a sample | |
| # Each sample has the following keys: | |
| # - "question": the question | |
| # - "answers": a list of answers | |
| # - "positive_ctxs": a list of positive passages | |
| # - "negative_ctxs": a list of negative passages | |
| # - "hard_negative_ctxs": a list of hard negative passages | |
| # use the huggingface dataset library to load the data, by default it will load the | |
| # data in a dict with the key being "train". | |
| logger.info(f"Loading data for dataset {self.name}") | |
| data = load_dataset( | |
| "json", | |
| data_files=[str(p) for p in paths], # datasets needs str paths and not Path | |
| split="train", | |
| streaming=False, # TODO maybe we can make streaming work | |
| keep_in_memory=keep_in_memory, | |
| ) | |
| # add id if not present | |
| if isinstance(data, datasets.Dataset): | |
| data = data.add_column("sample_idx", range(len(data))) | |
| else: | |
| data = data.map( | |
| lambda x, idx: x.update({"sample_idx": idx}), with_indices=True | |
| ) | |
| map_kwargs = dict( | |
| function=self.load_fn, | |
| fn_kwargs=fn_kwargs, | |
| ) | |
| if isinstance(data, datasets.Dataset): | |
| map_kwargs.update( | |
| dict( | |
| load_from_cache_file=load_from_cache_file, | |
| keep_in_memory=keep_in_memory, | |
| num_proc=num_proc, | |
| desc="Loading data", | |
| ) | |
| ) | |
| # preprocess the data | |
| data = data.map(**map_kwargs) | |
| # shuffle the data | |
| if shuffle: | |
| data.shuffle(seed=42) | |
| return data | |
| def create_batches( | |
| data: Dataset, | |
| batch_fn: Callable, | |
| batch_fn_kwargs: Optional[Dict[str, Any]] = None, | |
| prefetch: bool = True, | |
| *args, | |
| **kwargs, | |
| ) -> Union[Iterable, List]: | |
| if not prefetch: | |
| # if we are streaming, we don't need to create batches right now | |
| # we will create them on the fly when we need them | |
| batched_data = ( | |
| batch | |
| for batch in batch_fn( | |
| data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) | |
| ) | |
| ) | |
| else: | |
| batched_data = [ | |
| batch | |
| for batch in tqdm( | |
| batch_fn( | |
| data, **(batch_fn_kwargs if batch_fn_kwargs is not None else {}) | |
| ), | |
| desc="Creating batches", | |
| ) | |
| ] | |
| return batched_data | |
| def collate_batches( | |
| batched_data: Union[Iterable, List], | |
| collate_fn: Callable, | |
| collate_fn_kwargs: Optional[Dict[str, Any]] = None, | |
| prefetch: bool = True, | |
| *args, | |
| **kwargs, | |
| ) -> Union[Iterable, List]: | |
| if not prefetch: | |
| collated_data = ( | |
| collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) | |
| for batch in batched_data | |
| ) | |
| else: | |
| collated_data = [ | |
| collate_fn(batch, **(collate_fn_kwargs if collate_fn_kwargs else {})) | |
| for batch in tqdm(batched_data, desc="Collating batches") | |
| ] | |
| return collated_data | |
| def load_fn(sample: Dict, *args, **kwargs) -> Dict: | |
| raise NotImplementedError | |
| def batch_fn(data: Dataset, *args, **kwargs) -> Any: | |
| raise NotImplementedError | |
| def collate_fn(batch: Any, *args, **kwargs) -> Any: | |
| raise NotImplementedError | |
| def pad_sequence( | |
| sequence: Union[List, torch.Tensor], | |
| length: int, | |
| value: Any = None, | |
| pad_to_left: bool = False, | |
| ) -> Union[List, torch.Tensor]: | |
| """ | |
| Pad the input to the specified length with the given value. | |
| Args: | |
| sequence (:obj:`List`, :obj:`torch.Tensor`): | |
| Element to pad, it can be either a :obj:`List` or a :obj:`torch.Tensor`. | |
| length (:obj:`int`, :obj:`str`, optional, defaults to :obj:`subtoken`): | |
| Length after pad. | |
| value (:obj:`Any`, optional): | |
| Value to use as padding. | |
| pad_to_left (:obj:`bool`, optional, defaults to :obj:`False`): | |
| If :obj:`True`, pads to the left, right otherwise. | |
| Returns: | |
| :obj:`List`, :obj:`torch.Tensor`: The padded sequence. | |
| """ | |
| padding = [value] * abs(length - len(sequence)) | |
| if isinstance(sequence, torch.Tensor): | |
| if len(sequence.shape) > 1: | |
| raise ValueError( | |
| f"Sequence tensor must be 1D. Current shape is `{len(sequence.shape)}`" | |
| ) | |
| padding = torch.as_tensor(padding) | |
| if pad_to_left: | |
| if isinstance(sequence, torch.Tensor): | |
| return torch.cat((padding, sequence), -1) | |
| return padding + sequence | |
| if isinstance(sequence, torch.Tensor): | |
| return torch.cat((sequence, padding), -1) | |
| return sequence + padding | |
| def convert_to_batch( | |
| self, samples: Any, *args, **kwargs | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Convert the list of samples to a batch. | |
| Args: | |
| samples (:obj:`List`): | |
| List of samples to convert to a batch. | |
| Returns: | |
| :obj:`Dict[str, torch.Tensor]`: The batch. | |
| """ | |
| # invert questions from list of dict to dict of list | |
| samples = {k: [d[k] for d in samples] for k in samples[0]} | |
| # get max length of questions | |
| max_len = max(len(x) for x in samples["input_ids"]) | |
| # pad the questions | |
| for key in samples: | |
| if key in self.padding_ops: | |
| samples[key] = torch.as_tensor( | |
| [self.padding_ops[key](b, max_len) for b in samples[key]] | |
| ) | |
| return samples | |
| def shuffle_data(self, seed: int = 42): | |
| self.data = self.data.shuffle(seed=seed) | |
| class InBatchNegativesDataset(GoldenRetrieverDataset): | |
| def __len__(self) -> int: | |
| if isinstance(self.data, datasets.Dataset): | |
| return len(self.data) | |
| def __getitem__( | |
| self, index | |
| ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: | |
| return self.data[index] | |
| def to_torch_dataset(self) -> torch.utils.data.Dataset: | |
| shuffle_this_time = self.shuffle | |
| if ( | |
| self.subsample_strategy | |
| and self.subsample_strategy != SubsampleStrategyEnum.NONE | |
| ): | |
| number_of_samples = int(len(self.data) * self.subsample_portion) | |
| if self.subsample_strategy == SubsampleStrategyEnum.RANDOM: | |
| logger.info( | |
| f"Random subsampling {number_of_samples} samples from {len(self.data)}" | |
| ) | |
| data = ( | |
| deepcopy(self.data) | |
| .shuffle(seed=42 + self.number_of_complete_iterations) | |
| .select(range(0, number_of_samples)) | |
| ) | |
| elif self.subsample_strategy == SubsampleStrategyEnum.IN_ORDER: | |
| # number_of_samples = int(len(self.data) * self.subsample_portion) | |
| already_selected = ( | |
| number_of_samples * self.number_of_complete_iterations | |
| ) | |
| logger.info( | |
| f"Subsampling {number_of_samples} samples out of {len(self.data)}" | |
| ) | |
| to_select = min(already_selected + number_of_samples, len(self.data)) | |
| logger.info( | |
| f"Portion of data selected: {already_selected} " f"to {to_select}" | |
| ) | |
| data = deepcopy(self.data).select(range(already_selected, to_select)) | |
| # don't shuffle the data if we are subsampling, and we have still not completed | |
| # one full iteration over the dataset | |
| if self.number_of_complete_iterations > 0: | |
| shuffle_this_time = False | |
| # reset the number of complete iterations | |
| if to_select >= len(self.data): | |
| # reset the number of complete iterations, | |
| # we have completed one full iteration over the dataset | |
| # the value is -1 because we want to start from 0 at the next iteration | |
| self.number_of_complete_iterations = -1 | |
| else: | |
| raise ValueError( | |
| f"Subsample strategy `{self.subsample_strategy}` is not valid. " | |
| f"Valid strategies are: {SubsampleStrategyEnum.__members__}" | |
| ) | |
| else: | |
| data = data = self.data | |
| # do we need to shuffle the data? | |
| if self.shuffle and shuffle_this_time: | |
| logger.info("Shuffling the data") | |
| data = data.shuffle(seed=42 + self.number_of_complete_iterations) | |
| batch_fn_kwargs = { | |
| "passage_batch_size": self.passage_batch_size, | |
| "question_batch_size": self.question_batch_size, | |
| "hard_negatives_manager": self.hn_manager, | |
| } | |
| batched_data = self.create_batches( | |
| data, | |
| batch_fn=self.batch_fn, | |
| batch_fn_kwargs=batch_fn_kwargs, | |
| prefetch=self.prefetch, | |
| ) | |
| batched_data = self.collate_batches( | |
| batched_data, self.collate_fn, prefetch=self.prefetch | |
| ) | |
| # increment the number of complete iterations | |
| self.number_of_complete_iterations += 1 | |
| if self.prefetch: | |
| return BaseDataset(name=self.name, data=batched_data) | |
| else: | |
| return IterableBaseDataset(name=self.name, data=batched_data) | |
| def load_fn( | |
| sample: Dict, | |
| tokenizer: tr.PreTrainedTokenizer, | |
| max_positives: int, | |
| max_negatives: int, | |
| max_hard_negatives: int, | |
| max_passages: int = -1, | |
| max_question_length: int = 256, | |
| max_passage_length: int = 128, | |
| *args, | |
| **kwargs, | |
| ) -> Dict: | |
| # remove duplicates and limit the number of passages | |
| positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) | |
| if max_positives != -1: | |
| positives = positives[:max_positives] | |
| negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) | |
| if max_negatives != -1: | |
| negatives = negatives[:max_negatives] | |
| hard_negatives = list( | |
| set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) | |
| ) | |
| if max_hard_negatives != -1: | |
| hard_negatives = hard_negatives[:max_hard_negatives] | |
| question = tokenizer( | |
| sample["question"], max_length=max_question_length, truncation=True | |
| ) | |
| passage = positives + negatives + hard_negatives | |
| if max_passages != -1: | |
| passage = passage[:max_passages] | |
| passage = tokenizer(passage, max_length=max_passage_length, truncation=True) | |
| # invert the passage data structure from a dict of lists to a list of dicts | |
| passage = [dict(zip(passage, t)) for t in zip(*passage.values())] | |
| output = dict( | |
| question=question, | |
| passage=passage, | |
| positives=positives, | |
| positive_pssgs=passage[: len(positives)], | |
| ) | |
| return output | |
| def batch_fn( | |
| data: Dataset, | |
| passage_batch_size: int, | |
| question_batch_size: int, | |
| hard_negatives_manager: Optional[HardNegativesManager] = None, | |
| *args, | |
| **kwargs, | |
| ) -> Dict[str, List[Dict[str, Any]]]: | |
| def split_batch( | |
| batch: Union[Dict[str, Any], ModelInputs], question_batch_size: int | |
| ) -> List[ModelInputs]: | |
| """ | |
| Split a batch into multiple batches of size `question_batch_size` while keeping | |
| the same number of passages. | |
| """ | |
| def split_fn(x): | |
| return [ | |
| x[i : i + question_batch_size] | |
| for i in range(0, len(x), question_batch_size) | |
| ] | |
| # split the sample_idx | |
| sample_idx = split_fn(batch["sample_idx"]) | |
| # split the questions | |
| questions = split_fn(batch["questions"]) | |
| # split the positives | |
| positives = split_fn(batch["positives"]) | |
| # split the positives_pssgs | |
| positives_pssgs = split_fn(batch["positives_pssgs"]) | |
| # collect the new batches | |
| batches = [] | |
| for i in range(len(questions)): | |
| batches.append( | |
| ModelInputs( | |
| dict( | |
| sample_idx=sample_idx[i], | |
| questions=questions[i], | |
| passages=batch["passages"], | |
| positives=positives[i], | |
| positives_pssgs=positives_pssgs[i], | |
| ) | |
| ) | |
| ) | |
| return batches | |
| batch = [] | |
| passages_in_batch = {} | |
| for sample in data: | |
| if len(passages_in_batch) >= passage_batch_size: | |
| # create the batch dict | |
| batch_dict = ModelInputs( | |
| dict( | |
| sample_idx=[s["sample_idx"] for s in batch], | |
| questions=[s["question"] for s in batch], | |
| passages=list(passages_in_batch.values()), | |
| positives_pssgs=[s["positive_pssgs"] for s in batch], | |
| positives=[s["positives"] for s in batch], | |
| ) | |
| ) | |
| # split the batch if needed | |
| if len(batch) > question_batch_size: | |
| for splited_batch in split_batch(batch_dict, question_batch_size): | |
| yield splited_batch | |
| else: | |
| yield batch_dict | |
| # reset batch | |
| batch = [] | |
| passages_in_batch = {} | |
| batch.append(sample) | |
| # yes it's a bit ugly but it works :) | |
| # count the number of passages in the batch and stop if we reach the limit | |
| # we use a set to avoid counting the same passage twice | |
| # we use a tuple because set doesn't support lists | |
| # we use input_ids as discriminator | |
| passages_in_batch.update( | |
| {tuple(passage["input_ids"]): passage for passage in sample["passage"]} | |
| ) | |
| # check for hard negatives and add with a probability of 0.1 | |
| if hard_negatives_manager is not None: | |
| if sample["sample_idx"] in hard_negatives_manager: | |
| passages_in_batch.update( | |
| { | |
| tuple(passage["input_ids"]): passage | |
| for passage in hard_negatives_manager.get( | |
| sample["sample_idx"] | |
| ) | |
| } | |
| ) | |
| # left over | |
| if len(batch) > 0: | |
| # create the batch dict | |
| batch_dict = ModelInputs( | |
| dict( | |
| sample_idx=[s["sample_idx"] for s in batch], | |
| questions=[s["question"] for s in batch], | |
| passages=list(passages_in_batch.values()), | |
| positives_pssgs=[s["positive_pssgs"] for s in batch], | |
| positives=[s["positives"] for s in batch], | |
| ) | |
| ) | |
| # split the batch if needed | |
| if len(batch) > question_batch_size: | |
| for splited_batch in split_batch(batch_dict, question_batch_size): | |
| yield splited_batch | |
| else: | |
| yield batch_dict | |
| def collate_fn(self, batch: Any, *args, **kwargs) -> Any: | |
| # convert questions and passages to a batch | |
| questions = self.convert_to_batch(batch.questions) | |
| passages = self.convert_to_batch(batch.passages) | |
| # build an index to map the position of the passage in the batch | |
| passage_index = {tuple(c["input_ids"]): i for i, c in enumerate(batch.passages)} | |
| # now we can create the labels | |
| labels = torch.zeros( | |
| questions["input_ids"].shape[0], passages["input_ids"].shape[0] | |
| ) | |
| # iterate over the questions and set the labels to 1 if the passage is positive | |
| for sample_idx in range(len(questions["input_ids"])): | |
| for pssg in batch["positives_pssgs"][sample_idx]: | |
| # get the index of the positive passage | |
| index = passage_index[tuple(pssg["input_ids"])] | |
| # set the label to 1 | |
| labels[sample_idx, index] = 1 | |
| model_inputs = ModelInputs( | |
| { | |
| "questions": questions, | |
| "passages": passages, | |
| "labels": labels, | |
| "positives": batch["positives"], | |
| "sample_idx": batch["sample_idx"], | |
| } | |
| ) | |
| return model_inputs | |
| class AidaInBatchNegativesDataset(InBatchNegativesDataset): | |
| def __init__(self, use_topics: bool = False, *args, **kwargs): | |
| if "load_fn_kwargs" not in kwargs: | |
| kwargs["load_fn_kwargs"] = {} | |
| kwargs["load_fn_kwargs"]["use_topics"] = use_topics | |
| super().__init__(*args, **kwargs) | |
| def load_fn( | |
| sample: Dict, | |
| tokenizer: tr.PreTrainedTokenizer, | |
| max_positives: int, | |
| max_negatives: int, | |
| max_hard_negatives: int, | |
| max_passages: int = -1, | |
| max_question_length: int = 256, | |
| max_passage_length: int = 128, | |
| use_topics: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> Dict: | |
| # remove duplicates and limit the number of passages | |
| positives = list(set([p["text"].strip() for p in sample["positive_ctxs"]])) | |
| if max_positives != -1: | |
| positives = positives[:max_positives] | |
| negatives = list(set([n["text"].strip() for n in sample["negative_ctxs"]])) | |
| if max_negatives != -1: | |
| negatives = negatives[:max_negatives] | |
| hard_negatives = list( | |
| set([h["text"].strip() for h in sample["hard_negative_ctxs"]]) | |
| ) | |
| if max_hard_negatives != -1: | |
| hard_negatives = hard_negatives[:max_hard_negatives] | |
| question = sample["question"] | |
| if "doc_topic" in sample and use_topics: | |
| question = tokenizer( | |
| question, | |
| sample["doc_topic"], | |
| max_length=max_question_length, | |
| truncation=True, | |
| ) | |
| else: | |
| question = tokenizer( | |
| question, max_length=max_question_length, truncation=True | |
| ) | |
| passage = positives + negatives + hard_negatives | |
| if max_passages != -1: | |
| passage = passage[:max_passages] | |
| passage = tokenizer(passage, max_length=max_passage_length, truncation=True) | |
| # invert the passage data structure from a dict of lists to a list of dicts | |
| passage = [dict(zip(passage, t)) for t in zip(*passage.values())] | |
| output = dict( | |
| question=question, | |
| passage=passage, | |
| positives=positives, | |
| positive_pssgs=passage[: len(positives)], | |
| ) | |
| return output | |