Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Iterable, Iterator, List, Optional | |
| import hydra | |
| import torch | |
| from lightning.pytorch.utilities import move_data_to_device | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from relik.reader.data.patches import merge_patches_predictions | |
| from relik.reader.data.relik_reader_sample import ( | |
| RelikReaderSample, | |
| load_relik_reader_samples, | |
| ) | |
| from relik.reader.relik_reader_core import RelikReaderCoreModel | |
| from relik.reader.utils.special_symbols import NME_SYMBOL | |
| logger = logging.getLogger(__name__) | |
| def convert_tokens_to_char_annotations( | |
| sample: RelikReaderSample, remove_nmes: bool = False | |
| ): | |
| char_annotations = set() | |
| for ( | |
| predicted_entity, | |
| predicted_spans, | |
| ) in sample.predicted_window_labels.items(): | |
| if predicted_entity == NME_SYMBOL and remove_nmes: | |
| continue | |
| for span_start, span_end in predicted_spans: | |
| span_start = sample.token2char_start[str(span_start)] | |
| span_end = sample.token2char_end[str(span_end)] | |
| char_annotations.add((span_start, span_end, predicted_entity)) | |
| char_probs_annotations = dict() | |
| for ( | |
| span_start, | |
| span_end, | |
| ), candidates_probs in sample.span_title_probabilities.items(): | |
| span_start = sample.token2char_start[str(span_start)] | |
| span_end = sample.token2char_end[str(span_end)] | |
| char_probs_annotations[(span_start, span_end)] = { | |
| title for title, _ in candidates_probs | |
| } | |
| sample.predicted_window_labels_chars = char_annotations | |
| sample.probs_window_labels_chars = char_probs_annotations | |
| class RelikReaderPredictor: | |
| def __init__( | |
| self, | |
| relik_reader_core: RelikReaderCoreModel, | |
| dataset_conf: Optional[dict] = None, | |
| predict_nmes: bool = False, | |
| ) -> None: | |
| self.relik_reader_core = relik_reader_core | |
| self.dataset_conf = dataset_conf | |
| self.predict_nmes = predict_nmes | |
| if self.dataset_conf is not None: | |
| # instantiate dataset | |
| self.dataset = hydra.utils.instantiate( | |
| dataset_conf, | |
| dataset_path=None, | |
| samples=None, | |
| ) | |
| def predict( | |
| self, | |
| path: Optional[str], | |
| samples: Optional[Iterable[RelikReaderSample]], | |
| dataset_conf: Optional[dict], | |
| token_batch_size: int = 1024, | |
| progress_bar: bool = False, | |
| **kwargs, | |
| ) -> List[RelikReaderSample]: | |
| annotated_samples = list( | |
| self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) | |
| ) | |
| for sample in annotated_samples: | |
| merge_patches_predictions(sample) | |
| convert_tokens_to_char_annotations( | |
| sample, remove_nmes=not self.predict_nmes | |
| ) | |
| return annotated_samples | |
| def _predict( | |
| self, | |
| path: Optional[str], | |
| samples: Optional[Iterable[RelikReaderSample]], | |
| dataset_conf: dict, | |
| token_batch_size: int = 1024, | |
| progress_bar: bool = False, | |
| **kwargs, | |
| ) -> Iterator[RelikReaderSample]: | |
| assert ( | |
| path is not None or samples is not None | |
| ), "Either predict on a path or on an iterable of samples" | |
| samples = load_relik_reader_samples(path) if samples is None else samples | |
| # setup infrastructure to re-yield in order | |
| def samples_it(): | |
| for i, sample in enumerate(samples): | |
| assert sample._mixin_prediction_position is None | |
| sample._mixin_prediction_position = i | |
| yield sample | |
| next_prediction_position = 0 | |
| position2predicted_sample = {} | |
| # instantiate dataset | |
| if getattr(self, "dataset", None) is not None: | |
| dataset = self.dataset | |
| dataset.samples = samples_it() | |
| dataset.tokens_per_batch = token_batch_size | |
| else: | |
| dataset = hydra.utils.instantiate( | |
| dataset_conf, | |
| dataset_path=None, | |
| samples=samples_it(), | |
| tokens_per_batch=token_batch_size, | |
| ) | |
| # instantiate dataloader | |
| iterator = DataLoader(dataset, batch_size=None, num_workers=0, shuffle=False) | |
| if progress_bar: | |
| iterator = tqdm(iterator, desc="Predicting") | |
| model_device = next(self.relik_reader_core.parameters()).device | |
| with torch.inference_mode(): | |
| for batch in iterator: | |
| # do batch predict | |
| with torch.autocast( | |
| "cpu" if model_device == torch.device("cpu") else "cuda" | |
| ): | |
| batch = move_data_to_device(batch, model_device) | |
| batch_out = self.relik_reader_core.batch_predict(**batch) | |
| # update prediction position position | |
| for sample in batch_out: | |
| if sample._mixin_prediction_position >= next_prediction_position: | |
| position2predicted_sample[ | |
| sample._mixin_prediction_position | |
| ] = sample | |
| # yield | |
| while next_prediction_position in position2predicted_sample: | |
| yield position2predicted_sample[next_prediction_position] | |
| del position2predicted_sample[next_prediction_position] | |
| next_prediction_position += 1 | |
| if len(position2predicted_sample) > 0: | |
| logger.warning( | |
| "It seems samples have been discarded in your dataset. " | |
| "This means that you WON'T have a prediction for each input sample. " | |
| "Prediction order will also be partially disrupted" | |
| ) | |
| for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): | |
| yield v | |
| if progress_bar: | |
| iterator.close() | |