Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import ( | |
| Any, | |
| Callable, | |
| Dict, | |
| Generator, | |
| Iterable, | |
| Iterator, | |
| List, | |
| NamedTuple, | |
| Optional, | |
| Tuple, | |
| Union, | |
| ) | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import IterableDataset | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, PreTrainedTokenizer | |
| from relik.reader.data.relik_reader_data_utils import ( | |
| add_noise_to_value, | |
| batchify, | |
| chunks, | |
| flatten, | |
| ) | |
| from relik.reader.data.relik_reader_sample import ( | |
| RelikReaderSample, | |
| load_relik_reader_samples, | |
| ) | |
| from relik.reader.utils.special_symbols import NME_SYMBOL | |
| logger = logging.getLogger(__name__) | |
| def preprocess_dataset( | |
| input_dataset: Iterable[dict], | |
| transformer_model: str, | |
| add_topic: bool, | |
| ) -> Iterable[dict]: | |
| tokenizer = AutoTokenizer.from_pretrained(transformer_model) | |
| for dataset_elem in tqdm(input_dataset, desc="Preprocessing input dataset"): | |
| if len(dataset_elem["tokens"]) == 0: | |
| print( | |
| f"Dataset element with doc id: {dataset_elem['doc_id']}", | |
| f"and offset {dataset_elem['offset']} does not contain any token", | |
| "Skipping it", | |
| ) | |
| continue | |
| new_dataset_elem = dict( | |
| doc_id=dataset_elem["doc_id"], | |
| offset=dataset_elem["offset"], | |
| ) | |
| tokenization_out = tokenizer( | |
| dataset_elem["tokens"], | |
| return_offsets_mapping=True, | |
| add_special_tokens=False, | |
| ) | |
| window_tokens = tokenization_out.input_ids | |
| window_tokens = flatten(window_tokens) | |
| offsets_mapping = [ | |
| [ | |
| ( | |
| ss + dataset_elem["token2char_start"][str(i)], | |
| se + dataset_elem["token2char_start"][str(i)], | |
| ) | |
| for ss, se in tokenization_out.offset_mapping[i] | |
| ] | |
| for i in range(len(dataset_elem["tokens"])) | |
| ] | |
| offsets_mapping = flatten(offsets_mapping) | |
| assert len(offsets_mapping) == len(window_tokens) | |
| window_tokens = ( | |
| [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] | |
| ) | |
| topic_offset = 0 | |
| if add_topic: | |
| topic_tokens = tokenizer( | |
| dataset_elem["doc_topic"], add_special_tokens=False | |
| ).input_ids | |
| topic_offset = len(topic_tokens) | |
| new_dataset_elem["topic_tokens"] = topic_offset | |
| window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] | |
| new_dataset_elem.update( | |
| dict( | |
| tokens=window_tokens, | |
| token2char_start={ | |
| str(i): s | |
| for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) | |
| }, | |
| token2char_end={ | |
| str(i): e | |
| for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) | |
| }, | |
| window_candidates=dataset_elem["window_candidates"], | |
| window_candidates_scores=dataset_elem.get("window_candidates_scores"), | |
| ) | |
| ) | |
| if "window_labels" in dataset_elem: | |
| window_labels = [ | |
| (s, e, l.replace("_", " ")) for s, e, l in dataset_elem["window_labels"] | |
| ] | |
| new_dataset_elem["window_labels"] = window_labels | |
| if not all( | |
| [ | |
| s in new_dataset_elem["token2char_start"].values() | |
| for s, _, _ in new_dataset_elem["window_labels"] | |
| ] | |
| ): | |
| print( | |
| "Mismatching token start char mapping with labels", | |
| new_dataset_elem["token2char_start"], | |
| new_dataset_elem["window_labels"], | |
| dataset_elem["tokens"], | |
| ) | |
| continue | |
| if not all( | |
| [ | |
| e in new_dataset_elem["token2char_end"].values() | |
| for _, e, _ in new_dataset_elem["window_labels"] | |
| ] | |
| ): | |
| print( | |
| "Mismatching token end char mapping with labels", | |
| new_dataset_elem["token2char_end"], | |
| new_dataset_elem["window_labels"], | |
| dataset_elem["tokens"], | |
| ) | |
| continue | |
| yield new_dataset_elem | |
| def preprocess_sample( | |
| relik_sample: RelikReaderSample, | |
| tokenizer, | |
| lowercase_policy: float, | |
| add_topic: bool = False, | |
| ) -> None: | |
| if len(relik_sample.tokens) == 0: | |
| return | |
| if lowercase_policy > 0: | |
| lc_tokens = np.random.uniform(0, 1, len(relik_sample.tokens)) < lowercase_policy | |
| relik_sample.tokens = [ | |
| t.lower() if lc else t for t, lc in zip(relik_sample.tokens, lc_tokens) | |
| ] | |
| tokenization_out = tokenizer( | |
| relik_sample.tokens, | |
| return_offsets_mapping=True, | |
| add_special_tokens=False, | |
| ) | |
| window_tokens = tokenization_out.input_ids | |
| window_tokens = flatten(window_tokens) | |
| offsets_mapping = [ | |
| [ | |
| ( | |
| ss + relik_sample.token2char_start[str(i)], | |
| se + relik_sample.token2char_start[str(i)], | |
| ) | |
| for ss, se in tokenization_out.offset_mapping[i] | |
| ] | |
| for i in range(len(relik_sample.tokens)) | |
| ] | |
| offsets_mapping = flatten(offsets_mapping) | |
| assert len(offsets_mapping) == len(window_tokens) | |
| window_tokens = [tokenizer.cls_token_id] + window_tokens + [tokenizer.sep_token_id] | |
| topic_offset = 0 | |
| if add_topic: | |
| topic_tokens = tokenizer( | |
| relik_sample.doc_topic, add_special_tokens=False | |
| ).input_ids | |
| topic_offset = len(topic_tokens) | |
| relik_sample.topic_tokens = topic_offset | |
| window_tokens = window_tokens[:1] + topic_tokens + window_tokens[1:] | |
| relik_sample._d.update( | |
| dict( | |
| tokens=window_tokens, | |
| token2char_start={ | |
| str(i): s | |
| for i, (s, _) in enumerate(offsets_mapping, start=topic_offset) | |
| }, | |
| token2char_end={ | |
| str(i): e | |
| for i, (_, e) in enumerate(offsets_mapping, start=topic_offset) | |
| }, | |
| ) | |
| ) | |
| if "window_labels" in relik_sample._d: | |
| relik_sample.window_labels = [ | |
| (s, e, l.replace("_", " ")) for s, e, l in relik_sample.window_labels | |
| ] | |
| class TokenizationOutput(NamedTuple): | |
| input_ids: torch.Tensor | |
| attention_mask: torch.Tensor | |
| token_type_ids: torch.Tensor | |
| prediction_mask: torch.Tensor | |
| special_symbols_mask: torch.Tensor | |
| class RelikDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| dataset_path: Optional[str], | |
| materialize_samples: bool, | |
| transformer_model: Union[str, PreTrainedTokenizer], | |
| special_symbols: List[str], | |
| shuffle_candidates: Optional[Union[bool, float]] = False, | |
| for_inference: bool = False, | |
| noise_param: float = 0.1, | |
| sorting_fields: Optional[str] = None, | |
| tokens_per_batch: int = 2048, | |
| batch_size: int = None, | |
| max_batch_size: int = 128, | |
| section_size: int = 50_000, | |
| prebatch: bool = True, | |
| random_drop_gold_candidates: float = 0.0, | |
| use_nme: bool = True, | |
| max_subwords_per_candidate: bool = 22, | |
| mask_by_instances: bool = False, | |
| min_length: int = 5, | |
| max_length: int = 2048, | |
| model_max_length: int = 1000, | |
| split_on_cand_overload: bool = True, | |
| skip_empty_training_samples: bool = False, | |
| drop_last: bool = False, | |
| samples: Optional[Iterator[RelikReaderSample]] = None, | |
| lowercase_policy: float = 0.0, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.dataset_path = dataset_path | |
| self.materialize_samples = materialize_samples | |
| self.samples: Optional[List[RelikReaderSample]] = None | |
| if self.materialize_samples: | |
| self.samples = list() | |
| if isinstance(transformer_model, str): | |
| self.tokenizer = self._build_tokenizer(transformer_model, special_symbols) | |
| else: | |
| self.tokenizer = transformer_model | |
| self.special_symbols = special_symbols | |
| self.shuffle_candidates = shuffle_candidates | |
| self.for_inference = for_inference | |
| self.noise_param = noise_param | |
| self.batching_fields = ["input_ids"] | |
| self.sorting_fields = ( | |
| sorting_fields if sorting_fields is not None else self.batching_fields | |
| ) | |
| self.tokens_per_batch = tokens_per_batch | |
| self.batch_size = batch_size | |
| self.max_batch_size = max_batch_size | |
| self.section_size = section_size | |
| self.prebatch = prebatch | |
| self.random_drop_gold_candidates = random_drop_gold_candidates | |
| self.use_nme = use_nme | |
| self.max_subwords_per_candidate = max_subwords_per_candidate | |
| self.mask_by_instances = mask_by_instances | |
| self.min_length = min_length | |
| self.max_length = max_length | |
| self.model_max_length = ( | |
| model_max_length | |
| if model_max_length < self.tokenizer.model_max_length | |
| else self.tokenizer.model_max_length | |
| ) | |
| # retrocompatibility workaround | |
| self.transformer_model = ( | |
| transformer_model | |
| if isinstance(transformer_model, str) | |
| else transformer_model.name_or_path | |
| ) | |
| self.split_on_cand_overload = split_on_cand_overload | |
| self.skip_empty_training_samples = skip_empty_training_samples | |
| self.drop_last = drop_last | |
| self.lowercase_policy = lowercase_policy | |
| self.samples = samples | |
| def _build_tokenizer(self, transformer_model: str, special_symbols: List[str]): | |
| return AutoTokenizer.from_pretrained( | |
| transformer_model, | |
| additional_special_tokens=[ss for ss in special_symbols], | |
| add_prefix_space=True, | |
| ) | |
| def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: | |
| fields_batchers = { | |
| "input_ids": lambda x: batchify( | |
| x, padding_value=self.tokenizer.pad_token_id | |
| ), | |
| "attention_mask": lambda x: batchify(x, padding_value=0), | |
| "token_type_ids": lambda x: batchify(x, padding_value=0), | |
| "prediction_mask": lambda x: batchify(x, padding_value=1), | |
| "global_attention": lambda x: batchify(x, padding_value=0), | |
| "token2word": None, | |
| "sample": None, | |
| "special_symbols_mask": lambda x: batchify(x, padding_value=False), | |
| "start_labels": lambda x: batchify(x, padding_value=-100), | |
| "end_labels": lambda x: batchify(x, padding_value=-100), | |
| "predictable_candidates_symbols": None, | |
| "predictable_candidates": None, | |
| "patch_offset": None, | |
| "optimus_labels": None, | |
| } | |
| if "roberta" in self.transformer_model: | |
| del fields_batchers["token_type_ids"] | |
| return fields_batchers | |
| def _build_input_ids( | |
| self, sentence_input_ids: List[int], candidates_input_ids: List[List[int]] | |
| ) -> List[int]: | |
| return ( | |
| [self.tokenizer.cls_token_id] | |
| + sentence_input_ids | |
| + [self.tokenizer.sep_token_id] | |
| + flatten(candidates_input_ids) | |
| + [self.tokenizer.sep_token_id] | |
| ) | |
| def _get_special_symbols_mask(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| special_symbols_mask = input_ids >= ( | |
| len(self.tokenizer) - len(self.special_symbols) | |
| ) | |
| special_symbols_mask[0] = True | |
| return special_symbols_mask | |
| def _build_tokenizer_essentials( | |
| self, input_ids, original_sequence, sample | |
| ) -> TokenizationOutput: | |
| input_ids = torch.tensor(input_ids, dtype=torch.long) | |
| attention_mask = torch.ones_like(input_ids) | |
| total_sequence_len = len(input_ids) | |
| predictable_sentence_len = len(original_sequence) | |
| # token type ids | |
| token_type_ids = torch.cat( | |
| [ | |
| input_ids.new_zeros( | |
| predictable_sentence_len + 2 | |
| ), # original sentence bpes + CLS and SEP | |
| input_ids.new_ones(total_sequence_len - predictable_sentence_len - 2), | |
| ] | |
| ) | |
| # prediction mask -> boolean on tokens that are predictable | |
| prediction_mask = torch.tensor( | |
| [1] | |
| + ([0] * predictable_sentence_len) | |
| + ([1] * (total_sequence_len - predictable_sentence_len - 1)) | |
| ) | |
| # add topic tokens to the prediction mask so that they cannot be predicted | |
| # or optimized during training | |
| topic_tokens = getattr(sample, "topic_tokens", None) | |
| if topic_tokens is not None: | |
| prediction_mask[1 : 1 + topic_tokens] = 1 | |
| # If mask by instances is active the prediction mask is applied to everything | |
| # that is not indicated as an instance in the training set. | |
| if self.mask_by_instances: | |
| char_start2token = { | |
| cs: int(tok) for tok, cs in sample.token2char_start.items() | |
| } | |
| char_end2token = {ce: int(tok) for tok, ce in sample.token2char_end.items()} | |
| instances_mask = torch.ones_like(prediction_mask) | |
| for _, span_info in sample.instance_id2span_data.items(): | |
| span_info = span_info[0] | |
| token_start = char_start2token[span_info[0]] + 1 # +1 for the CLS | |
| token_end = char_end2token[span_info[1]] + 1 # +1 for the CLS | |
| instances_mask[token_start : token_end + 1] = 0 | |
| prediction_mask += instances_mask | |
| prediction_mask[prediction_mask > 1] = 1 | |
| assert len(prediction_mask) == len(input_ids) | |
| # special symbols mask | |
| special_symbols_mask = self._get_special_symbols_mask(input_ids) | |
| return TokenizationOutput( | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| prediction_mask, | |
| special_symbols_mask, | |
| ) | |
| def _build_labels( | |
| self, | |
| sample, | |
| tokenization_output: TokenizationOutput, | |
| predictable_candidates: List[str], | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| start_labels = [0] * len(tokenization_output.input_ids) | |
| end_labels = [0] * len(tokenization_output.input_ids) | |
| char_start2token = {v: int(k) for k, v in sample.token2char_start.items()} | |
| char_end2token = {v: int(k) for k, v in sample.token2char_end.items()} | |
| for cs, ce, gold_candidate_title in sample.window_labels: | |
| if gold_candidate_title not in predictable_candidates: | |
| if self.use_nme: | |
| gold_candidate_title = NME_SYMBOL | |
| else: | |
| continue | |
| # +1 is to account for the CLS token | |
| start_bpe = char_start2token[cs] + 1 | |
| end_bpe = char_end2token[ce] + 1 | |
| class_index = predictable_candidates.index(gold_candidate_title) | |
| if ( | |
| start_labels[start_bpe] == 0 and end_labels[end_bpe] == 0 | |
| ): # prevent from having entities that ends with the same label | |
| start_labels[start_bpe] = class_index + 1 # +1 for the NONE class | |
| end_labels[end_bpe] = class_index + 1 # +1 for the NONE class | |
| else: | |
| print( | |
| "Found entity with the same last subword, it will not be included." | |
| ) | |
| print( | |
| cs, | |
| ce, | |
| gold_candidate_title, | |
| start_labels, | |
| end_labels, | |
| sample.doc_id, | |
| ) | |
| ignored_labels_indices = tokenization_output.prediction_mask == 1 | |
| start_labels = torch.tensor(start_labels, dtype=torch.long) | |
| start_labels[ignored_labels_indices] = -100 | |
| end_labels = torch.tensor(end_labels, dtype=torch.long) | |
| end_labels[ignored_labels_indices] = -100 | |
| return start_labels, end_labels | |
| def produce_sample_bag( | |
| self, sample, predictable_candidates: List[str], candidates_starting_offset: int | |
| ) -> Optional[Tuple[dict, list, int]]: | |
| # input sentence tokenization | |
| input_subwords = sample.tokens[1:-1] # removing special tokens | |
| candidates_symbols = self.special_symbols[candidates_starting_offset:] | |
| predictable_candidates = list(predictable_candidates) | |
| original_predictable_candidates = list(predictable_candidates) | |
| # add NME as a possible candidate | |
| if self.use_nme: | |
| predictable_candidates.insert(0, NME_SYMBOL) | |
| # candidates encoding | |
| candidates_symbols = candidates_symbols[: len(predictable_candidates)] | |
| candidates_encoding_result = self.tokenizer.batch_encode_plus( | |
| [ | |
| "{} {}".format(cs, ct) if ct != NME_SYMBOL else NME_SYMBOL | |
| for cs, ct in zip(candidates_symbols, predictable_candidates) | |
| ], | |
| add_special_tokens=False, | |
| ).input_ids | |
| if ( | |
| self.max_subwords_per_candidate is not None | |
| and self.max_subwords_per_candidate > 0 | |
| ): | |
| candidates_encoding_result = [ | |
| cer[: self.max_subwords_per_candidate] | |
| for cer in candidates_encoding_result | |
| ] | |
| # drop candidates if the number of input tokens is too long for the model | |
| if ( | |
| sum(map(len, candidates_encoding_result)) | |
| + len(input_subwords) | |
| + 20 # + 20 special tokens | |
| > self.model_max_length | |
| ): | |
| acceptable_tokens_from_candidates = ( | |
| self.model_max_length - 20 - len(input_subwords) | |
| ) | |
| i = 0 | |
| cum_len = 0 | |
| while ( | |
| cum_len + len(candidates_encoding_result[i]) | |
| < acceptable_tokens_from_candidates | |
| ): | |
| cum_len += len(candidates_encoding_result[i]) | |
| i += 1 | |
| candidates_encoding_result = candidates_encoding_result[:i] | |
| candidates_symbols = candidates_symbols[:i] | |
| predictable_candidates = predictable_candidates[:i] | |
| # final input_ids build | |
| input_ids = self._build_input_ids( | |
| sentence_input_ids=input_subwords, | |
| candidates_input_ids=candidates_encoding_result, | |
| ) | |
| # complete input building (e.g. attention / prediction mask) | |
| tokenization_output = self._build_tokenizer_essentials( | |
| input_ids, input_subwords, sample | |
| ) | |
| output_dict = { | |
| "input_ids": tokenization_output.input_ids, | |
| "attention_mask": tokenization_output.attention_mask, | |
| "token_type_ids": tokenization_output.token_type_ids, | |
| "prediction_mask": tokenization_output.prediction_mask, | |
| "special_symbols_mask": tokenization_output.special_symbols_mask, | |
| "sample": sample, | |
| "predictable_candidates_symbols": candidates_symbols, | |
| "predictable_candidates": predictable_candidates, | |
| } | |
| # labels creation | |
| if sample.window_labels is not None: | |
| start_labels, end_labels = self._build_labels( | |
| sample, | |
| tokenization_output, | |
| predictable_candidates, | |
| ) | |
| output_dict.update(start_labels=start_labels, end_labels=end_labels) | |
| if ( | |
| "roberta" in self.transformer_model | |
| or "longformer" in self.transformer_model | |
| ): | |
| del output_dict["token_type_ids"] | |
| predictable_candidates_set = set(predictable_candidates) | |
| remaining_candidates = [ | |
| candidate | |
| for candidate in original_predictable_candidates | |
| if candidate not in predictable_candidates_set | |
| ] | |
| total_used_candidates = ( | |
| candidates_starting_offset | |
| + len(predictable_candidates) | |
| - (1 if self.use_nme else 0) | |
| ) | |
| if self.use_nme: | |
| assert predictable_candidates[0] == NME_SYMBOL | |
| return output_dict, remaining_candidates, total_used_candidates | |
| def __iter__(self): | |
| dataset_iterator = self.dataset_iterator_func() | |
| current_dataset_elements = [] | |
| i = None | |
| for i, dataset_elem in enumerate(dataset_iterator, start=1): | |
| if ( | |
| self.section_size is not None | |
| and len(current_dataset_elements) == self.section_size | |
| ): | |
| for batch in self.materialize_batches(current_dataset_elements): | |
| yield batch | |
| current_dataset_elements = [] | |
| current_dataset_elements.append(dataset_elem) | |
| if i % 50_000 == 0: | |
| logger.info(f"Processed: {i} number of elements") | |
| if len(current_dataset_elements) != 0: | |
| for batch in self.materialize_batches(current_dataset_elements): | |
| yield batch | |
| if i is not None: | |
| logger.info(f"Dataset finished: {i} number of elements processed") | |
| else: | |
| logger.warning("Dataset empty") | |
| def dataset_iterator_func(self): | |
| skipped_instances = 0 | |
| data_samples = ( | |
| load_relik_reader_samples(self.dataset_path) | |
| if self.samples is None | |
| else self.samples | |
| ) | |
| for sample in data_samples: | |
| preprocess_sample( | |
| sample, self.tokenizer, lowercase_policy=self.lowercase_policy | |
| ) | |
| current_patch = 0 | |
| sample_bag, used_candidates = None, None | |
| remaining_candidates = list(sample.window_candidates) | |
| if not self.for_inference: | |
| # randomly drop gold candidates at training time | |
| if ( | |
| self.random_drop_gold_candidates > 0.0 | |
| and np.random.uniform() < self.random_drop_gold_candidates | |
| and len(set(ct for _, _, ct in sample.window_labels)) > 1 | |
| ): | |
| # selecting candidates to drop | |
| np.random.shuffle(sample.window_labels) | |
| n_dropped_candidates = np.random.randint( | |
| 0, len(sample.window_labels) - 1 | |
| ) | |
| dropped_candidates = [ | |
| label_elem[-1] | |
| for label_elem in sample.window_labels[:n_dropped_candidates] | |
| ] | |
| dropped_candidates = set(dropped_candidates) | |
| # saving NMEs because they should not be dropped | |
| if NME_SYMBOL in dropped_candidates: | |
| dropped_candidates.remove(NME_SYMBOL) | |
| # sample update | |
| sample.window_labels = [ | |
| (s, e, _l) | |
| if _l not in dropped_candidates | |
| else (s, e, NME_SYMBOL) | |
| for s, e, _l in sample.window_labels | |
| ] | |
| remaining_candidates = [ | |
| wc | |
| for wc in remaining_candidates | |
| if wc not in dropped_candidates | |
| ] | |
| # shuffle candidates | |
| if ( | |
| isinstance(self.shuffle_candidates, bool) | |
| and self.shuffle_candidates | |
| ) or ( | |
| isinstance(self.shuffle_candidates, float) | |
| and np.random.uniform() < self.shuffle_candidates | |
| ): | |
| np.random.shuffle(remaining_candidates) | |
| while len(remaining_candidates) != 0: | |
| sample_bag = self.produce_sample_bag( | |
| sample, | |
| predictable_candidates=remaining_candidates, | |
| candidates_starting_offset=used_candidates | |
| if used_candidates is not None | |
| else 0, | |
| ) | |
| if sample_bag is not None: | |
| sample_bag, remaining_candidates, used_candidates = sample_bag | |
| if ( | |
| self.for_inference | |
| or not self.skip_empty_training_samples | |
| or ( | |
| ( | |
| sample_bag.get("start_labels") is not None | |
| and torch.any(sample_bag["start_labels"] > 1).item() | |
| ) | |
| or ( | |
| sample_bag.get("optimus_labels") is not None | |
| and len(sample_bag["optimus_labels"]) > 0 | |
| ) | |
| ) | |
| ): | |
| sample_bag["patch_offset"] = current_patch | |
| current_patch += 1 | |
| yield sample_bag | |
| else: | |
| skipped_instances += 1 | |
| if skipped_instances % 1000 == 0 and skipped_instances != 0: | |
| logger.info( | |
| f"Skipped {skipped_instances} instances since they did not have any gold labels..." | |
| ) | |
| # Just use the first fitting candidates if split on | |
| # cand is not True | |
| if not self.split_on_cand_overload: | |
| break | |
| def preshuffle_elements(self, dataset_elements: List): | |
| # This shuffling is done so that when using the sorting function, | |
| # if it is deterministic given a collection and its order, we will | |
| # make the whole operation not deterministic anymore. | |
| # Basically, the aim is not to build every time the same batches. | |
| if not self.for_inference: | |
| dataset_elements = np.random.permutation(dataset_elements) | |
| sorting_fn = ( | |
| lambda elem: add_noise_to_value( | |
| sum(len(elem[k]) for k in self.sorting_fields), | |
| noise_param=self.noise_param, | |
| ) | |
| if not self.for_inference | |
| else sum(len(elem[k]) for k in self.sorting_fields) | |
| ) | |
| dataset_elements = sorted(dataset_elements, key=sorting_fn) | |
| if self.for_inference: | |
| return dataset_elements | |
| ds = list(chunks(dataset_elements, 64)) | |
| np.random.shuffle(ds) | |
| return flatten(ds) | |
| def materialize_batches( | |
| self, dataset_elements: List[Dict[str, Any]] | |
| ) -> Generator[Dict[str, Any], None, None]: | |
| if self.prebatch: | |
| dataset_elements = self.preshuffle_elements(dataset_elements) | |
| current_batch = [] | |
| # function that creates a batch from the 'current_batch' list | |
| def output_batch() -> Dict[str, Any]: | |
| assert ( | |
| len( | |
| set([len(elem["predictable_candidates"]) for elem in current_batch]) | |
| ) | |
| == 1 | |
| ), " ".join( | |
| map( | |
| str, [len(elem["predictable_candidates"]) for elem in current_batch] | |
| ) | |
| ) | |
| batch_dict = dict() | |
| de_values_by_field = { | |
| fn: [de[fn] for de in current_batch if fn in de] | |
| for fn in self.fields_batcher | |
| } | |
| # in case you provide fields batchers but in the batch | |
| # there are no elements for that field | |
| de_values_by_field = { | |
| fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 | |
| } | |
| assert len(set([len(v) for v in de_values_by_field.values()])) | |
| # todo: maybe we should report the user about possible | |
| # fields filtering due to "None" instances | |
| de_values_by_field = { | |
| fn: fvs | |
| for fn, fvs in de_values_by_field.items() | |
| if all([fv is not None for fv in fvs]) | |
| } | |
| for field_name, field_values in de_values_by_field.items(): | |
| field_batch = ( | |
| self.fields_batcher[field_name](field_values) | |
| if self.fields_batcher[field_name] is not None | |
| else field_values | |
| ) | |
| batch_dict[field_name] = field_batch | |
| return batch_dict | |
| max_len_discards, min_len_discards = 0, 0 | |
| should_token_batch = self.batch_size is None | |
| curr_pred_elements = -1 | |
| for de in dataset_elements: | |
| if ( | |
| should_token_batch | |
| and self.max_batch_size != -1 | |
| and len(current_batch) == self.max_batch_size | |
| ) or (not should_token_batch and len(current_batch) == self.batch_size): | |
| yield output_batch() | |
| current_batch = [] | |
| curr_pred_elements = -1 | |
| too_long_fields = [ | |
| k | |
| for k in de | |
| if self.max_length != -1 | |
| and torch.is_tensor(de[k]) | |
| and len(de[k]) > self.max_length | |
| ] | |
| if len(too_long_fields) > 0: | |
| max_len_discards += 1 | |
| continue | |
| too_short_fields = [ | |
| k | |
| for k in de | |
| if self.min_length != -1 | |
| and torch.is_tensor(de[k]) | |
| and len(de[k]) < self.min_length | |
| ] | |
| if len(too_short_fields) > 0: | |
| min_len_discards += 1 | |
| continue | |
| if should_token_batch: | |
| de_len = sum(len(de[k]) for k in self.batching_fields) | |
| future_max_len = max( | |
| de_len, | |
| max( | |
| [ | |
| sum(len(bde[k]) for k in self.batching_fields) | |
| for bde in current_batch | |
| ], | |
| default=0, | |
| ), | |
| ) | |
| future_tokens_per_batch = future_max_len * (len(current_batch) + 1) | |
| num_predictable_candidates = len(de["predictable_candidates"]) | |
| if len(current_batch) > 0 and ( | |
| future_tokens_per_batch >= self.tokens_per_batch | |
| or ( | |
| num_predictable_candidates != curr_pred_elements | |
| and curr_pred_elements != -1 | |
| ) | |
| ): | |
| yield output_batch() | |
| current_batch = [] | |
| current_batch.append(de) | |
| curr_pred_elements = len(de["predictable_candidates"]) | |
| if len(current_batch) != 0 and not self.drop_last: | |
| yield output_batch() | |
| if max_len_discards > 0: | |
| if self.for_inference: | |
| logger.warning( | |
| f"WARNING: Inference mode is True but {max_len_discards} samples longer than max length were " | |
| f"found. The {max_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" | |
| f", this can INVALIDATE results. This might happen if the max length was not set to -1 or if the " | |
| f"sample length exceeds the maximum length supported by the current model." | |
| ) | |
| else: | |
| logger.warning( | |
| f"During iteration, {max_len_discards} elements were " | |
| f"discarded since longer than max length {self.max_length}" | |
| ) | |
| if min_len_discards > 0: | |
| if self.for_inference: | |
| logger.warning( | |
| f"WARNING: Inference mode is True but {min_len_discards} samples shorter than min length were " | |
| f"found. The {min_len_discards} samples will be DISCARDED. If you are doing some kind of evaluation" | |
| f", this can INVALIDATE results. This might happen if the min length was not set to -1 or if the " | |
| f"sample length is shorter than the minimum length supported by the current model." | |
| ) | |
| else: | |
| logger.warning( | |
| f"During iteration, {min_len_discards} elements were " | |
| f"discarded since shorter than min length {self.min_length}" | |
| ) | |
| def convert_tokens_to_char_annotations( | |
| sample: RelikReaderSample, | |
| remove_nmes: bool = True, | |
| ) -> RelikReaderSample: | |
| """ | |
| Converts the token annotations to char annotations. | |
| Args: | |
| sample (:obj:`RelikReaderSample`): | |
| The sample to convert. | |
| remove_nmes (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
| Whether to remove the NMEs from the annotations. | |
| Returns: | |
| :obj:`RelikReaderSample`: The converted sample. | |
| """ | |
| char_annotations = set() | |
| for ( | |
| predicted_entity, | |
| predicted_spans, | |
| ) in sample.predicted_window_labels.items(): | |
| if predicted_entity == NME_SYMBOL and remove_nmes: | |
| continue | |
| for span_start, span_end in predicted_spans: | |
| span_start = sample.token2char_start[str(span_start)] | |
| span_end = sample.token2char_end[str(span_end)] | |
| char_annotations.add((span_start, span_end, predicted_entity)) | |
| char_probs_annotations = dict() | |
| for ( | |
| span_start, | |
| span_end, | |
| ), candidates_probs in sample.span_title_probabilities.items(): | |
| span_start = sample.token2char_start[str(span_start)] | |
| span_end = sample.token2char_end[str(span_end)] | |
| char_probs_annotations[(span_start, span_end)] = { | |
| title for title, _ in candidates_probs | |
| } | |
| sample.predicted_window_labels_chars = char_annotations | |
| sample.probs_window_labels_chars = char_probs_annotations | |
| return sample | |
| def merge_patches_predictions(sample) -> None: | |
| sample._d["predicted_window_labels"] = dict() | |
| predicted_window_labels = sample._d["predicted_window_labels"] | |
| sample._d["span_title_probabilities"] = dict() | |
| span_title_probabilities = sample._d["span_title_probabilities"] | |
| span2title = dict() | |
| for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): | |
| # selecting span predictions | |
| for predicted_title, predicted_spans in patch_info[ | |
| "predicted_window_labels" | |
| ].items(): | |
| for pred_span in predicted_spans: | |
| pred_span = tuple(pred_span) | |
| curr_title = span2title.get(pred_span) | |
| if curr_title is None or curr_title == NME_SYMBOL: | |
| span2title[pred_span] = predicted_title | |
| # else: | |
| # print("Merging at patch level") | |
| # selecting span predictions probability | |
| for predicted_span, titles_probabilities in patch_info[ | |
| "span_title_probabilities" | |
| ].items(): | |
| if predicted_span not in span_title_probabilities: | |
| span_title_probabilities[predicted_span] = titles_probabilities | |
| for span, title in span2title.items(): | |
| if title not in predicted_window_labels: | |
| predicted_window_labels[title] = list() | |
| predicted_window_labels[title].append(span) | |