Spaces:
Runtime error
Runtime error
| from data_provider.context_gen import * | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="A simple argument parser") | |
| # Script arguments | |
| parser.add_argument('--name', default='none', type=str) | |
| parser.add_argument('--seed', default=0, type=int) | |
| parser.add_argument('--epochs', default=100, type=int) | |
| parser.add_argument('--chunk_size', default=100, type=int) | |
| parser.add_argument('--rxn_num', default=50000, type=int) | |
| parser.add_argument('--k', default=4, type=int) | |
| parser.add_argument('--root', default='data/pretrain_data', type=str) | |
| args = parser.parse_args() | |
| return args | |
| def pad_shorter_array(arr1, arr2): | |
| len1 = arr1.shape[0] | |
| len2 = arr2.shape[0] | |
| if len1 > len2: | |
| arr2 = np.pad(arr2, (0, len1 - len2), 'constant') | |
| elif len2 > len1: | |
| arr1 = np.pad(arr1, (0, len2 - len1), 'constant') | |
| return arr1, arr2 | |
| def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'): | |
| num_full_chunks = len(values) // chunk_size | |
| values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1) | |
| values = np.sort(values)[::-1] | |
| plt.figure(figsize=(10, 4), dpi=100) | |
| x = np.arange(len(values)) | |
| plt.bar(x, values, color=color) | |
| current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) | |
| plt.xticks((current_values/chunk_size).astype(int), current_values) | |
| plt.ylabel('Molecule Frequency', fontsize=20) | |
| if x_lim: | |
| plt.xlim(*x_lim) | |
| if y_lim: | |
| plt.ylim(*y_lim) | |
| plt.tick_params(axis='both', which='major', labelsize=12) | |
| plt.tight_layout(pad=0.5) | |
| plt.savefig(target_path) | |
| print(f'Figure saved to {target_path}') | |
| plt.clf() | |
| def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100): | |
| num_full_chunks = len(list1) // chunk_size | |
| list1, list2 = pad_shorter_array(list1, list2) | |
| values1, values2 = [ | |
| np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1] | |
| for values in (list1, list2)] | |
| plt.figure(figsize=(10, 6), dpi=100) | |
| x = np.arange(len(values1)) | |
| plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6) | |
| plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5) | |
| current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int) | |
| plt.xticks((current_values/chunk_size).astype(int), current_values) | |
| plt.ylabel('Molecule Frequency', fontsize=20) | |
| if x_lim: | |
| plt.xlim(*x_lim) | |
| if y_lim: | |
| plt.ylim(*y_lim) | |
| plt.tick_params(axis='both', which='major', labelsize=18) | |
| plt.tight_layout(pad=0.5) | |
| plt.legend(fontsize=24, loc='upper right') | |
| plt.savefig(target_path) | |
| print(f'Figure saved to {target_path}') | |
| plt.clf() | |
| def statistics(args): | |
| if args.seed: | |
| set_random_seed(args.seed) | |
| # 1141864 rxns from ord | |
| # 1120773 rxns from uspto | |
| cluster = Reaction_Cluster(args.root) | |
| rxn_num = len(cluster.reaction_data) | |
| abstract_num = 0 | |
| property_num = 0 | |
| calculated_property_num = 0 | |
| experimental_property_num = 0 | |
| avg_calculated_property_len = 0 | |
| avg_experimental_property_len = 0 | |
| mol_set = set() | |
| for rxn_dict in cluster.reaction_data: | |
| for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: | |
| for mol in rxn_dict[key]: | |
| mol_set.add(mol) | |
| mol_num = len(mol_set) | |
| for mol_dict in cluster.property_data: | |
| if 'abstract' in mol_dict: | |
| abstract_num += 1 | |
| if 'property' in mol_dict: | |
| property_num += 1 | |
| if 'Experimental Properties' in mol_dict['property']: | |
| experimental_property_num += 1 | |
| avg_experimental_property_len += len(mol_dict['property']['Experimental Properties']) | |
| if 'Computed Properties' in mol_dict['property']: | |
| calculated_property_num += 1 | |
| avg_calculated_property_len += len(mol_dict['property']['Computed Properties']) | |
| print(f'Reaction Number: {rxn_num}') | |
| print(f'Molecule Number: {mol_num}') | |
| print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)') | |
| print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)') | |
| print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule') | |
| print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule') | |
| def visualize(args): | |
| if args.seed: | |
| set_random_seed(args.seed) | |
| cluster = Reaction_Cluster(args.root) | |
| prob_values, rxn_weights = cluster.visualize_mol_distribution() | |
| rand_prob_values, rand_rxn_weights = cluster._randomly( | |
| cluster.visualize_mol_distribution | |
| ) | |
| fig_root = f'results/{args.name}/' | |
| plot_distribution(prob_values, fig_root+'mol_distribution.pdf') | |
| plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf') | |
| plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf') | |
| plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf') | |
| plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5)) | |
| plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf') | |
| def visualize_frequency(args): | |
| if args.seed: | |
| set_random_seed(args.seed) | |
| fig_root = f'results/{args.name}/' | |
| name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}' | |
| cache_path = f'{fig_root}/freq_{name_suffix}.npy' | |
| if os.path.exists(cache_path): | |
| mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True) | |
| else: | |
| cluster = Reaction_Cluster(args.root) | |
| mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs) | |
| rand_mol_freq, rand_rxn_freq = cluster._randomly( | |
| cluster.visualize_mol_frequency, | |
| rxn_num=args.rxn_num, k=args.k, epochs=args.epochs | |
| ) | |
| np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True) | |
| color1 = '#FA7F6F' | |
| color2 = '#80AFBF' | |
| color3 = '#FFBE7A' | |
| plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) | |
| # plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) | |
| plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2) | |
| # plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1) | |
| plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size) | |
| # plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size) | |
| if __name__=='__main__': | |
| args = parse_args() | |
| print(args, flush=True) | |
| # statistics(args) | |
| # visualize(args) | |
| visualize_frequency(args) |