Spaces:
Runtime error
Runtime error
| import random | |
| import os | |
| import numpy as np | |
| import argparse | |
| import json | |
| from collections import defaultdict | |
| from matplotlib import pyplot as plt | |
| from collections import Counter | |
| from .data_utils import json_read | |
| def set_random_seed(seed): | |
| random.seed(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| np.random.seed(seed) | |
| class Reaction_Cluster: | |
| def __init__(self, root, reaction_filename, reverse_ratio=0.5): | |
| self.root = root | |
| self.reaction_data = json_read(os.path.join(self.root, reaction_filename)) | |
| self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json')) | |
| self.mol_property_map = {d['canon_smiles']: d for d in self.property_data} | |
| self.reverse_ratio = reverse_ratio | |
| self.rxn_mols_attr = defaultdict(lambda:{ | |
| 'freq': 0, | |
| 'occurrence': 0, | |
| 'in_caption': False, | |
| }) | |
| self._read_reaction_mols() # add `valid_mols` in each rxn_dict | |
| self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols']) | |
| self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict | |
| self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict | |
| def _read_reaction_mols(self): | |
| self.valid_rxn_indices = [] | |
| for rxn_id, rxn_dict in enumerate(self.reaction_data): | |
| mol_role_map = {} | |
| for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: | |
| for m in rxn_dict[key]: | |
| if m in mol_role_map: | |
| continue | |
| if m in self.mol_property_map: | |
| mol_role_map[m] = key | |
| valid_mols = [] | |
| for mol in mol_role_map: | |
| assert mol in self.mol_property_map # this is garanteed by the above if statement | |
| if 'abstract' not in self.mol_property_map[mol]: | |
| continue | |
| valid_mols.append(mol) # here the molecules should be in the R, C, S, P order. | |
| if len(valid_mols) > 0: | |
| self.valid_rxn_indices.append(rxn_id) | |
| rxn_dict['valid_mols'] = valid_mols | |
| rxn_dict['mol_role_map'] = mol_role_map | |
| def _calculate_Pr(self): | |
| total_weights = 0 | |
| for rxn_dict in self.reaction_data: | |
| rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']]) | |
| rxn_dict['weight'] = rxn_weight | |
| total_weights += rxn_weight | |
| for rxn_dict in self.reaction_data: | |
| rxn_dict['weight'] = rxn_dict['weight'] / total_weights | |
| def _calculate_Pir(self): | |
| for rxn_dict in self.reaction_data: | |
| mol_weight = {} | |
| for mol in rxn_dict['valid_mols']: | |
| mol_weight[mol] = 1/self.mol_counter[mol] | |
| total_weight = sum(mol_weight.values()) | |
| rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()} | |
| def choose_mol(self, valid_mols, k=4, weights=None): | |
| if k>=len(valid_mols): | |
| sampled_indices = list(range(len(valid_mols))) | |
| else: | |
| sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights) | |
| sampled_indices = list(sampled_indices) | |
| sampled_indices = sorted(sampled_indices) | |
| if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance. | |
| sampled_indices.reverse() | |
| sampled_mols = [valid_mols[i] for i in sampled_indices] | |
| return sampled_mols | |
| def sample_mol_batch(self, index=None, k=4): | |
| if index is None: | |
| index = self.sample_rxn_index(1)[0] | |
| assert index < len(self.reaction_data) | |
| rxn = self.reaction_data[index] | |
| valid_mols, weights = zip(*rxn['mol_weight'].items()) | |
| sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights) | |
| mol_property_batch = [] | |
| for mol in sampled_mols: | |
| mol_property = self.mol_property_map[mol] | |
| mol_role = rxn['mol_role_map'][mol] | |
| mol_property['role'] = mol_role | |
| mol_property_batch.append(mol_property) | |
| if 'rsmiles_map' in rxn: | |
| rsmiles_map = random.choice(rxn['rsmiles_map']) | |
| for mol_property in mol_property_batch: | |
| canon_smiles = mol_property['canon_smiles'] | |
| if canon_smiles in rsmiles_map: | |
| mol_property['r_smiles'] = rsmiles_map[canon_smiles] | |
| return mol_property_batch | |
| def sample_rxn_index(self, num_samples): | |
| indices = range(len(self.reaction_data)) | |
| weights = [d['weight'] for d in self.reaction_data] | |
| return np.random.choice(indices, num_samples, replace=False, p=weights) | |
| def __call__(self, rxn_num=1000, k=4): | |
| sampled_indices = self.sample_rxn_index(rxn_num) | |
| sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices] | |
| return sampled_batch | |
| def generate_batch_uniform_rxn(self, rxn_num=1000, k=4): | |
| assert rxn_num <= len(self.valid_rxn_indices) | |
| sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num) | |
| sampled_batch = [] | |
| for rxn_id in sampled_rxn_indices: | |
| rxn = self.reaction_data[rxn_id] | |
| sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None) | |
| mol_property_batch = [] | |
| for mol in sampled_mols: | |
| mol_property = self.mol_property_map[mol] | |
| mol_role = rxn['mol_role_map'][mol] | |
| mol_property['role'] = mol_role | |
| mol_property_batch.append(mol_property) | |
| sampled_batch.append(mol_property_batch) | |
| return sampled_batch | |
| def generate_batch_uniform_mol(self, rxn_num=1000, k=4): | |
| valid_mols = list(self.mol_counter.elements()) | |
| assert rxn_num*k <= len(valid_mols) | |
| sampled_batch = [] | |
| sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k) | |
| for i in range(rxn_num): | |
| sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]]) | |
| return sampled_batch | |
| def generate_batch_single(self, rxn_num=1000): | |
| valid_mols = list(self.mol_counter.elements()) | |
| sampled_mols = random.sample(valid_mols, rxn_num) | |
| total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols] | |
| return total_valid_mols | |
| # visaulize probability for molecules in caption dataset. | |
| def visualize_mol_distribution(self): | |
| prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()} | |
| N = len(prob_dict) | |
| M = len(self.reaction_data) | |
| assert N == len(self.mol_property_map) | |
| print(f'Number of molecules in Caption Dataset: {N}') | |
| print(f'Number of Reactions in Reaction Dataset: {M}') | |
| # prob distribution for molecules | |
| for rxn_dict in self.reaction_data: | |
| for mol, weight in rxn_dict['mol_weight'].items(): | |
| prob_dict[mol] += weight * rxn_dict['weight'] | |
| # sum of prob_dict.values() should already be 1. | |
| prob_values = np.array(list(prob_dict.values())) | |
| prob_values *= N | |
| # prob distribution for reactions | |
| rxn_weights = np.array([d['weight'] for d in self.reaction_data]) | |
| # sum of rxn_weights should already be 1. | |
| rxn_weights *= M | |
| return prob_values, rxn_weights | |
| # visaulize the frequency for molecules in caption dataset. | |
| def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100): | |
| sampled_mols_counter = Counter() | |
| sampled_rxns_counter = Counter() | |
| for _ in range(epochs): | |
| rxn_indices = self.sample_rxn_index(rxn_num) | |
| sampled_rxns_counter.update(rxn_indices) | |
| for index in rxn_indices: | |
| rxn = self.reaction_data[index] | |
| if len(rxn['valid_mols']) ==0: | |
| continue | |
| valid_mols, weights = zip(*rxn['mol_weight'].items()) | |
| mol_batch = self.choose_mol(valid_mols, k=k, weights=weights) | |
| sampled_mols_counter.update(mol_batch) | |
| sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())]) | |
| sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())]) | |
| return sampled_mols_count, sampled_rxns_count | |
| def _randomly(self, func, *args, **kwargs): | |
| # make fake weights and backup the weights | |
| for rxn_dict in self.reaction_data: | |
| rxn_dict['weight_bak'] = rxn_dict['weight'] | |
| rxn_dict['weight'] = 1/len(self.reaction_data) | |
| rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight'] | |
| rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']} | |
| # run the function | |
| result = func(*args, **kwargs) | |
| # weights recovery | |
| for rxn_dict in self.reaction_data: | |
| rxn_dict['weight'] = rxn_dict['weight_bak'] | |
| del rxn_dict['weight_bak'] | |
| rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak'] | |
| del rxn_dict['mol_weight_bak'] | |
| return result | |