Spaces:
Runtime error
Runtime error
| import torch | |
| from torch_geometric.data import Dataset | |
| import os | |
| from torch_geometric.data import InMemoryDataset | |
| from .data_utils import reformat_smiles | |
| import random | |
| import json | |
| class PubChemDataset(InMemoryDataset): | |
| def __init__(self, path): | |
| super(PubChemDataset, self).__init__() | |
| self.data, self.slices = torch.load(path) | |
| def __getitem__(self, idx): | |
| return self.get(idx) | |
| class CaptionDataset(Dataset): | |
| def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'): | |
| super(CaptionDataset, self).__init__(root) | |
| self.root = root | |
| self.file_path = os.path.join(root, f'{mode}.pt') | |
| self.smi_max_len = smi_max_len | |
| self.tokenizer = None | |
| self.use_graph = use_graph | |
| self.smiles_type = smiles_type | |
| self.data = PubChemDataset(self.file_path) | |
| def get(self, index): | |
| return self.__getitem__(index) | |
| def len(self): | |
| return len(self) | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, index): | |
| data = self.data[index] | |
| smiles = reformat_smiles(data.smiles, smiles_type=self.smiles_type) | |
| smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. ' | |
| text_list = [] | |
| count = 0 | |
| for line in data.text.split('\n'): | |
| count += 1 | |
| text_list.append(line.strip()) | |
| if count > 100: | |
| break | |
| text = ' '.join(text_list) + '\n' | |
| graph_list = [data] if self.use_graph else [] | |
| return index, graph_list, text, smiles_prompt | |
| class PretrainCaptionDataset(Dataset): | |
| def __init__(self, root, smi_max_len=128, use_graph=True, disable_graph_cache=False): | |
| super(PretrainCaptionDataset, self).__init__(root) | |
| self.pre_train_data = CaptionDataset( | |
| root, | |
| 'pretrain', | |
| smi_max_len=smi_max_len, | |
| use_graph=use_graph, | |
| ) | |
| self.train_data = CaptionDataset( | |
| root, | |
| 'train', | |
| smi_max_len=smi_max_len, | |
| use_graph=use_graph, | |
| ) | |
| def get(self, index): | |
| return self.__getitem__(index) | |
| def len(self): | |
| return len(self) | |
| def __len__(self): | |
| return len(self.pre_train_data) + len(self.train_data) | |
| def __getitem__(self, index): | |
| if index < len(self.pre_train_data): | |
| index, graph_list, text, smiles_prompt = self.pre_train_data[index] | |
| else: | |
| index, graph_list, text, smiles_prompt = self.train_data[index - len(self.pre_train_data)] | |
| graph_item = graph_list[0] | |
| if hasattr(graph_item, 'iupac'): | |
| del graph_item.iupac | |
| if hasattr(graph_item, 'cid'): | |
| del graph_item.cid | |
| del graph_item.text | |
| del graph_item.smiles | |
| return graph_item, text, smiles_prompt |