Spaces:
Runtime error
Runtime error
| import collections | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Iterator, List, Union | |
| import torch | |
| import transformers as tr | |
| from tqdm import tqdm | |
| from transformers import AutoConfig | |
| from relik.common.log import get_console_logger, get_logger | |
| from relik.reader.data.relik_reader_data_utils import batchify, flatten | |
| from relik.reader.data.relik_reader_sample import RelikReaderSample | |
| from relik.reader.pytorch_modules.hf.modeling_relik import ( | |
| RelikReaderConfig, | |
| RelikReaderSpanModel, | |
| ) | |
| from relik.reader.relik_reader_predictor import RelikReaderPredictor | |
| from relik.reader.utils.save_load_utilities import load_model_and_conf | |
| from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols | |
| console_logger = get_console_logger() | |
| logger = get_logger(__name__, level=logging.INFO) | |
| class RelikReaderForSpanExtraction(torch.nn.Module): | |
| def __init__( | |
| self, | |
| transformer_model: str | tr.PreTrainedModel | None = None, | |
| additional_special_symbols: int = 0, | |
| num_layers: int | None = None, | |
| activation: str = "gelu", | |
| linears_hidden_size: int | None = 512, | |
| use_last_k_layers: int = 1, | |
| training: bool = False, | |
| device: str | torch.device | None = None, | |
| tokenizer: str | tr.PreTrainedTokenizer | None = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| if isinstance(transformer_model, str): | |
| config = AutoConfig.from_pretrained( | |
| transformer_model, trust_remote_code=True | |
| ) | |
| if "relik-reader" in config.model_type: | |
| transformer_model = RelikReaderSpanModel.from_pretrained( | |
| transformer_model, **kwargs | |
| ) | |
| else: | |
| reader_config = RelikReaderConfig( | |
| transformer_model=transformer_model, | |
| additional_special_symbols=additional_special_symbols, | |
| num_layers=num_layers, | |
| activation=activation, | |
| linears_hidden_size=linears_hidden_size, | |
| use_last_k_layers=use_last_k_layers, | |
| training=training, | |
| ) | |
| transformer_model = RelikReaderSpanModel(reader_config) | |
| self.relik_reader_model = transformer_model | |
| self._tokenizer = tokenizer | |
| # move the model to the device | |
| self.to(device or torch.device("cpu")) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: torch.Tensor, | |
| prediction_mask: torch.Tensor | None = None, | |
| special_symbols_mask: torch.Tensor | None = None, | |
| special_symbols_mask_entities: torch.Tensor | None = None, | |
| start_labels: torch.Tensor | None = None, | |
| end_labels: torch.Tensor | None = None, | |
| disambiguation_labels: torch.Tensor | None = None, | |
| relation_labels: torch.Tensor | None = None, | |
| is_validation: bool = False, | |
| is_prediction: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| return self.relik_reader_model( | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| prediction_mask, | |
| special_symbols_mask, | |
| special_symbols_mask_entities, | |
| start_labels, | |
| end_labels, | |
| disambiguation_labels, | |
| relation_labels, | |
| is_validation, | |
| is_prediction, | |
| *args, | |
| **kwargs, | |
| ) | |
| def batch_predict( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| token_type_ids: torch.Tensor | None = None, | |
| prediction_mask: torch.Tensor | None = None, | |
| special_symbols_mask: torch.Tensor | None = None, | |
| sample: List[RelikReaderSample] | None = None, | |
| top_k: int = 5, # the amount of top-k most probable entities to predict | |
| *args, | |
| **kwargs, | |
| ) -> Iterator[RelikReaderSample]: | |
| """ | |
| Args: | |
| input_ids: | |
| attention_mask: | |
| token_type_ids: | |
| prediction_mask: | |
| special_symbols_mask: | |
| sample: | |
| top_k: | |
| *args: | |
| **kwargs: | |
| Returns: | |
| """ | |
| forward_output = self.forward( | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| prediction_mask, | |
| special_symbols_mask, | |
| ) | |
| ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() | |
| ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() | |
| ed_predictions = forward_output["ed_predictions"].cpu().numpy() | |
| ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() | |
| batch_predictable_candidates = kwargs["predictable_candidates"] | |
| patch_offset = kwargs["patch_offset"] | |
| for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( | |
| sample, | |
| ned_start_predictions, | |
| ned_end_predictions, | |
| ed_predictions, | |
| ed_probabilities, | |
| batch_predictable_candidates, | |
| patch_offset, | |
| ): | |
| ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] | |
| ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] | |
| final_class2predicted_spans = collections.defaultdict(list) | |
| spans2predicted_probabilities = dict() | |
| for start_token_index, end_token_index in zip( | |
| ne_start_indices, ne_end_indices | |
| ): | |
| # predicted candidate | |
| token_class = edp[start_token_index + 1] - 1 | |
| predicted_candidate_title = pred_cands[token_class] | |
| final_class2predicted_spans[predicted_candidate_title].append( | |
| [start_token_index, end_token_index] | |
| ) | |
| # candidates probabilities | |
| classes_probabilities = edpr[start_token_index + 1] | |
| classes_probabilities_best_indices = classes_probabilities.argsort()[ | |
| ::-1 | |
| ] | |
| titles_2_probs = [] | |
| top_k = ( | |
| min( | |
| top_k, | |
| len(classes_probabilities_best_indices), | |
| ) | |
| if top_k != -1 | |
| else len(classes_probabilities_best_indices) | |
| ) | |
| for i in range(top_k): | |
| titles_2_probs.append( | |
| ( | |
| pred_cands[classes_probabilities_best_indices[i] - 1], | |
| classes_probabilities[ | |
| classes_probabilities_best_indices[i] | |
| ].item(), | |
| ) | |
| ) | |
| spans2predicted_probabilities[ | |
| (start_token_index, end_token_index) | |
| ] = titles_2_probs | |
| if "patches" not in ts._d: | |
| ts._d["patches"] = dict() | |
| ts._d["patches"][po] = dict() | |
| sample_patch = ts._d["patches"][po] | |
| sample_patch["predicted_window_labels"] = final_class2predicted_spans | |
| sample_patch["span_title_probabilities"] = spans2predicted_probabilities | |
| # additional info | |
| sample_patch["predictable_candidates"] = pred_cands | |
| yield ts | |
| def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]: | |
| candidates_symbols = get_special_symbols(len(candidates)) | |
| candidates = [ | |
| [cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] | |
| for cs, ct in zip(candidates_symbols, candidates) | |
| ] | |
| return ( | |
| [self.tokenizer.cls_token] | |
| + text | |
| + [self.tokenizer.sep_token] | |
| + flatten(candidates) | |
| + [self.tokenizer.sep_token] | |
| ) | |
| def _compute_offsets(offsets_mapping): | |
| offsets_mapping = offsets_mapping.numpy() | |
| token2word = [] | |
| word2token = {} | |
| count = 0 | |
| for i, offset in enumerate(offsets_mapping): | |
| if offset[0] == 0: | |
| token2word.append(i - count) | |
| word2token[i - count] = [i] | |
| else: | |
| token2word.append(token2word[-1]) | |
| word2token[token2word[-1]].append(i) | |
| count += 1 | |
| return token2word, word2token | |
| def _convert_tokens_to_word_annotations(sample: RelikReaderSample): | |
| triplets = [] | |
| entities = [] | |
| for entity in sample.predicted_entities: | |
| if sample.entity_candidates: | |
| entities.append( | |
| ( | |
| sample.token2word[entity[0] - 1], | |
| sample.token2word[entity[1] - 1] + 1, | |
| sample.entity_candidates[entity[2]], | |
| ) | |
| ) | |
| else: | |
| entities.append( | |
| ( | |
| sample.token2word[entity[0] - 1], | |
| sample.token2word[entity[1] - 1] + 1, | |
| -1, | |
| ) | |
| ) | |
| for predicted_triplet, predicted_triplet_probabilities in zip( | |
| sample.predicted_relations, sample.predicted_relations_probabilities | |
| ): | |
| subject, object_, relation = predicted_triplet | |
| subject = entities[subject] | |
| object_ = entities[object_] | |
| relation = sample.candidates[relation] | |
| triplets.append( | |
| { | |
| "subject": { | |
| "start": subject[0], | |
| "end": subject[1], | |
| "type": subject[2], | |
| "name": " ".join(sample.tokens[subject[0] : subject[1]]), | |
| }, | |
| "relation": { | |
| "name": relation, | |
| "probability": float(predicted_triplet_probabilities.round(2)), | |
| }, | |
| "object": { | |
| "start": object_[0], | |
| "end": object_[1], | |
| "type": object_[2], | |
| "name": " ".join(sample.tokens[object_[0] : object_[1]]), | |
| }, | |
| } | |
| ) | |
| sample.predicted_entities = entities | |
| sample.predicted_relations = triplets | |
| sample.predicted_relations_probabilities = None | |
| def read( | |
| self, | |
| text: List[str] | List[List[str]] | None = None, | |
| samples: List[RelikReaderSample] | None = None, | |
| input_ids: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| token_type_ids: torch.Tensor | None = None, | |
| prediction_mask: torch.Tensor | None = None, | |
| special_symbols_mask: torch.Tensor | None = None, | |
| special_symbols_mask_entities: torch.Tensor | None = None, | |
| candidates: List[List[str]] | None = None, | |
| max_length: int | None = 1024, | |
| max_batch_size: int | None = 64, | |
| token_batch_size: int | None = None, | |
| progress_bar: bool = False, | |
| *args, | |
| **kwargs, | |
| ) -> List[List[RelikReaderSample]]: | |
| """ | |
| Reads the given text. | |
| Args: | |
| text: The text to read in tokens. | |
| samples: | |
| input_ids: The input ids of the text. | |
| attention_mask: The attention mask of the text. | |
| token_type_ids: The token type ids of the text. | |
| prediction_mask: The prediction mask of the text. | |
| special_symbols_mask: The special symbols mask of the text. | |
| special_symbols_mask_entities: The special symbols mask entities of the text. | |
| candidates: The candidates of the text. | |
| max_length: The maximum length of the text. | |
| max_batch_size: The maximum batch size. | |
| token_batch_size: The maximum number of tokens per batch. | |
| progress_bar: | |
| Returns: | |
| The predicted labels for each sample. | |
| """ | |
| if text is None and input_ids is None and samples is None: | |
| raise ValueError( | |
| "Either `text` or `input_ids` or `samples` must be provided." | |
| ) | |
| if (input_ids is None and samples is None) and ( | |
| text is None or candidates is None | |
| ): | |
| raise ValueError( | |
| "`text` and `candidates` must be provided to return the predictions when " | |
| "`input_ids` and `samples` is not provided." | |
| ) | |
| if text is not None and samples is None: | |
| if len(text) != len(candidates): | |
| raise ValueError("`text` and `candidates` must have the same length.") | |
| if isinstance(text[0], str): # change to list of text | |
| text = [text] | |
| candidates = [candidates] | |
| samples = [ | |
| RelikReaderSample(tokens=t, candidates=c) | |
| for t, c in zip(text, candidates) | |
| ] | |
| if samples is not None: | |
| # function that creates a batch from the 'current_batch' list | |
| def output_batch() -> Dict[str, Any]: | |
| assert ( | |
| len( | |
| set( | |
| [ | |
| len(elem["predictable_candidates"]) | |
| for elem in current_batch | |
| ] | |
| ) | |
| ) | |
| == 1 | |
| ), " ".join( | |
| map( | |
| str, | |
| [len(elem["predictable_candidates"]) for elem in current_batch], | |
| ) | |
| ) | |
| batch_dict = dict() | |
| de_values_by_field = { | |
| fn: [de[fn] for de in current_batch if fn in de] | |
| for fn in self.fields_batcher | |
| } | |
| # in case you provide fields batchers but in the batch | |
| # there are no elements for that field | |
| de_values_by_field = { | |
| fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 | |
| } | |
| assert len(set([len(v) for v in de_values_by_field.values()])) | |
| # todo: maybe we should report the user about possible | |
| # fields filtering due to "None" instances | |
| de_values_by_field = { | |
| fn: fvs | |
| for fn, fvs in de_values_by_field.items() | |
| if all([fv is not None for fv in fvs]) | |
| } | |
| for field_name, field_values in de_values_by_field.items(): | |
| field_batch = ( | |
| self.fields_batcher[field_name]([fv[0] for fv in field_values]) | |
| if self.fields_batcher[field_name] is not None | |
| else field_values | |
| ) | |
| batch_dict[field_name] = field_batch | |
| batch_dict = { | |
| k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
| for k, v in batch_dict.items() | |
| } | |
| return batch_dict | |
| current_batch = [] | |
| predictions = [] | |
| current_cand_len = -1 | |
| for sample in tqdm(samples, disable=not progress_bar): | |
| sample.candidates = [NME_SYMBOL] + sample.candidates | |
| inputs_text = self._build_input(sample.tokens, sample.candidates) | |
| model_inputs = self.tokenizer( | |
| inputs_text, | |
| is_split_into_words=True, | |
| add_special_tokens=False, | |
| padding=False, | |
| truncation=True, | |
| max_length=max_length or self.tokenizer.model_max_length, | |
| return_offsets_mapping=True, | |
| return_tensors="pt", | |
| ) | |
| model_inputs["special_symbols_mask"] = ( | |
| model_inputs["input_ids"] > self.tokenizer.vocab_size | |
| ) | |
| # prediction mask is 0 until the first special symbol | |
| model_inputs["token_type_ids"] = ( | |
| torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 | |
| ).long() | |
| # shift prediction_mask to the left | |
| model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( | |
| shifts=-1, dims=1 | |
| ) | |
| model_inputs["prediction_mask"][:, -1] = 1 | |
| model_inputs["prediction_mask"][:, 0] = 1 | |
| assert ( | |
| len(model_inputs["special_symbols_mask"]) | |
| == len(model_inputs["prediction_mask"]) | |
| == len(model_inputs["input_ids"]) | |
| ) | |
| model_inputs["sample"] = sample | |
| # compute cand_len using special_symbols_mask | |
| model_inputs["predictable_candidates"] = sample.candidates[ | |
| : model_inputs["special_symbols_mask"].sum().item() | |
| ] | |
| # cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) | |
| offsets = model_inputs.pop("offset_mapping") | |
| offsets = offsets[model_inputs["prediction_mask"] == 0] | |
| sample.token2word, sample.word2token = self._compute_offsets(offsets) | |
| future_max_len = max( | |
| len(model_inputs["input_ids"]), | |
| max([len(b["input_ids"]) for b in current_batch], default=0), | |
| ) | |
| future_tokens_per_batch = future_max_len * (len(current_batch) + 1) | |
| if len(current_batch) > 0 and ( | |
| ( | |
| len(model_inputs["predictable_candidates"]) != current_cand_len | |
| and current_cand_len != -1 | |
| ) | |
| or ( | |
| isinstance(token_batch_size, int) | |
| and future_tokens_per_batch >= token_batch_size | |
| ) | |
| or len(current_batch) == max_batch_size | |
| ): | |
| batch_inputs = output_batch() | |
| current_batch = [] | |
| predictions.extend(list(self.batch_predict(**batch_inputs))) | |
| current_cand_len = len(model_inputs["predictable_candidates"]) | |
| current_batch.append(model_inputs) | |
| if current_batch: | |
| batch_inputs = output_batch() | |
| predictions.extend(list(self.batch_predict(**batch_inputs))) | |
| else: | |
| predictions = list( | |
| self.batch_predict( | |
| input_ids, | |
| attention_mask, | |
| token_type_ids, | |
| prediction_mask, | |
| special_symbols_mask, | |
| special_symbols_mask_entities, | |
| *args, | |
| **kwargs, | |
| ) | |
| ) | |
| return predictions | |
| def device(self) -> torch.device: | |
| """ | |
| The device of the model. | |
| """ | |
| return next(self.parameters()).device | |
| def tokenizer(self) -> tr.PreTrainedTokenizer: | |
| """ | |
| The tokenizer. | |
| """ | |
| if self._tokenizer: | |
| return self._tokenizer | |
| self._tokenizer = tr.AutoTokenizer.from_pretrained( | |
| self.relik_reader_model.config.name_or_path | |
| ) | |
| return self._tokenizer | |
| def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: | |
| fields_batchers = { | |
| "input_ids": lambda x: batchify( | |
| x, padding_value=self.tokenizer.pad_token_id | |
| ), | |
| "attention_mask": lambda x: batchify(x, padding_value=0), | |
| "token_type_ids": lambda x: batchify(x, padding_value=0), | |
| "prediction_mask": lambda x: batchify(x, padding_value=1), | |
| "global_attention": lambda x: batchify(x, padding_value=0), | |
| "token2word": None, | |
| "sample": None, | |
| "special_symbols_mask": lambda x: batchify(x, padding_value=False), | |
| "special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), | |
| } | |
| if "roberta" in self.relik_reader_model.config.model_type: | |
| del fields_batchers["token_type_ids"] | |
| return fields_batchers | |
| def save_pretrained( | |
| self, | |
| output_dir: str, | |
| model_name: str | None = None, | |
| push_to_hub: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| Saves the model to the given path. | |
| Args: | |
| output_dir: The path to save the model to. | |
| model_name: The name of the model. | |
| push_to_hub: Whether to push the model to the hub. | |
| """ | |
| # create the output directory | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| model_name = model_name or "relik-reader-for-span-extraction" | |
| logger.info(f"Saving reader to {output_dir / model_name}") | |
| # save the model | |
| self.relik_reader_model.register_for_auto_class() | |
| self.relik_reader_model.save_pretrained( | |
| output_dir / model_name, push_to_hub=push_to_hub, **kwargs | |
| ) | |
| logger.info("Saving reader to disk done.") | |
| if self.tokenizer: | |
| self.tokenizer.save_pretrained( | |
| output_dir / model_name, push_to_hub=push_to_hub, **kwargs | |
| ) | |
| logger.info("Saving tokenizer to disk done.") | |
| class RelikReader: | |
| def __init__(self, model_path: str, predict_nmes: bool = False): | |
| model, model_conf = load_model_and_conf(model_path) | |
| model.training = False | |
| model.eval() | |
| val_dataset_conf = model_conf.data.val_dataset | |
| val_dataset_conf.special_symbols = get_special_symbols( | |
| model_conf.model.entities_per_forward | |
| ) | |
| val_dataset_conf.transformer_model = model_conf.model.model.transformer_model | |
| self.predictor = RelikReaderPredictor( | |
| model, | |
| dataset_conf=model_conf.data.val_dataset, | |
| predict_nmes=predict_nmes, | |
| ) | |
| self.model_path = model_path | |
| def link_entities( | |
| self, | |
| dataset_path_or_samples: str | Iterator[RelikReaderSample], | |
| token_batch_size: int = 2048, | |
| progress_bar: bool = False, | |
| ) -> List[RelikReaderSample]: | |
| data_input = ( | |
| (dataset_path_or_samples, None) | |
| if isinstance(dataset_path_or_samples, str) | |
| else (None, dataset_path_or_samples) | |
| ) | |
| return self.predictor.predict( | |
| *data_input, | |
| dataset_conf=None, | |
| token_batch_size=token_batch_size, | |
| progress_bar=progress_bar, | |
| ) | |
| # def save_pretrained(self, path: Union[str, Path]): | |
| # self.predictor.save(path) | |
| def main(): | |
| rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True) | |
| predictions = rr.link_entities( | |
| "/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl" | |
| ) | |
| print(predictions) | |
| if __name__ == "__main__": | |
| main() | |