Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import torch | |
| from pytorch_lightning import LightningDataModule | |
| import torch_geometric | |
| # from torch_geometric.loader import DataLoader | |
| from torch.utils.data import DataLoader | |
| from torch_geometric.loader.dataloader import Collater | |
| from data_provider.reaction_action_dataset import ActionDataset | |
| from data_provider.synthesis_dataset import SynthesisDataset | |
| from data_provider.caption_dataset import CaptionDataset | |
| from data_provider.chebi_dataset import ChEBI_dataset | |
| import re | |
| # we split individual characters inside special tokens like [START_DNA] | |
| CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") | |
| # token added to implement a custom sequence tokenization. This token is added at | |
| # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance | |
| # that they do not occur in the corpus. The digits are escaped so that the token does not appear | |
| # literally in the source code in case we ever include it in the training data. | |
| SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" | |
| def _insert_split_marker(m: re.Match): | |
| """ | |
| Applies split marker based on a regex match of special tokens such as | |
| [START_DNA]. | |
| Parameters | |
| ---------- | |
| n : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| start_token, _, sequence, end_token = m.groups() | |
| sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) | |
| return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" | |
| def smiles_handler(text, mol_ph, is_gal=True): | |
| smiles_list = [] | |
| for match in CUSTOM_SEQ_RE.finditer(text): | |
| smiles = match.group(3) | |
| smiles_list.append(smiles) | |
| if is_gal: | |
| text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text) | |
| text = escape_custom_split_sequence(text) | |
| return text, smiles_list | |
| else: | |
| text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text) | |
| return text, smiles_list | |
| def escape_custom_split_sequence(text): | |
| """ | |
| Applies custom splitting to the text for GALILEO's tokenization | |
| Parameters | |
| ---------- | |
| text : str | |
| Input text to split | |
| Returns | |
| ---------- | |
| str - the text with the split token added | |
| """ | |
| return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) | |
| class TrainCollater: | |
| def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True, use_graph=True, use_qa_pair=True): | |
| self.rxn_max_len = rxn_max_len | |
| self.text_max_len = text_max_len | |
| self.tokenizer = tokenizer | |
| self.collater = Collater([], []) | |
| self.mol_ph = mol_ph | |
| self.mol_token_id = mol_token_id | |
| self.is_gal = is_gal | |
| self.use_graph = use_graph | |
| self.use_qa_pair = use_qa_pair | |
| def __call__(self, batch): | |
| return self.collate_qa(batch) if self.use_qa_pair else self.collate(batch) | |
| def collate(self, batch): | |
| rxn_ids, graphs, texts, smiles_prompt = zip(*batch) | |
| if graphs: | |
| graphs = self.collater(graphs) | |
| ## deal with prompt | |
| if self.use_graph: | |
| smiles_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in smiles_prompt] | |
| else: | |
| smiles_prompt = [escape_custom_split_sequence(p) for p in smiles_prompt] | |
| self.tokenizer.padding_side = 'left' | |
| smiles_prompt_tokens = self.tokenizer(text=smiles_prompt, | |
| truncation=False, | |
| padding='longest', | |
| add_special_tokens=True, | |
| return_tensors='pt', | |
| return_attention_mask=True) | |
| is_mol_token = smiles_prompt_tokens.input_ids == self.mol_token_id | |
| smiles_prompt_tokens['is_mol_token'] = is_mol_token | |
| self.tokenizer.padding_side = 'right' | |
| text_tokens = self.tokenizer(text=texts, | |
| truncation=True, | |
| padding='longest', | |
| add_special_tokens=True, | |
| max_length=self.text_max_len, | |
| return_tensors='pt', | |
| return_attention_mask=True) | |
| return rxn_ids, graphs, smiles_prompt_tokens, text_tokens | |
| def collate_qa(self, batch): | |
| rxn_ids, graphs, texts, input_prompt = zip(*batch) | |
| graphs = [graph for graph_batch in graphs for graph in graph_batch] | |
| if graphs: | |
| graphs = self.collater(graphs) | |
| ## deal with prompt | |
| if self.use_graph: | |
| input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt] | |
| else: | |
| input_prompt = [escape_custom_split_sequence(p) for p in input_prompt] | |
| self.tokenizer.padding_side = 'right' | |
| qa_pair = [[q, a] for q, a in zip(input_prompt, texts)] | |
| qa_batch = self.tokenizer(qa_pair, | |
| truncation=True, | |
| padding='longest', | |
| add_special_tokens=True, | |
| max_length=self.rxn_max_len + self.text_max_len, | |
| return_tensors='pt', | |
| return_attention_mask=True, | |
| return_token_type_ids=True) | |
| is_mol_token = qa_batch.input_ids == self.mol_token_id | |
| qa_batch['is_mol_token'] = is_mol_token | |
| return rxn_ids, graphs, qa_batch | |
| class InferenceCollater: | |
| def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True): | |
| self.text_max_len = text_max_len | |
| self.rxn_max_len = rxn_max_len | |
| self.tokenizer = tokenizer | |
| self.collater = Collater([], []) | |
| self.mol_ph = mol_ph | |
| self.mol_token_id = mol_token_id | |
| self.is_gal = is_gal | |
| def __call__(self, batch): | |
| rxn_ids, graphs, texts, input_prompt = zip(*batch) | |
| inputs = input_prompt | |
| graphs = [graph for graph_batch in graphs for graph in graph_batch] | |
| if graphs: | |
| graphs = self.collater(graphs) | |
| input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt] | |
| ## deal with prompt | |
| self.tokenizer.padding_side = 'left' | |
| input_prompt_tokens = self.tokenizer(input_prompt, | |
| truncation=True, | |
| padding='longest', | |
| add_special_tokens=True, | |
| max_length=self.rxn_max_len, | |
| return_tensors='pt', | |
| return_attention_mask=True) | |
| is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id | |
| input_prompt_tokens['is_mol_token'] = is_mol_token | |
| return rxn_ids, graphs, input_prompt_tokens, texts, inputs | |
| class TuneDM(LightningDataModule): | |
| def __init__( | |
| self, | |
| num_workers: int = 0, | |
| batch_size: int = 256, | |
| root: str = 'data/', | |
| text_max_len: int = 128, | |
| smi_max_len: int = 128, | |
| rxn_max_len: int = 128, | |
| tokenizer=None, | |
| downstream_task='action', | |
| args=None, | |
| ): | |
| super().__init__() | |
| self.args = args | |
| self.batch_size = batch_size | |
| self.inference_batch_size = args.inference_batch_size | |
| self.num_workers = num_workers | |
| self.rxn_max_len = rxn_max_len | |
| self.text_max_len = text_max_len | |
| self.prompt = args.prompt | |
| DownstreamDataset = { | |
| 'action': ActionDataset, | |
| 'synthesis': SynthesisDataset, | |
| 'caption': CaptionDataset, | |
| 'chebi': ChEBI_dataset, | |
| }[downstream_task] | |
| ds_args = { | |
| 'use_graph': not args.disable_graphs, | |
| 'disable_graph_cache': args.disable_graph_cache, | |
| 'smiles_type': args.smiles_type, | |
| } | |
| if downstream_task == 'action': | |
| ds_args['predict_rxn_condition'] = args.predict_rxn_condition | |
| if downstream_task == 'synthesis': | |
| ds_args['roundrobin_train'] = args.roundrobin_train | |
| ds_args['test_subset'] = args.test_subset | |
| self.train_dataset = DownstreamDataset(root, 'train', smi_max_len, **ds_args) | |
| self.val_dataset = DownstreamDataset(root, 'valid', smi_max_len, **ds_args) | |
| self.test_dataset = DownstreamDataset(root, 'test', smi_max_len, **ds_args) | |
| self.init_tokenizer(tokenizer) | |
| self.mol_ph_token = '<mol>' * self.args.num_query_token | |
| self.is_gal = args.opt_model.find('galactica') >= 0 | |
| self.use_graph = not args.disable_graphs | |
| self.is_t5 = args.opt_model.find('t5') >= 0 | |
| def init_tokenizer(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.train_dataset.tokenizer = tokenizer | |
| self.val_dataset.tokenizer = tokenizer | |
| self.test_dataset.tokenizer = tokenizer | |
| self.mol_token_id = self.tokenizer.mol_token_id | |
| # self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0] | |
| def train_dataloader(self): | |
| if self.args.roundrobin_train: | |
| self.train_dataset.reload_data() | |
| if hasattr(self.train_dataset, 'renew_r_smiles'): | |
| self.train_dataset.renew_r_smiles() | |
| loader = DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| pin_memory=False, | |
| drop_last=True, | |
| persistent_workers=True, | |
| collate_fn=TrainCollater( | |
| tokenizer=self.tokenizer, | |
| text_max_len=self.text_max_len, | |
| rxn_max_len=self.rxn_max_len, | |
| mol_ph=self.mol_ph_token, | |
| mol_token_id=self.mol_token_id, | |
| is_gal=self.is_gal, | |
| use_graph=self.use_graph, | |
| use_qa_pair=not self.is_t5, | |
| ), | |
| ) | |
| return loader | |
| def val_dataloader(self): | |
| test_loader = DataLoader( | |
| self.test_dataset, | |
| batch_size=self.inference_batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=False, | |
| drop_last=False, | |
| persistent_workers=True, | |
| collate_fn=InferenceCollater( | |
| tokenizer=self.tokenizer, | |
| text_max_len=self.text_max_len, | |
| rxn_max_len=self.rxn_max_len, | |
| mol_ph=self.mol_ph_token, | |
| mol_token_id=self.mol_token_id, | |
| is_gal=self.is_gal | |
| ), | |
| ) | |
| return [test_loader] | |
| val_loader = DataLoader( | |
| self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=False, | |
| drop_last=False, | |
| persistent_workers=True, | |
| collate_fn=InferenceCollater( | |
| tokenizer=self.tokenizer, | |
| text_max_len=self.text_max_len, | |
| rxn_max_len=self.rxn_max_len, | |
| mol_ph=self.mol_ph_token, | |
| mol_token_id=self.mol_token_id, | |
| is_gal=self.is_gal | |
| ), | |
| ) | |
| return [val_loader, test_loader] | |
| def test_dataloader(self): | |
| loader = DataLoader( | |
| self.test_dataset, | |
| batch_size=self.inference_batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| pin_memory=False, | |
| drop_last=False, | |
| persistent_workers=True, | |
| collate_fn=InferenceCollater( | |
| tokenizer=self.tokenizer, | |
| text_max_len=self.text_max_len, | |
| rxn_max_len=self.rxn_max_len, | |
| mol_ph=self.mol_ph_token, | |
| mol_token_id=self.mol_token_id, | |
| is_gal=self.is_gal | |
| ), | |
| ) | |
| return loader | |