ipd commited on
Commit
6747ba1
·
1 Parent(s): 90ca16d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -0
  2. app.py +717 -0
  3. data/.DS_Store +0 -0
  4. data/bace/test.csv +0 -0
  5. data/bace/train.csv +0 -0
  6. data/bace/valid.csv +0 -0
  7. data/esol/test.csv +109 -0
  8. data/esol/train.csv +0 -0
  9. img/.DS_Store +0 -0
  10. img/img1.png +0 -0
  11. img/img2.png +0 -0
  12. img/img3.png +0 -0
  13. img/img4.png +0 -0
  14. img/img5.png +0 -0
  15. img/introduction.png +0 -0
  16. img/latent_multi_bace.png +0 -0
  17. log.csv +1 -0
  18. models/.DS_Store +0 -0
  19. models/__pycache__/fm4m.cpython-310.pyc +0 -0
  20. models/fm4m.py +663 -0
  21. models/mhg_model/.DS_Store +0 -0
  22. models/mhg_model/README.md +75 -0
  23. models/mhg_model/__init__.py +5 -0
  24. models/mhg_model/__pycache__/__init__.cpython-310.pyc +0 -0
  25. models/mhg_model/__pycache__/load.cpython-310.pyc +0 -0
  26. models/mhg_model/graph_grammar/__init__.py +19 -0
  27. models/mhg_model/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
  28. models/mhg_model/graph_grammar/__pycache__/hypergraph.cpython-310.pyc +0 -0
  29. models/mhg_model/graph_grammar/algo/__init__.py +20 -0
  30. models/mhg_model/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc +0 -0
  31. models/mhg_model/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc +0 -0
  32. models/mhg_model/graph_grammar/algo/tree_decomposition.py +821 -0
  33. models/mhg_model/graph_grammar/graph_grammar/__init__.py +20 -0
  34. models/mhg_model/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
  35. models/mhg_model/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc +0 -0
  36. models/mhg_model/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc +0 -0
  37. models/mhg_model/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc +0 -0
  38. models/mhg_model/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc +0 -0
  39. models/mhg_model/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc +0 -0
  40. models/mhg_model/graph_grammar/graph_grammar/base.py +30 -0
  41. models/mhg_model/graph_grammar/graph_grammar/corpus.py +152 -0
  42. models/mhg_model/graph_grammar/graph_grammar/hrg.py +1065 -0
  43. models/mhg_model/graph_grammar/graph_grammar/symbols.py +180 -0
  44. models/mhg_model/graph_grammar/graph_grammar/utils.py +130 -0
  45. models/mhg_model/graph_grammar/hypergraph.py +544 -0
  46. models/mhg_model/graph_grammar/io/__init__.py +20 -0
  47. models/mhg_model/graph_grammar/io/__pycache__/__init__.cpython-310.pyc +0 -0
  48. models/mhg_model/graph_grammar/io/__pycache__/smi.cpython-310.pyc +0 -0
  49. models/mhg_model/graph_grammar/io/smi.py +559 -0
  50. models/mhg_model/graph_grammar/nn/__init__.py +11 -0
README.md CHANGED
@@ -8,6 +8,10 @@ sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ models:
12
+ - ibm/materials.smi-ted
13
+ - ibm/materials.selfies-ted
14
+ - ibm/materials.mhg-ged
15
  ---
16
 
17
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import matplotlib.pyplot as plt
4
+ from PIL import Image
5
+ from rdkit.Chem import Descriptors, QED, Draw
6
+ from rdkit.Chem.Crippen import MolLogP
7
+ import pandas as pd
8
+ from rdkit.Contrib.SA_Score import sascorer
9
+ from rdkit.Chem import DataStructs, AllChem
10
+ from transformers import BartForConditionalGeneration, AutoTokenizer, AutoModel
11
+ from transformers.modeling_outputs import BaseModelOutput
12
+ import selfies as sf
13
+ from rdkit import Chem
14
+ import torch
15
+ import numpy as np
16
+ import umap
17
+ import pickle
18
+ import xgboost as xgb
19
+ from sklearn.svm import SVR
20
+ from sklearn.linear_model import LinearRegression
21
+ from sklearn.kernel_ridge import KernelRidge
22
+ import json
23
+
24
+ import os
25
+
26
+ os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
27
+
28
+ # my_theme = gr.Theme.from_hub("ysharma/steampunk")
29
+ # my_theme = gr.themes.Glass()
30
+
31
+ """
32
+ # カスタムテーマ設定
33
+ theme = gr.themes.Default().set(
34
+ body_background_fill="#000000", # 背景色を黒に設定
35
+ text_color="#FFFFFF", # テキスト色を白に設定
36
+ )
37
+ """
38
+ """
39
+ import sys
40
+ sys.path.append("models")
41
+ sys.path.append("../models")
42
+ sys.path.append("../")"""
43
+
44
+
45
+ # Get the current file's directory
46
+ base_dir = os.path.dirname(__file__)
47
+ print("Base Dir : ", base_dir)
48
+
49
+ import models.fm4m as fm4m
50
+
51
+
52
+ # Function to display molecule image from SMILES
53
+ def smiles_to_image(smiles):
54
+ mol = Chem.MolFromSmiles(smiles)
55
+ if mol:
56
+ img = Draw.MolToImage(mol)
57
+ return img
58
+ return None
59
+
60
+
61
+ # Function to get canonical SMILES
62
+ def get_canonical_smiles(smiles):
63
+ mol = Chem.MolFromSmiles(smiles)
64
+ if mol:
65
+ return Chem.MolToSmiles(mol, canonical=True)
66
+ return None
67
+
68
+
69
+ # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
70
+ smiles_image_mapping = {
71
+ "Mol 1": {"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png"},
72
+ # Example SMILES for ethanol
73
+ "Mol 2": {"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png"},
74
+ # Example SMILES for butane
75
+ "Mol 3": {"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
76
+ "image": "img/img3.png"}, # Example SMILES for ethylamine
77
+ "Mol 4": {"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png"},
78
+ # Example SMILES for diethyl ether
79
+ "Mol 5": {"smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png"} # Example SMILES for chloroethane
80
+ }
81
+
82
+ datasets = [" ","BACE", "ESOL", "Custom Dataset"]
83
+
84
+ models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
85
+
86
+ fusion_available = ["Concat"]
87
+
88
+ global log_df
89
+ log_df = pd.DataFrame(columns=["Selected Models", "Dataset", "Task", "Result"])
90
+
91
+
92
+ def log_selection(models, dataset, task_type, result, log_df):
93
+ # Append the new entry to the DataFrame
94
+ new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_type, "Result": result}
95
+ updated_log_df = log_df.append(new_entry, ignore_index=True)
96
+ return updated_log_df
97
+
98
+
99
+ # Function to handle evaluation and logging
100
+ def save_rep(models, dataset, task_type, eval_output):
101
+ return
102
+ def evaluate_and_log(models, dataset, task_type, eval_output):
103
+ task_dic = {'Classification': 'CLS', 'Regression': 'RGR'}
104
+ result = f"{eval_output}"#display_eval(models, dataset, task_type, fusion_type=None)
105
+ result = result.replace(" Score", "")
106
+
107
+ new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_dic[task_type], "Result": result}
108
+ new_entry_df = pd.DataFrame([new_entry])
109
+
110
+ log_df = pd.read_csv('log.csv', index_col=0)
111
+ log_df = pd.concat([new_entry_df, log_df])
112
+
113
+ log_df.to_csv('log.csv')
114
+
115
+ return log_df
116
+
117
+
118
+ log_df = pd.read_csv('log.csv', index_col=0)
119
+
120
+
121
+ # Load images for selection
122
+ def load_image(path):
123
+ return Image.open(smiles_image_mapping[path]["image"])# Image.1open(path)
124
+
125
+
126
+ # Function to handle image selection
127
+ def handle_image_selection(image_key):
128
+ smiles = smiles_image_mapping[image_key]["smiles"]
129
+ mol_image = smiles_to_image(smiles)
130
+ return smiles, mol_image
131
+
132
+
133
+ def calculate_properties(smiles):
134
+ mol = Chem.MolFromSmiles(smiles)
135
+ if mol:
136
+ qed = QED.qed(mol)
137
+ logp = MolLogP(mol)
138
+ sa = sascorer.calculateScore(mol)
139
+ wt = Descriptors.MolWt(mol)
140
+ return qed, sa, logp, wt
141
+ return None, None, None, None
142
+
143
+
144
+ # Function to calculate Tanimoto similarity
145
+ def calculate_tanimoto(smiles1, smiles2):
146
+ mol1 = Chem.MolFromSmiles(smiles1)
147
+ mol2 = Chem.MolFromSmiles(smiles2)
148
+ if mol1 and mol2:
149
+ # fp1 = FingerprintMols.FingerprintMol(mol1)
150
+ # fp2 = FingerprintMols.FingerprintMol(mol2)
151
+ fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
152
+ fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
153
+ return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
154
+ return None
155
+
156
+
157
+ #with open("models/selfies_model/bart-2908.pickle", "rb") as input_file:
158
+ # gen_model, gen_tokenizer = pickle.load(input_file)
159
+
160
+ gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
161
+ gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
162
+
163
+
164
+ def generate(latent_vector, mask):
165
+ encoder_outputs = BaseModelOutput(latent_vector)
166
+ decoder_output = gen_model.generate(encoder_outputs=encoder_outputs, attention_mask=mask,
167
+ max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1)
168
+ selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
169
+ outs = []
170
+ for i in selfies:
171
+ outs.append(sf.decoder(i.replace("] [", "][")))
172
+ return outs
173
+
174
+
175
+ def perturb_latent(latent_vecs, noise_scale=0.5):
176
+ modified_vec = torch.tensor(np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
177
+ dtype=torch.float32) + latent_vecs
178
+ return modified_vec
179
+
180
+
181
+ def encode(selfies):
182
+ encoding = gen_tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
183
+ input_ids = encoding['input_ids']
184
+ attention_mask = encoding['attention_mask']
185
+ outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
186
+ model_output = outputs.last_hidden_state
187
+
188
+ """input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
189
+ sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
190
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
191
+ model_output = sum_embeddings / sum_mask"""
192
+ return model_output, attention_mask
193
+
194
+
195
+ # Function to generate canonical SMILES and molecule image
196
+ def generate_canonical(smiles):
197
+ s = sf.encoder(smiles)
198
+ selfie = s.replace("][", "] [")
199
+ latent_vec, mask = encode([selfie])
200
+ gen_mol = None
201
+ for i in range(5, 51):
202
+ noise = i / 10
203
+ perturbed_latent = perturb_latent(latent_vec, noise_scale=noise)
204
+ gen = generate(perturbed_latent, mask)
205
+ gen_mol = Chem.MolToSmiles(Chem.MolFromSmiles(gen[0]))
206
+ if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
207
+
208
+ if gen_mol:
209
+ # Calculate properties for ref and gen molecules
210
+ ref_properties = calculate_properties(smiles)
211
+ gen_properties = calculate_properties(gen_mol)
212
+ tanimoto_similarity = calculate_tanimoto(smiles, gen_mol)
213
+
214
+ # Prepare the table with ref mol and gen mol
215
+ data = {
216
+ "Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
217
+ "Reference Mol": [ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3],
218
+ tanimoto_similarity],
219
+ "Generated Mol": [gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], ""]
220
+ }
221
+ df = pd.DataFrame(data)
222
+
223
+ # Display molecule image of canonical smiles
224
+ mol_image = smiles_to_image(gen_mol)
225
+
226
+ return df, gen_mol, mol_image
227
+ return "Invalid SMILES", None, None
228
+
229
+
230
+ # Function to display evaluation score
231
+ def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
232
+ result = None
233
+
234
+ try:
235
+ downstream_model = downstream.split("*")[0].lstrip()
236
+ downstream_model = downstream_model.rstrip()
237
+ hyp_param = downstream.split("*")[-1].lstrip()
238
+ hyp_param = hyp_param.rstrip()
239
+ hyp_param = hyp_param.replace("nan", "float('nan')")
240
+ params = eval(hyp_param)
241
+ except:
242
+ downstream_model = downstream.split("*")[0].lstrip()
243
+ downstream_model = downstream_model.rstrip()
244
+ params = None
245
+
246
+
247
+
248
+
249
+ try:
250
+ if not selected_models:
251
+ return "Please select at least one enabled model."
252
+
253
+ if task_type == "Classification":
254
+ global roc_auc, fpr, tpr, x_batch, y_batch
255
+ elif task_type == "Regression":
256
+ global RMSE, y_batch_test, y_prob
257
+
258
+ if len(selected_models) > 1:
259
+ if task_type == "Classification":
260
+ #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
261
+ # downstream_model="XGBClassifier",
262
+ # dataset=dataset.lower())
263
+ if downstream_model == "Default Settings":
264
+ downstream_model = "DefaultClassifier"
265
+ params = None
266
+ result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
267
+ downstream_model=downstream_model,
268
+ params = params,
269
+ dataset=dataset)
270
+
271
+ elif task_type == "Regression":
272
+ #result, RMSE, y_batch_test, y_prob = fm4m.multi_modal(model_list=selected_models,
273
+ # downstream_model="XGBRegressor",
274
+ # dataset=dataset.lower())
275
+
276
+ if downstream_model == "Default Settings":
277
+ downstream_model = "DefaultRegressor"
278
+ params = None
279
+
280
+ result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
281
+ downstream_model=downstream_model,
282
+ params=params,
283
+ dataset=dataset)
284
+
285
+ else:
286
+ if task_type == "Classification":
287
+ #result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
288
+ # downstream_model="XGBClassifier",
289
+ # dataset=dataset.lower())
290
+ if downstream_model == "Default Settings":
291
+ downstream_model = "DefaultClassifier"
292
+ params = None
293
+
294
+ result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
295
+ downstream_model=downstream_model,
296
+ params=params,
297
+ dataset=dataset)
298
+
299
+ elif task_type == "Regression":
300
+ #result, RMSE, y_batch_test, y_prob = fm4m.single_modal(model=selected_models[0],
301
+ # downstream_model="XGBRegressor",
302
+ # dataset=dataset.lower())
303
+
304
+ if downstream_model == "Default Settings":
305
+ downstream_model = "DefaultRegressor"
306
+ params = None
307
+
308
+ result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
309
+ downstream_model=downstream_model,
310
+ params=params,
311
+ dataset=dataset)
312
+
313
+ if result == None:
314
+ result = "Data & Model Setting is incorrect"
315
+ except Exception as e:
316
+ return f"An error occurred: {e}"
317
+ return f"{result}"
318
+
319
+
320
+ # Function to handle plot display
321
+ def display_plot(plot_type):
322
+ fig, ax = plt.subplots()
323
+
324
+ if plot_type == "Latent Space":
325
+ global x_batch, y_batch
326
+ ax.set_title("T-SNE Plot")
327
+ # reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
328
+ # features_umap = reducer.fit_transform(x_batch[:500])
329
+ # x = y_batch.values[:500]
330
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
331
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
332
+ class_0 = x_batch # features_umap[index_0]
333
+ class_1 = y_batch # features_umap[index_1]
334
+
335
+ """with open("latent_multi_bace.pkl", "rb") as f:
336
+ class_0, class_1 = pickle.load(f)
337
+ """
338
+ plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
339
+ plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
340
+
341
+ ax.set_xlabel('Feature 1')
342
+ ax.set_ylabel('Feature 2')
343
+ ax.set_title('Dataset Distribution')
344
+
345
+ elif plot_type == "ROC-AUC":
346
+ global roc_auc, fpr, tpr
347
+ ax.set_title("ROC-AUC Curve")
348
+ try:
349
+ ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
350
+ ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
351
+ ax.set_xlim([0.0, 1.0])
352
+ ax.set_ylim([0.0, 1.05])
353
+ except:
354
+ pass
355
+ ax.set_xlabel('False Positive Rate')
356
+ ax.set_ylabel('True Positive Rate')
357
+ ax.set_title('Receiver Operating Characteristic')
358
+ ax.legend(loc='lower right')
359
+
360
+ elif plot_type == "Parity Plot":
361
+ global RMSE, y_batch_test, y_prob
362
+ ax.set_title("Parity plot")
363
+
364
+ # change format
365
+ try:
366
+ print(y_batch_test)
367
+ print(y_prob)
368
+ y_batch_test = np.array(y_batch_test, dtype=float)
369
+ y_prob = np.array(y_prob, dtype=float)
370
+ ax.scatter(y_batch_test, y_prob, color="blue", label=f"Predicted vs Actual (RMSE: {RMSE:.4f})")
371
+ min_val = min(min(y_batch_test), min(y_prob))
372
+ max_val = max(max(y_batch_test), max(y_prob))
373
+ ax.plot([min_val, max_val], [min_val, max_val], 'r-')
374
+
375
+ except:
376
+
377
+ y_batch_test = []
378
+ y_prob = []
379
+ RMSE = None
380
+ print(y_batch_test)
381
+ print(y_prob)
382
+
383
+
384
+
385
+
386
+
387
+ ax.set_xlabel('Actual Values')
388
+ ax.set_ylabel('Predicted Values')
389
+
390
+ ax.legend(loc='lower right')
391
+ return fig
392
+
393
+
394
+ # Predefined dataset paths (these should be adjusted to your file paths)
395
+ predefined_datasets = {
396
+ "BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
397
+ "ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
398
+ }
399
+
400
+
401
+ # Function to load a predefined dataset from the local path
402
+ def load_predefined_dataset(dataset_name):
403
+ val = predefined_datasets.get(dataset_name)
404
+ try: file_path = val.split(",")[0]
405
+ except:file_path=False
406
+
407
+ if file_path:
408
+ df = pd.read_csv(file_path)
409
+ return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), f"{dataset_name.lower()}"
410
+ return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), f"Dataset not found"
411
+
412
+
413
+ # Function to display the head of the uploaded CSV file
414
+ def display_csv_head(file):
415
+ if file is not None:
416
+ # Load the CSV file into a DataFrame
417
+ df = pd.read_csv(file.name)
418
+ return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns))
419
+ return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
420
+
421
+
422
+ # Function to handle dataset selection (predefined or custom)
423
+ def handle_dataset_selection(selected_dataset):
424
+ if selected_dataset == "Custom Dataset":
425
+ # Show file upload fields for train and test datasets if "Custom Dataset" is selected
426
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
427
+ visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
428
+ else:
429
+ #[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
430
+ # input_column_selector, output_column_selector]
431
+
432
+
433
+
434
+ # Load the predefined dataset from its local path
435
+ #return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
436
+ # visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
437
+ #return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(
438
+ # visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
439
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(
440
+ visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
441
+
442
+
443
+ # Function to select input and output columns and display a message
444
+ def select_columns(input_column, output_column, train_data, test_data,dataset_name):
445
+ if input_column and output_column:
446
+ return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
447
+ return "Please select both input and output columns."
448
+
449
+ def set_dataname(dataset_name, dataset_selector ):
450
+ if dataset_selector == "Custom Dataset":
451
+ return f"{dataset_name}"
452
+ return f"{dataset_selector}"
453
+
454
+ # Function to create model based on user input
455
+ def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None):
456
+ if model_name == "XGBClassifier":
457
+ model = xgb.XGBClassifier(objective='binary:logistic',eval_metric= 'auc', max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
458
+ elif model_name == "SVR":
459
+ model = SVR(degree=degree, kernel=kernel)
460
+ elif model_name == "Kernel Ridge":
461
+ model = KernelRidge(alpha=alpha, degree=degree, kernel=kernel)
462
+ elif model_name == "Linear Regression":
463
+ model = LinearRegression()
464
+ elif model_name == "Default - Auto":
465
+ model = "Default Settings"
466
+ return f"{model}"
467
+ else:
468
+ return "Model not supported."
469
+
470
+ return f"{model_name} * {model.get_params()}"
471
+ def model_selector(model_name):
472
+ # Dynamically return the appropriate hyperparameter components based on the selected model
473
+ if model_name == "XGBClassifier":
474
+ return (
475
+ gr.Slider(1, 10, label="max_depth"),
476
+ gr.Slider(50, 500, label="n_estimators"),
477
+ gr.Slider(0.1, 10.0, step=0.1, label="alpha")
478
+ )
479
+ elif model_name == "SVR":
480
+ return (
481
+ gr.Slider(1, 5, label="degree"),
482
+ gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
483
+ )
484
+ elif model_name == "Kernel Ridge":
485
+ return (
486
+ gr.Slider(0.1, 10.0, step=0.1, label="alpha"),
487
+ gr.Slider(1, 5, label="degree"),
488
+ gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
489
+ )
490
+ elif model_name == "Linear Regression":
491
+ return () # No hyperparameters for Linear Regression
492
+ else:
493
+ return ()
494
+
495
+
496
+
497
+ # Define the Gradio layout
498
+ # with gr.Blocks(theme=my_theme) as demo:
499
+ with gr.Blocks() as demo:
500
+ with gr.Row():
501
+ # Left Column
502
+ with gr.Column():
503
+ gr.HTML('''
504
+ <div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
505
+ <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
506
+ </div>
507
+ ''')
508
+ # gr.Markdown("## Data & Model Setting")
509
+ #dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dat")
510
+
511
+ # Dropdown menu for predefined datasets including "Custom Dataset" option
512
+ dataset_selector = gr.Dropdown(label="Select Dataset",
513
+ choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
514
+ # Display the message for selected columns
515
+ selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=False)
516
+
517
+ with gr.Accordion("Dataset Settings", open=True):
518
+ # File upload options for custom dataset (train and test)
519
+ dataset_name = gr.Textbox(label="Dataset Name", visible=False)
520
+ train_file = gr.File(label="Upload Custom Train Dataset", file_types=[".csv"], visible=False)
521
+ train_display = gr.Dataframe(label="Train Dataset Preview (First 5 Rows)", visible=False, interactive=False)
522
+
523
+ test_file = gr.File(label="Upload Custom Test Dataset", file_types=[".csv"], visible=False)
524
+ test_display = gr.Dataframe(label="Test Dataset Preview (First 5 Rows)", visible=False, interactive=False)
525
+
526
+ # Predefined dataset displays
527
+ predefined_display = gr.Dataframe(label="Predefined Dataset Preview (First 5 Rows)", visible=False,
528
+ interactive=False)
529
+
530
+
531
+
532
+ # Dropdowns for selecting input and output columns for the custom dataset
533
+ input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
534
+ output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
535
+
536
+ #selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=True)
537
+
538
+ # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
539
+ dataset_selector.change(handle_dataset_selection,
540
+ inputs=dataset_selector,
541
+ outputs=[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
542
+ input_column_selector, output_column_selector])
543
+
544
+ # When a predefined dataset is selected, load its head and update column selectors
545
+ dataset_selector.change(load_predefined_dataset,
546
+ inputs=dataset_selector,
547
+ outputs=[predefined_display, input_column_selector, output_column_selector, selected_columns_message])
548
+
549
+ # When a custom train file is uploaded, display its head and update column selectors
550
+ train_file.change(display_csv_head, inputs=train_file,
551
+ outputs=[train_display, input_column_selector, output_column_selector])
552
+
553
+ # When a custom test file is uploaded, display its head
554
+ test_file.change(display_csv_head, inputs=test_file,
555
+ outputs=[test_display, input_column_selector, output_column_selector])
556
+
557
+ dataset_selector.change(set_dataname,
558
+ inputs=[dataset_name, dataset_selector],
559
+ outputs=dataset_name)
560
+
561
+ # Update the selected columns information when dropdown values are changed
562
+ input_column_selector.change(select_columns,
563
+ inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
564
+ outputs=selected_columns_message)
565
+
566
+ output_column_selector.change(select_columns,
567
+ inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
568
+ outputs=selected_columns_message)
569
+
570
+ model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
571
+
572
+ # Add disabled checkboxes for GNN and FNN
573
+ # gnn_checkbox = gr.Checkbox(label="GNN (Disabled)", value=False, interactive=False)
574
+ # fnn_checkbox = gr.Checkbox(label="FNN (Disabled)", value=False, interactive=False)
575
+
576
+ task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
577
+
578
+ ####### adding hyper parameter tuning ###########
579
+ model_name = gr.Dropdown(["Default - Auto", "XGBClassifier", "SVR", "Kernel Ridge", "Linear Regression"], label="Select Downstream Model")
580
+ with gr.Accordion("Downstream Hyperparameter Settings", open=True):
581
+ # Create placeholders for hyperparameter components
582
+ max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth")
583
+ n_estimators = gr.Slider(100, 5000, step=100, visible=False, label="n_estimators")
584
+ alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
585
+ degree = gr.Slider(1, 20, step=1,visible=False, label="degree")
586
+ kernel = gr.Dropdown(choices=["rbf", "poly", "linear"], visible=False, label="kernel")
587
+
588
+ # Output textbox
589
+ output = gr.Textbox(label="Loaded Parameters")
590
+
591
+
592
+ # Dynamically show relevant hyperparameters based on selected model
593
+ def update_hyperparameters(model_name):
594
+ if model_name == "XGBClassifier":
595
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
596
+ visible=False), gr.update(visible=False)
597
+ elif model_name == "SVR":
598
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
599
+ visible=True), gr.update(visible=True)
600
+ elif model_name == "Kernel Ridge":
601
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(
602
+ visible=True), gr.update(visible=True)
603
+ elif model_name == "Linear Regression":
604
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
605
+ visible=False), gr.update(visible=False)
606
+ elif model_name == "Default - Auto":
607
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
608
+ visible=False), gr.update(visible=False)
609
+
610
+
611
+ # When model is selected, update which hyperparameters are visible
612
+ model_name.change(update_hyperparameters, inputs=[model_name],
613
+ outputs=[max_depth, n_estimators, alpha, degree, kernel])
614
+
615
+ # Submit button to create the model with selected hyperparameters
616
+ submit_button = gr.Button("Create Downstream Model")
617
+
618
+
619
+ # Function to handle model creation based on input parameters
620
+ def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
621
+ if model_name == "XGBClassifier":
622
+ return create_model(model_name, max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
623
+ elif model_name == "SVR":
624
+ return create_model(model_name, degree=degree, kernel=kernel)
625
+ elif model_name == "Kernel Ridge":
626
+ return create_model(model_name, alpha=alpha, degree=degree, kernel=kernel)
627
+ elif model_name == "Linear Regression":
628
+ return create_model(model_name)
629
+ elif model_name == "Default - Auto":
630
+ return create_model(model_name)
631
+
632
+ # When the submit button is clicked, run the on_submit function
633
+ submit_button.click(on_submit, inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
634
+ outputs=output)
635
+ ###### End of hyper param tuning #########
636
+
637
+ fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
638
+
639
+
640
+
641
+ eval_button = gr.Button("Train downstream model")
642
+ #eval_button.style(css_class="custom-button-left")
643
+
644
+ # Middle Column
645
+ with gr.Column():
646
+ gr.HTML('''
647
+ <div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
648
+ <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
649
+ </div>
650
+ ''')
651
+ # gr.Markdown("## Downstream task Result")
652
+ eval_output = gr.Textbox(label="Train downstream model")
653
+
654
+ plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
655
+ plot_output = gr.Plot(label="Visualization")#, height=250, width=250)
656
+
657
+ #download_rep = gr.Button("Download representation")
658
+
659
+ create_log = gr.Button("Store log")
660
+
661
+ log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
662
+
663
+ eval_button.click(display_eval,
664
+ inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton],
665
+ outputs=eval_output)
666
+
667
+ plot_radio.change(display_plot, inputs=plot_radio, outputs=plot_output)
668
+
669
+
670
+ # Function to gather selected models
671
+ def gather_selected_models(*models):
672
+ selected = [model for model in models if model]
673
+ return selected
674
+
675
+
676
+ create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
677
+ outputs=log_table)
678
+ #download_rep.click(save_rep, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
679
+ # outputs=None)
680
+
681
+ # Right Column
682
+ with gr.Column():
683
+ gr.HTML('''
684
+ <div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
685
+ <h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
686
+ </div>
687
+ ''')
688
+ # gr.Markdown("## Molecular Generation")
689
+ smiles_input = gr.Textbox(label="Input SMILES String")
690
+ image_display = gr.Image(label="Molecule Image", height=250, width=250)
691
+ # Show images for selection
692
+ with gr.Accordion("Select from sample molecules", open=False):
693
+ image_selector = gr.Radio(
694
+ choices=list(smiles_image_mapping.keys()),
695
+ label="Select from sample molecules",
696
+ value=None,
697
+ #item_images=[load_image(smiles_image_mapping[key]["image"]) for key in smiles_image_mapping.keys()]
698
+ )
699
+ image_selector.change(load_image, image_selector, image_display)
700
+ generate_button = gr.Button("Generate")
701
+ gen_image_display = gr.Image(label="Generated Molecule Image", height=250, width=250)
702
+ generated_output = gr.Textbox(label="Generated Output")
703
+ property_table = gr.Dataframe(label="Molecular Properties Comparison")
704
+
705
+
706
+
707
+ # Handle image selection
708
+ image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
709
+ smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
710
+
711
+ # Generate button to display canonical SMILES and molecule image
712
+ generate_button.click(generate_canonical, inputs=smiles_input,
713
+ outputs=[property_table, generated_output, gen_image_display])
714
+
715
+
716
+ if __name__ == "__main__":
717
+ demo.launch(share=True)
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/bace/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/esol/test.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,selfies,prop,smiles
2
+ 0,[Cl] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [C] [C] [C] [C] [Branch1] [Branch2] [C] [O] [C] [Ring1] [=Branch1] [Ring1] [Ring1] [C] [Ring1] [Branch2] [C] [Ring1] [=C] [Branch1] [C] [Cl] [C] [Ring1] [=N] [Branch1] [C] [Cl] [Cl],-4.533,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
3
+ 1,[C] [C] [C] [C] [C] [=O],-1.103,CCCCC=O
4
+ 2,[O] [C] [C] [C] [C] [=C],-0.7909999999999999,OCCCC=C
5
+ 3,[C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [N] [=C] [C] [Branch1] [C] [N] [=C] [Branch1] [C] [Br] [C] [Ring1] [Branch2] [=O],-3.005,c1ccccc1n2ncc(N)c(Br)c2(=O)
6
+ 4,[N] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-1.231,Nc1ccc(O)cc1
7
+ 5,[C] [C] [Branch1] [C] [C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.817,CC(C)CCOC(=O)C
8
+ 6,[C] [O] [P] [=Branch1] [C] [=S] [Branch1] [Ring1] [O] [C] [S] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=O],-2.087,COP(=S)(OC)SCC(=O)N(C)C=O
9
+ 7,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=Branch1] [Ring2] [=C] [Ring1] [#Branch1] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-6.312,Clc1ccc(Cl)c(c1)c2ccc(Cl)c(Cl)c2
10
+ 8,[C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [N] [=C] [C] [=N] [C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [Cl],-4.438,c2(Cl)c(Cl)c(Cl)c1nccnc1c2(Cl)
11
+ 9,[C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [N] [=C] [Branch1] [=Branch1] [N] [=C] [Ring1] [#Branch1] [O] [N] [Branch1] [C] [C] [C],-3.57,CCCCc1c(C)nc(nc1O)N(C)C
12
+ 10,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [C] [=Branch1] [C] [=O] [O] [C] [C],-1.413,CCOC(=O)CC(=O)OCC
13
+ 11,[C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-3.192,CC(C)(C)c1ccc(O)cc1
14
+ 12,[C] [C] [=C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch1],-3.035,Cc1cccc(C)c1
15
+ 13,[C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.125,CCCOC(=O)C
16
+ 14,[C] [S] [C] [=N] [N] [=C] [Branch1] [=Branch2] [C] [=Branch1] [C] [=O] [N] [Ring1] [#Branch1] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.324,CSc1nnc(c(=O)n1N)C(C)(C)C
17
+ 15,[Cl] [C] [=C] [C] [=C] [Branch1] [Branch1] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Cl],-5.142,Clc1ccc(cc1)c2ccccc2Cl
18
+ 16,[C] [C] [C] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [C] [Branch1] [Ring2] [C] [Ring1] [Branch2] [C] [Branch1] [C] [O] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2],-1.5319999999999998,CC1CC(C)C(=O)C(C1)C(O)CC2CC(=O)NC(=O)C2
19
+ 17,[C] [N] [C] [=Branch1] [C] [=O] [O] [C] [=C] [C] [=C] [C] [Branch1] [Branch2] [N] [=C] [N] [Branch1] [C] [C] [C] [=C] [Ring1] [O],-1.846,CNC(=O)Oc1cccc(N=CN(C)C)c1
20
+ 18,[C] [C] [=C] [C] [=N] [C] [N] [Branch1] [=Branch1] [C] [C] [C] [Ring1] [Ring1] [C] [=N] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [N] [C] [Ring2] [Ring1] [Ring1] [=Ring1] [#C],-3.397,Cc3ccnc4N(C1CC1)c2ncccc2C(=O)Nc34
21
+ 19,[C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.389,CCNc1ccccc1
22
+ 20,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Ring2] [Ring1] [Ring1] [Ring1] [#Branch2],-6.297000000000001,Cc1c2ccccc2c(C)c3ccc4ccccc4c13
23
+ 21,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [F] [C] [Branch1] [C] [Cl] [=C] [Ring1] [=Branch2] [F],-5.462000000000001,Fc1cccc(F)c1C(=O)NC(=O)Nc2cc(Cl)c(F)c(Cl)c2F
24
+ 22,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.057,COc1ccc(Cl)cc1
25
+ 23,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=N] [Ring1] [=Branch1],-4.2010000000000005,o1c2ccccc2c3ccccc13
26
+ 24,[C] [=C] [C] [=C] [N] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [Ring1] [#Branch2] [=C] [Ring1] [=C],-3.846,c3ccc2nc1ccccc1cc2c3
27
+ 25,[C] [C] [C] [C] [=Branch1] [C] [=O] [C] [C] [Branch1] [P] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [O] [Ring1] [#Branch1] [C] [C] [Ring1] [P] [C] [C] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [O] [C] [=Branch1] [C] [=O] [C] [O],-2.893,CC12CC(=O)C3C(CCC4=CC(=O)CCC34C)C2CCC1(O)C(=O)CO
28
+ 26,[C] [C] [C] [=C] [C] [=C] [C] [Branch1] [Ring1] [C] [C] [=C] [Ring1] [Branch2] [N] [Branch1] [Ring2] [C] [O] [C] [C] [=Branch1] [C] [=O] [C] [Cl],-3.319,CCc1cccc(CC)c1N(COC)C(=O)CCl
29
+ 27,[C] [C] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-4.157,CCCCN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
30
+ 28,[C] [S] [C] [=Branch1] [C] [=S] [N] [C] [Ring1] [=Branch1] [=O],-0.396,C1SC(=S)NC1(=O)
31
+ 29,[O] [C] [=C] [C] [=C] [Branch1] [Branch2] [C] [Branch1] [C] [O] [=C] [Ring1] [#Branch1] [C] [O] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [Branch1] [C] [O] [=C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [C] [=Ring1] [=N] [O],-2.7310000000000003,Oc1ccc(c(O)c1)c3oc2cc(O)cc(O)c2c(=O)c3O
32
+ 30,[C] [N] [Branch1] [C] [C] [C] [=N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C],-3.164,CN(C)C=Nc1ccc(Cl)cc1C
33
+ 31,[N] [C] [=Branch1] [C] [=O] [N] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [=Branch1] [=O],0.652,NC(=O)NC1NC(=O)NC1=O
34
+ 32,[Cl] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-4.063,Clc1cccc2ccccc12
35
+ 33,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.352,Oc1ccc(Cl)c(Cl)c1
36
+ 34,[C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [=C] [Branch1] [C] [Cl] [Cl] [C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-6.775,CC1(C)C(C=C(Cl)Cl)C1C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
37
+ 35,[C] [=C] [C] [=C] [NH1] [N] [=N] [C] [Ring1] [Branch1] [=C] [Ring1] [=Branch2],-2.21,c2ccc1[nH]nnc1c2
38
+ 36,[C] [C] [Branch1] [C] [C] [C] [Branch2] [Ring1] [Branch1] [N] [C] [=C] [C] [=C] [Branch1] [=Branch1] [C] [=C] [Ring1] [=Branch1] [Cl] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-8.057,CC(C)C(Nc1ccc(cc1Cl)C(F)(F)F)C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
39
+ 37,[C] [C] [C],-1.5530000000000002,CCC
40
+ 38,[C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [O] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-3.792,C1Cc2cccc3cccc1c23
41
+ 39,[C] [C] [C] [#C],-1.092,CCC#C
42
+ 40,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.5580000000000003,Clc1ccc(Cl)cc1
43
+ 41,[C] [C] [=C] [NH1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch2] [Ring1] [=Branch1],-2.9810000000000003,Cc1c[nH]c2ccccc12
44
+ 42,[C] [C] [#N],0.152,CC#N
45
+ 43,[C] [C] [C] [C] [O],-0.688,CCCCO
46
+ 44,[C] [C] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C],-2.052,CC(=C)C(=C)C
47
+ 45,[C] [C] [C] [Branch1] [C] [C] [C] [C] [O],-1.308,CCC(C)CCO
48
+ 46,[Cl] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [Branch1] [C] [Cl] [=C] [Ring1] [#Branch1] [Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2] [Cl],-7.192,Clc1ccc(c(Cl)c1Cl)c2ccc(Cl)c(Cl)c2Cl
49
+ 47,[C] [C] [=C] [C] [=Branch2] [Ring1] [=Branch1] [=C] [C] [=C] [Ring1] [=Branch1] [N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-4.945,Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2
50
+ 48,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [Cl],-3.22,Oc1ccc(Cl)cc1Cl
51
+ 49,[C] [N] [C] [=Branch2] [Ring1] [Ring2] [=C] [Branch1] [C] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [S] [Ring1] [O] [=Branch1] [C] [=O] [=O] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=N] [Ring1] [=Branch1],-3.4730000000000003,CN2C(=C(O)c1ccccc1S2(=O)=O)C(=O)Nc3ccccn3
52
+ 50,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [C] [Ring2] [Ring1] [C] [=O],-3.872,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1=O
53
+ 51,[C] [C] [=C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1],-4.147,Cc1cccc2c(C)cccc12
54
+ 52,[N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [O] [N] [C] [N] [S] [Ring1] [=Branch1] [=Branch1] [C] [=O] [=O] [C] [=C] [Ring1] [N] [Cl],-1.72,NS(=O)(=O)c2cc1c(NCNS1(=O)=O)cc2Cl
55
+ 53,[O] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-2.725,Oc1cccc2cccnc12
56
+ 54,[C] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [Ring1] [#Branch2],-3.447,C1CCc2ccccc2C1
57
+ 55,[C] [C] [O] [C] [Branch1] [C] [C] [O] [C] [C],-0.899,CCOC(C)OCC
58
+ 56,[C] [C] [C] [C] [Ring1] [Ring1] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [Branch1] [C] [Ring1] [Branch2] [=O] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.464,CC12CC2(C)C(=O)N(C1=O)c3cc(Cl)cc(Cl)c3
59
+ 57,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-4.87,Cc1c2ccccc2cc3ccccc13
60
+ 58,[C] [C] [C] [C] [O] [C],-1.072,CCCCOC
61
+ 59,[C] [C] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [Ring1] [#Branch1] [C] [C] [C] [C] [C] [C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [C] [O] [C] [Ring1] [=Branch2] [Branch1] [N] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [#Branch1] [Ring1] [=C] [C] [=O],-3.0660000000000003,CC13CCC(=O)C=C1CCC4C2CCC(C(=O)CO)C2(CC(O)C34)C=O
62
+ 60,[C] [C] [C] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [O] [=O],-1.6030000000000002,CCC1(C(C)C)C(=O)NC(=O)NC1=O
63
+ 61,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-2.761,CCOC(=O)c1ccc(O)cc1
64
+ 62,[C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [C] [=C] [Ring2] [Ring1] [C] [C] [Ring1] [S] [=C] [Ring1] [=C] [C] [Ring1] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-6.885,c1cc2ccc3ccc4ccc5ccc6ccc1c7c2c3c4c5c67
65
+ 63,[C] [C] [N] [C] [=C] [C] [Branch1] [=Branch1] [N] [Branch1] [C] [C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch2] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [N] [=C] [Ring2] [Ring1] [Ring2] [Ring1] [=Branch1],-4.408,CCN2c1cc(N(C)C)cc(C)c1NC(=O)c3cccnc23
66
+ 64,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.301,CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
67
+ 65,[C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-3.3080000000000003,CCCCCC(C)C
68
+ 66,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [N] [N] [=C] [Branch1] [#C] [N] [=C] [Ring1] [#Branch1] [C] [Branch1] [Ring1] [O] [C] [=C] [Ring1] [=N] [O] [C] [N] [C] [C] [N] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [O] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-3.958,COc2cc1c(N)nc(nc1c(OC)c2OC)N3CCN(CC3)C(=O)OCC(C)(C)O
69
+ 67,[C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [#Branch1] [C] [=C] [Ring1] [O],-0.636,c1cC2C(=O)NC(=O)C2cc1
70
+ 68,[C] [C] [C] [=O],-0.3939999999999999,CCC=O
71
+ 69,[Cl] [C] [=C] [C] [=C] [Branch2] [Ring1] [=Branch2] [C] [N] [Branch1] [Branch2] [C] [C] [C] [C] [C] [Ring1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring2] [Ring1] [=Branch1],-5.126,Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1
72
+ 70,[C] [C] [C] [C] [C] [Branch1] [Ring1] [C] [C] [C] [=O],-2.232,CCCCC(CC)C=O
73
+ 71,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-2.312,O=C1NC(=O)NC(=O)C1(CC)CCC(C)C
74
+ 72,[C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-1.857,CC(=O)Nc1ccccc1
75
+ 73,[C] [=N] [C] [=C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [N] [=C] [Ring1] [#Branch2],-0.7170000000000001,c1nccc(C(=O)NN)c1
76
+ 74,[C] [C] [Branch1] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [Ring2] [C] [Ring1] [=Branch1] [C] [Ring1] [=Branch2] [=O],-2.158,CC2(C)C1CCC(C)(C1)C2=O
77
+ 75,[C] [O] [C] [=C] [N] [=C] [C] [=N] [C] [=N] [C] [Ring1] [=Branch1] [=N] [Ring1] [#Branch2],-1.589,COc2cnc1cncnc1n2
78
+ 76,[C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [Branch1] [C] [C] [O] [P] [=Branch1] [C] [=O] [Branch1] [Ring1] [O] [C] [O] [C],-0.949,CNC(=O)C=C(C)OP(=O)(OC)OC
79
+ 77,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [Branch1] [Ring1] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [Ring1] [=Branch1],-3.784,O2c1ccccc1N(CC)C(=O)c3ccccc23
80
+ 78,[C] [=C] [C] [=C] [C] [=C] [Branch1] [Ring1] [O] [C] [C] [Branch1] [Branch2] [C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [=N] [O] [C] [Ring1] [P] [=O],-4.0760000000000005,c1cc2ccc(OC)c(CC=C(C)(C))c2oc1=O
81
+ 79,[C] [C] [C] [S] [C] [C] [C],-2.307,CCCSCCC
82
+ 80,[C] [O] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-2.948,CON(C)C(=O)Nc1ccc(Cl)cc1
83
+ 81,[C] [C] [O] [C] [C],-0.718,CCOCC
84
+ 82,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [Ring1] [O],-3.858,CC34CCC1C(CCc2cc(O)ccc12)C3CC(O)C4O
85
+ 83,[C] [C] [N] [C] [=N] [C] [Branch1] [C] [Cl] [=N] [C] [Branch1] [O] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [#N] [=N] [Ring1] [=N],-2.49,CCNc1nc(Cl)nc(NC(C)(C)C#N)n1
86
+ 84,[C] [C] [Branch1] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-1.6469999999999998,CC(C)CC(C)(C)O
87
+ 85,[Cl] [C] [=C] [C] [=C] [C] [Branch1] [C] [Br] [=C] [Ring1] [#Branch1],-3.928,Clc1cccc(Br)c1
88
+ 86,[C] [C] [C] [C] [C] [C] [Branch1] [C] [O] [C] [C],-2.033,CCCCCC(O)CC
89
+ 87,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.126,O=C1NC(=O)NC(=O)C1(CC)CC=C(C)C
90
+ 88,[C] [C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [C] [Branch1] [C] [Br] [=C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [N] [=O],-2.766,CCC(C)C1(CC(Br)=C)C(=O)NC(=O)NC1=O
91
+ 89,[C] [O] [C] [=Branch1] [C] [=O] [C],-0.416,COC(=O)C
92
+ 90,[C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [#Branch1] [O],-3.129,CC(C)c1ccc(C)cc1O
93
+ 91,[C],-0.636,C
94
+ 92,[N] [C] [=N] [C] [Branch1] [C] [O] [=N] [C] [N] [=C] [NH1] [C] [Ring1] [#Branch2] [=Ring1] [Branch1],-1.74,Nc1nc(O)nc2nc[nH]c12
95
+ 93,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-4.692,Fc1cccc(F)c1C(=O)NC(=O)Nc2ccc(Cl)cc2
96
+ 94,[C] [C] [C] [C] [C] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O] [Ring1] [#Branch2],-2.579,CC12CCC(CC1)C(C)(C)O2
97
+ 95,[C] [C] [O],0.02,CCO
98
+ 96,[C] [=C] [Branch2] [Ring1] [C] [N] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [C] [C] [=C] [C] [=C] [Ring1] [P],-2.29,c1c(NC(=O)OC(C)C(=O)NCC)cccc1
99
+ 97,[C] [C] [Branch1] [C] [C] [=C] [C] [C] [Branch2] [Ring1] [#Branch2] [C] [=Branch1] [C] [=O] [O] [C] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [C] [C],-6.763,CC(C)=CC3C(C(=O)OCc2cccc(Oc1ccccc1)c2)C3(C)C
100
+ 98,[C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Branch1] [Branch2] [N] [C] [=Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-2.902,CCCCNC(=O)n1c(NC(=O)OC)nc2ccccc12
101
+ 99,[C] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.542,CN(C)c1ccccc1
102
+ 100,[C] [O] [C] [=Branch1] [C] [=O] [C] [=C],-0.878,COC(=O)C=C
103
+ 101,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [=N] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C] [=C] [Ring1] [=C],-4.477,CN(C)C(=O)Nc2ccc(Oc1ccc(Cl)cc1)cc2
104
+ 102,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.465,O=C1NC(=O)NC(=O)C1(C(C)C)CC=C(C)C
105
+ 103,[C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1] [C],-2.6210000000000004,Cc1ccc(O)cc1C
106
+ 104,[Cl] [C] [=C] [C] [=C] [C] [=Branch1] [Ring2] [=N] [Ring1] [=Branch1] [C] [Branch1] [C] [Cl] [Branch1] [C] [Cl] [Cl],-3.833,Clc1cccc(n1)C(Cl)(Cl)Cl
107
+ 105,[C] [C] [=Branch1] [C] [=O] [O] [C] [Branch2] [Ring1] [=C] [C] [C] [C] [C] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [#Branch1] [C] [Ring1] [O] [C] [C] [C] [Ring2] [Ring1] [C] [Ring1] [#C] [C] [C] [#C],-4.2410000000000005,CC(=O)OC3(CCC4C2CCC1=CC(=O)CCC1C2CCC34C)C#C
108
+ 106,[C] [N] [C] [=Branch1] [C] [=O] [O] [N] [=C] [Branch1] [Ring2] [C] [S] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.7,CNC(=O)ON=C(CSC)C(C)(C)C
109
+ 107,[C] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [O],-2.033,CCCCCCC(C)O
data/esol/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
img/.DS_Store ADDED
Binary file (6.15 kB). View file
 
img/img1.png ADDED
img/img2.png ADDED
img/img3.png ADDED
img/img4.png ADDED
img/img5.png ADDED
img/introduction.png ADDED
img/latent_multi_bace.png ADDED
log.csv ADDED
@@ -0,0 +1 @@
 
 
1
+ ,Selected Models,Dataset,Task,Result
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/__pycache__/fm4m.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
models/fm4m.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import roc_auc_score, roc_curve
2
+
3
+ import datetime
4
+ import os
5
+ import umap
6
+ import numpy as np
7
+
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import pickle
11
+ import json
12
+
13
+ from xgboost import XGBClassifier, XGBRegressor
14
+ import xgboost as xgb
15
+ from sklearn.metrics import roc_auc_score, mean_squared_error
16
+ import xgboost as xgb
17
+ from sklearn.svm import SVR
18
+ from sklearn.linear_model import LinearRegression
19
+ from sklearn.kernel_ridge import KernelRidge
20
+ import json
21
+ from sklearn.compose import TransformedTargetRegressor
22
+ from sklearn.preprocessing import MinMaxScaler
23
+
24
+
25
+ import torch
26
+ from transformers import AutoTokenizer, AutoModel
27
+
28
+ from .selfies_model.load import SELFIES as bart
29
+ from .mhg_model import load as mhg
30
+ from .smi_ted.smi_ted_light.load import load_smi_ted
31
+
32
+ datasets = {}
33
+ models = {}
34
+ downstream_models ={}
35
+
36
+
37
+ def avail_models_data():
38
+ global datasets
39
+ global models
40
+
41
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
42
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
43
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
44
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
45
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
46
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
47
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"}]
48
+
49
+
50
+ models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
51
+ {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
52
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
53
+ {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
54
+
55
+
56
+ def avail_models(raw=False):
57
+ global models
58
+
59
+ models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
60
+ {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
61
+ {"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality"},
62
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
63
+ ]
64
+
65
+
66
+
67
+ if raw: return models
68
+ else:
69
+ return pd.DataFrame(models).drop('Name', axis=1)
70
+
71
+ return models
72
+
73
+ def avail_downstream_models():
74
+ global downstream_models
75
+
76
+ with open("downstream_models.json", "r") as outfile:
77
+ downstream_models = json.load(outfile)
78
+ return downstream_models
79
+
80
+ def avail_datasets():
81
+ global datasets
82
+
83
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv",
84
+ "Timestamp": "2024-06-26 11:27:37"},
85
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre",
86
+ "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
87
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv",
88
+ "Timestamp": "2024-06-26 11:33:47"},
89
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo",
90
+ "Timestamp": "2024-06-26 11:34:37"},
91
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace",
92
+ "Timestamp": "2024-06-26 11:36:40"},
93
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp",
94
+ "Timestamp": "2024-06-26 11:39:23"},
95
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox",
96
+ "Timestamp": "2024-06-26 11:42:43"}]
97
+
98
+ return datasets
99
+
100
+ def reset():
101
+
102
+ """datasets = {"esol": ["smiles", "ESOL predicted log solubility in mols per litre", "data/esol", "2024-06-26 11:36:46.509324"],
103
+ "freesolv": ["smiles", "expt", "data/freesolv", "2024-06-26 11:37:37.393273"],
104
+ "lipo": ["smiles", "y", "data/lipo", "2024-06-26 11:37:37.393273"],
105
+ "hiv": ["smiles", "HIV_active", "data/hiv", "2024-06-26 11:37:37.393273"],
106
+ "bace": ["smiles", "Class", "data/bace", "2024-06-26 11:38:40.058354"],
107
+ "bbbp": ["smiles", "p_np", "data/bbbp","2024-06-26 11:38:40.058354"],
108
+ "clintox": ["smiles", "CT_TOX", "data/clintox","2024-06-26 11:38:40.058354"],
109
+ "sider": ["smiles","1:", "data/sider","2024-06-26 11:38:40.058354"],
110
+ "tox21": ["smiles",":-2", "data/tox21","2024-06-26 11:38:40.058354"]
111
+ }"""
112
+
113
+ datasets = [
114
+ {"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
115
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
116
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
117
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
118
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
119
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
120
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"},
121
+ #{"Dataset": "sider", "Input": "smiles", "Output": "1:", "path": "data/sider", "Timestamp": "2024-06-26 11:38:40.058354"},
122
+ #{"Dataset": "tox21", "Input": "smiles", "Output": ":-2", "path": "data/tox21", "Timestamp": "2024-06-26 11:38:40.058354"}
123
+ ]
124
+
125
+ models = [{"Name": "bart", "Description": "BART model for string based SELFIES modality",
126
+ "Timestamp": "2024-06-21 12:32:20"},
127
+ {"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
128
+ "Timestamp": "2024-06-21 12:35:56"},
129
+ {"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
130
+ {"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
131
+ {"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
132
+ {"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
133
+
134
+
135
+ downstream_models = [
136
+ {"Name": "XGBClassifier", "Description": "XG Boost Classifier",
137
+ "Timestamp": "2024-06-21 12:31:20"},
138
+ {"Name": "XGBRegressor", "Description": "XG Boost Regressor",
139
+ "Timestamp": "2024-06-21 12:32:56"},
140
+ {"Name": "2-FNN", "Description": "A two layer feedforward network",
141
+ "Timestamp": "2024-06-24 14:34:16"},
142
+ {"Name": "3-FNN", "Description": "A three layer feedforward network",
143
+ "Timestamp": "2024-06-24 14:38:37"},
144
+ ]
145
+
146
+ with open("datasets.json", "w") as outfile:
147
+ json.dump(datasets, outfile)
148
+
149
+ with open("models.json", "w") as outfile:
150
+ json.dump(models, outfile)
151
+
152
+ with open("downstream_models.json", "w") as outfile:
153
+ json.dump(downstream_models, outfile)
154
+
155
+ def update_data_list(list_data):
156
+ #datasets[list_data[0]] = list_data[1:]
157
+
158
+ with open("datasets.json", "w") as outfile:
159
+ json.dump(datasets, outfile)
160
+
161
+ avail_models_data()
162
+
163
+ def update_model_list(list_model):
164
+ #models[list_model[0]] = list_model[1]
165
+
166
+ with open("models.json", "w") as outfile:
167
+ json.dump(list_model, outfile)
168
+
169
+ avail_models_data()
170
+
171
+ def update_downstream_model_list(list_model):
172
+ #models[list_model[0]] = list_model[1]
173
+
174
+ with open("downstream_models.json", "w") as outfile:
175
+ json.dump(list_model, outfile)
176
+
177
+ avail_models_data()
178
+
179
+ avail_models_data()
180
+
181
+ def get_representation(train_data,test_data,model_type, return_tensor=True):
182
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
183
+ if model_type in alias.keys():
184
+ model_type = alias[model_type]
185
+
186
+ if model_type == "mhg":
187
+ model = mhg.load("models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
188
+ with torch.no_grad():
189
+ train_emb = model.encode(train_data)
190
+ x_batch = torch.stack(train_emb)
191
+
192
+ test_emb = model.encode(test_data)
193
+ x_batch_test = torch.stack(test_emb)
194
+ if not return_tensor:
195
+ x_batch = pd.DataFrame(x_batch)
196
+ x_batch_test = pd.DataFrame(x_batch_test)
197
+
198
+
199
+
200
+ elif model_type == "bart":
201
+ model = bart()
202
+ model.load()
203
+ x_batch = model.encode(train_data, return_tensor=return_tensor)
204
+ x_batch_test = model.encode(test_data, return_tensor=return_tensor)
205
+
206
+ elif model_type == "smi-ted":
207
+ model = load_smi_ted(folder='./models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
208
+ with torch.no_grad():
209
+ x_batch = model.encode(train_data, return_torch=return_tensor)
210
+ x_batch_test = model.encode(test_data, return_torch=return_tensor)
211
+
212
+ elif model_type == "mol-xl":
213
+ model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True,
214
+ trust_remote_code=True)
215
+ tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
216
+
217
+ if type(train_data) == list:
218
+ inputs = tokenizer(train_data, padding=True, return_tensors="pt")
219
+ else:
220
+ inputs = tokenizer(list(train_data.values), padding=True, return_tensors="pt")
221
+
222
+ with torch.no_grad():
223
+ outputs = model(**inputs)
224
+
225
+ x_batch = outputs.pooler_output
226
+
227
+ if type(test_data) == list:
228
+ inputs = tokenizer(test_data, padding=True, return_tensors="pt")
229
+ else:
230
+ inputs = tokenizer(list(test_data.values), padding=True, return_tensors="pt")
231
+
232
+ with torch.no_grad():
233
+ outputs = model(**inputs)
234
+
235
+ x_batch_test = outputs.pooler_output
236
+
237
+ if not return_tensor:
238
+ x_batch = pd.DataFrame(x_batch)
239
+ x_batch_test = pd.DataFrame(x_batch_test)
240
+
241
+
242
+ return x_batch, x_batch_test
243
+
244
+ def single_modal(model,dataset, downstream_model,params):
245
+ print(model)
246
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED": "smi-ted"}
247
+ data = avail_models(raw=True)
248
+ df = pd.DataFrame(data)
249
+ print(list(df["Name"].values))
250
+ if alias[model] in list(df["Name"].values):
251
+ if model in alias.keys():
252
+ model_type = alias[model]
253
+ else:
254
+ model_type = model
255
+ else:
256
+ print("Model not available")
257
+ return
258
+
259
+ data = avail_datasets()
260
+ df = pd.DataFrame(data)
261
+ print(list(df["Dataset"].values))
262
+
263
+ if dataset in list(df["Dataset"].values):
264
+ task = dataset
265
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
266
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
267
+ print(f" Representation loaded successfully")
268
+ else:
269
+
270
+ print("Custom Dataset")
271
+ #return
272
+ components = dataset.split(",")
273
+ train_data = pd.read_csv(components[0])[components[2]]
274
+ test_data = pd.read_csv(components[1])[components[2]]
275
+
276
+ y_batch = pd.read_csv(components[0])[components[3]]
277
+ y_batch_test = pd.read_csv(components[1])[components[3]]
278
+
279
+
280
+ x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
281
+
282
+
283
+
284
+ print(f" Representation loaded successfully")
285
+
286
+
287
+
288
+
289
+
290
+ print(f" Calculating ROC AUC Score ...")
291
+
292
+ if downstream_model == "XGBClassifier":
293
+ xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
294
+ xgb_predict_concat.fit(x_batch, y_batch)
295
+
296
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
297
+
298
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
299
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
300
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
301
+
302
+ try:
303
+ with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
304
+ class_0,class_1 = pickle.load(f1)
305
+ except:
306
+ print("Generating latent plots")
307
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
308
+ verbose=False)
309
+ n_samples = np.minimum(1000, len(x_batch))
310
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
311
+ x = y_batch.values[:n_samples]
312
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
313
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
314
+
315
+ class_0 = features_umap[index_0]
316
+ class_1 = features_umap[index_1]
317
+ print("Generating latent plots : Done")
318
+
319
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
320
+
321
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
322
+
323
+ return result, roc_auc,fpr, tpr, class_0, class_1
324
+
325
+ elif downstream_model == "DefaultClassifier":
326
+ xgb_predict_concat = XGBClassifier() # n_estimators=5000, learning_rate=0.01, max_depth=10
327
+ xgb_predict_concat.fit(x_batch, y_batch)
328
+
329
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
330
+
331
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
332
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
333
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
334
+
335
+ try:
336
+ with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
337
+ class_0,class_1 = pickle.load(f1)
338
+ except:
339
+ print("Generating latent plots")
340
+ reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
341
+ n_samples = np.minimum(1000,len(x_batch))
342
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
343
+ x = y_batch.values[:n_samples]
344
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
345
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
346
+
347
+ class_0 = features_umap[index_0]
348
+ class_1 = features_umap[index_1]
349
+ print("Generating latent plots : Done")
350
+
351
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
352
+
353
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
354
+
355
+ return result, roc_auc,fpr, tpr, class_0, class_1
356
+
357
+ elif downstream_model == "SVR":
358
+ regressor = SVR(**params)
359
+ model = TransformedTargetRegressor(regressor= regressor,
360
+ transformer = MinMaxScaler(feature_range=(-1, 1))
361
+ ).fit(x_batch,y_batch)
362
+
363
+ y_prob = model.predict(x_batch_test)
364
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
365
+
366
+ print(f"RMSE Score: {RMSE_score:.4f}")
367
+ result = f"RMSE Score: {RMSE_score:.4f}"
368
+
369
+ print("Generating latent plots")
370
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
371
+ verbose=False)
372
+ n_samples = np.minimum(1000, len(x_batch))
373
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
374
+ x = y_batch.values[:n_samples]
375
+ #index_0 = [index for index in range(len(x)) if x[index] == 0]
376
+ #index_1 = [index for index in range(len(x)) if x[index] == 1]
377
+
378
+ class_0 = features_umap#[index_0]
379
+ class_1 = features_umap#[index_1]
380
+ print("Generating latent plots : Done")
381
+
382
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
383
+
384
+ elif downstream_model == "Kernel Ridge":
385
+ regressor = KernelRidge(**params)
386
+ model = TransformedTargetRegressor(regressor=regressor,
387
+ transformer=MinMaxScaler(feature_range=(-1, 1))
388
+ ).fit(x_batch, y_batch)
389
+
390
+ y_prob = model.predict(x_batch_test)
391
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
392
+
393
+ print(f"RMSE Score: {RMSE_score:.4f}")
394
+ result = f"RMSE Score: {RMSE_score:.4f}"
395
+
396
+ print("Generating latent plots")
397
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
398
+ verbose=False)
399
+ n_samples = np.minimum(1000, len(x_batch))
400
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
401
+ x = y_batch.values[:n_samples]
402
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
403
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
404
+
405
+ class_0 = features_umap#[index_0]
406
+ class_1 = features_umap#[index_1]
407
+ print("Generating latent plots : Done")
408
+
409
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
410
+
411
+
412
+ elif downstream_model == "Linear Regression":
413
+ regressor = LinearRegression(**params)
414
+ model = TransformedTargetRegressor(regressor=regressor,
415
+ transformer=MinMaxScaler(feature_range=(-1, 1))
416
+ ).fit(x_batch, y_batch)
417
+
418
+ y_prob = model.predict(x_batch_test)
419
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
420
+
421
+ print(f"RMSE Score: {RMSE_score:.4f}")
422
+ result = f"RMSE Score: {RMSE_score:.4f}"
423
+
424
+ print("Generating latent plots")
425
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
426
+ verbose=False)
427
+ n_samples = np.minimum(1000, len(x_batch))
428
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
429
+ x = y_batch.values[:n_samples]
430
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
431
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
432
+
433
+ class_0 = features_umap#[index_0]
434
+ class_1 = features_umap#[index_1]
435
+ print("Generating latent plots : Done")
436
+
437
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
438
+
439
+
440
+ elif downstream_model == "DefaultRegressor":
441
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
442
+ model = TransformedTargetRegressor(regressor=regressor,
443
+ transformer=MinMaxScaler(feature_range=(-1, 1))
444
+ ).fit(x_batch, y_batch)
445
+
446
+ y_prob = model.predict(x_batch_test)
447
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
448
+
449
+ print(f"RMSE Score: {RMSE_score:.4f}")
450
+ result = f"RMSE Score: {RMSE_score:.4f}"
451
+
452
+ print("Generating latent plots")
453
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
454
+ verbose=False)
455
+ n_samples = np.minimum(1000, len(x_batch))
456
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
457
+ x = y_batch.values[:n_samples]
458
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
459
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
460
+
461
+ class_0 = features_umap#[index_0]
462
+ class_1 = features_umap#[index_1]
463
+ print("Generating latent plots : Done")
464
+
465
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
466
+
467
+
468
+ def multi_modal(model_list,dataset, downstream_model,params):
469
+ print(model_list)
470
+ data = avail_datasets()
471
+ df = pd.DataFrame(data)
472
+ list(df["Dataset"].values)
473
+
474
+ if dataset in list(df["Dataset"].values):
475
+ task = dataset
476
+ predefined = True
477
+ else:
478
+ predefined = False
479
+ components = dataset.split(",")
480
+ train_data = pd.read_csv(components[0])[components[2]]
481
+ test_data = pd.read_csv(components[1])[components[2]]
482
+
483
+ y_batch = pd.read_csv(components[0])[components[3]]
484
+ y_batch_test = pd.read_csv(components[1])[components[3]]
485
+
486
+ print("Custom Dataset loaded")
487
+
488
+
489
+ data = avail_models(raw=True)
490
+ df = pd.DataFrame(data)
491
+ list(df["Name"].values)
492
+
493
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED":"smi-ted"}
494
+ #if set(model_list).issubset(list(df["Name"].values)):
495
+ if set(model_list).issubset(list(alias.keys())):
496
+ for i, model in enumerate(model_list):
497
+ if model in alias.keys():
498
+ model_type = alias[model]
499
+ else:
500
+ model_type = model
501
+
502
+ if i == 0:
503
+ if predefined:
504
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
505
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
506
+ print(f" Loaded representation/{task}_{model_type}.pkl")
507
+ else:
508
+ x_batch, x_batch_test = get_representation(train_data, test_data, model_type)
509
+ x_batch = pd.DataFrame(x_batch)
510
+ x_batch_test = pd.DataFrame(x_batch_test)
511
+
512
+ else:
513
+ if predefined:
514
+ with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
515
+ x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
516
+ print(f" Loaded representation/{task}_{model_type}.pkl")
517
+ else:
518
+ x_batch_1, x_batch_test_1 = get_representation(train_data, test_data, model_type)
519
+ x_batch_1 = pd.DataFrame(x_batch_1)
520
+ x_batch_test_1 = pd.DataFrame(x_batch_test_1)
521
+
522
+ x_batch = pd.concat([x_batch, x_batch_1], axis=1)
523
+ x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
524
+
525
+
526
+ else:
527
+ print("Model not available")
528
+ return
529
+
530
+ num_columns = x_batch_test.shape[1]
531
+ x_batch_test.columns = [f'{i + 1}' for i in range(num_columns)]
532
+
533
+ num_columns = x_batch.shape[1]
534
+ x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
535
+
536
+
537
+ print(f"Representations loaded successfully")
538
+ try:
539
+ with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1:
540
+ class_0, class_1 = pickle.load(f1)
541
+ except:
542
+ print("Generating latent plots")
543
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
544
+ verbose=False)
545
+ n_samples = np.minimum(1000, len(x_batch))
546
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
547
+
548
+ if "Classifier" in downstream_model:
549
+ x = y_batch.values[:n_samples]
550
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
551
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
552
+
553
+ class_0 = features_umap[index_0]
554
+ class_1 = features_umap[index_1]
555
+
556
+ else:
557
+ class_0 = features_umap
558
+ class_1 = features_umap
559
+
560
+ print("Generating latent plots : Done")
561
+
562
+ print(f" Calculating ROC AUC Score ...")
563
+
564
+
565
+ if downstream_model == "XGBClassifier":
566
+ xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
567
+ xgb_predict_concat.fit(x_batch, y_batch)
568
+
569
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
570
+
571
+
572
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
573
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
574
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
575
+
576
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
577
+
578
+ #vizualize(x_batch_test, y_batch_test)
579
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
580
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
581
+
582
+ return result, roc_auc,fpr, tpr, class_0, class_1
583
+
584
+ elif downstream_model == "DefaultClassifier":
585
+ xgb_predict_concat = XGBClassifier()#n_estimators=5000, learning_rate=0.01, max_depth=10)
586
+ xgb_predict_concat.fit(x_batch, y_batch)
587
+
588
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
589
+
590
+
591
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
592
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
593
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
594
+
595
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
596
+
597
+ #vizualize(x_batch_test, y_batch_test)
598
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
599
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
600
+
601
+ return result, roc_auc,fpr, tpr, class_0, class_1
602
+
603
+ elif downstream_model == "SVR":
604
+ regressor = SVR(**params)
605
+ model = TransformedTargetRegressor(regressor= regressor,
606
+ transformer = MinMaxScaler(feature_range=(-1, 1))
607
+ ).fit(x_batch,y_batch)
608
+
609
+ y_prob = model.predict(x_batch_test)
610
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
611
+
612
+ print(f"RMSE Score: {RMSE_score:.4f}")
613
+ result = f"RMSE Score: {RMSE_score:.4f}"
614
+
615
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
616
+
617
+ elif downstream_model == "Linear Regression":
618
+ regressor = LinearRegression(**params)
619
+ model = TransformedTargetRegressor(regressor=regressor,
620
+ transformer=MinMaxScaler(feature_range=(-1, 1))
621
+ ).fit(x_batch, y_batch)
622
+
623
+ y_prob = model.predict(x_batch_test)
624
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
625
+
626
+ print(f"RMSE Score: {RMSE_score:.4f}")
627
+ result = f"RMSE Score: {RMSE_score:.4f}"
628
+
629
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
630
+
631
+ elif downstream_model == "Kernel Ridge":
632
+ regressor = KernelRidge(**params)
633
+ model = TransformedTargetRegressor(regressor=regressor,
634
+ transformer=MinMaxScaler(feature_range=(-1, 1))
635
+ ).fit(x_batch, y_batch)
636
+
637
+ y_prob = model.predict(x_batch_test)
638
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
639
+
640
+ print(f"RMSE Score: {RMSE_score:.4f}")
641
+ result = f"RMSE Score: {RMSE_score:.4f}"
642
+
643
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
644
+
645
+ elif downstream_model == "DefaultRegressor":
646
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
647
+ model = TransformedTargetRegressor(regressor=regressor,
648
+ transformer=MinMaxScaler(feature_range=(-1, 1))
649
+ ).fit(x_batch, y_batch)
650
+
651
+ y_prob = model.predict(x_batch_test)
652
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
653
+
654
+ print(f"RMSE Score: {RMSE_score:.4f}")
655
+ result = f"RMSE Score: {RMSE_score:.4f}"
656
+
657
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
658
+
659
+
660
+
661
+
662
+
663
+
models/mhg_model/.DS_Store ADDED
Binary file (8.2 kB). View file
 
models/mhg_model/README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mhg-gnn
2
+
3
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
4
+
5
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
6
+
7
+ ![mhg-gnn](images/mhg_example1.png)
8
+
9
+ ## Introduction
10
+
11
+ We present MHG-GNN, an autoencoder architecture
12
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
13
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
14
+ demonstrate high predictive performance on molecular graph data.
15
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
16
+
17
+ ## Table of Contents
18
+
19
+ 1. [Getting Started](#getting-started)
20
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
21
+ 2. [Installation](#installation)
22
+ 2. [Feature Extraction](#feature-extraction)
23
+
24
+ ## Getting Started
25
+
26
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
27
+
28
+ ### Pretrained Models and Training Logs
29
+
30
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
31
+
32
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
33
+
34
+ ### Installation
35
+
36
+ We recommend to create a virtual environment. For example:
37
+
38
+ ```
39
+ python3 -m venv .venv
40
+ . .venv/bin/activate
41
+ ```
42
+
43
+ Type the following command once the virtual environment is activated:
44
+
45
+ ```
46
+ git clone [email protected]:CMD-TRL/mhg-gnn.git
47
+ cd ./mhg-gnn
48
+ pip install .
49
+ ```
50
+
51
+ ## Feature Extraction
52
+
53
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
54
+
55
+ To load mhg-gnn, you can simply use:
56
+
57
+ ```python
58
+ import torch
59
+ import load
60
+
61
+ model = load.load()
62
+ ```
63
+
64
+ To encode SMILES into embeddings, you can use:
65
+
66
+ ```python
67
+ with torch.no_grad():
68
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
69
+ ```
70
+
71
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
72
+
73
+ ```python
74
+ orig = model.decode(repr)
75
+ ```
models/mhg_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
models/mhg_model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (224 Bytes). View file
 
models/mhg_model/__pycache__/load.cpython-310.pyc ADDED
Binary file (3.16 kB). View file
 
models/mhg_model/graph_grammar/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
12
+
13
+ """ Title """
14
+
15
+ __author__ = "Hiroshi Kajino <[email protected]>"
16
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
17
+ __version__ = "0.1"
18
+ __date__ = "Jan 1 2018"
19
+
models/mhg_model/graph_grammar/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (676 Bytes). View file
 
models/mhg_model/graph_grammar/__pycache__/hypergraph.cpython-310.pyc ADDED
Binary file (15.3 kB). View file
 
models/mhg_model/graph_grammar/algo/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (681 Bytes). View file
 
models/mhg_model/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc ADDED
Binary file (19.5 kB). View file
 
models/mhg_model/graph_grammar/algo/tree_decomposition.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from copy import deepcopy
22
+ from itertools import combinations
23
+ from ..hypergraph import Hypergraph
24
+ import networkx as nx
25
+ import numpy as np
26
+
27
+
28
+ class CliqueTree(nx.Graph):
29
+ ''' clique tree object
30
+
31
+ Attributes
32
+ ----------
33
+ hg : Hypergraph
34
+ This hypergraph will be decomposed.
35
+ root_hg : Hypergraph
36
+ Hypergraph on the root node.
37
+ ident_node_dict : dict
38
+ ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
39
+ '''
40
+ def __init__(self, hg=None, **kwargs):
41
+ self.hg = deepcopy(hg)
42
+ if self.hg is not None:
43
+ self.ident_node_dict = self.hg.get_identical_node_dict()
44
+ else:
45
+ self.ident_node_dict = {}
46
+ super().__init__(**kwargs)
47
+
48
+ @property
49
+ def root_hg(self):
50
+ ''' return the hypergraph on the root node
51
+ '''
52
+ return self.nodes[0]['subhg']
53
+
54
+ @root_hg.setter
55
+ def root_hg(self, hypergraph):
56
+ ''' set the hypergraph on the root node
57
+ '''
58
+ self.nodes[0]['subhg'] = hypergraph
59
+
60
+ def insert_subhg(self, subhypergraph: Hypergraph) -> None:
61
+ ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
62
+
63
+ Parameters
64
+ ----------
65
+ subhg : Hypergraph
66
+ '''
67
+ num_nodes = self.number_of_nodes()
68
+ self.add_node(num_nodes, subhg=subhypergraph)
69
+ self.add_edge(num_nodes, 0)
70
+ adj_nodes = deepcopy(list(self.adj[0].keys()))
71
+ for each_node in adj_nodes:
72
+ if len(self.nodes[each_node]["subhg"].nodes.intersection(
73
+ self.nodes[num_nodes]["subhg"].nodes)\
74
+ - self.root_hg.nodes) != 0 and each_node != num_nodes:
75
+ self.remove_edge(0, each_node)
76
+ self.add_edge(each_node, num_nodes)
77
+
78
+ def to_irredundant(self) -> None:
79
+ ''' convert the clique tree to be irredundant
80
+ '''
81
+ for each_node in self.hg.nodes:
82
+ subtree = self.subgraph([
83
+ each_tree_node for each_tree_node in self.nodes()\
84
+ if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
85
+ leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
86
+ redundant_leaf_node_list = []
87
+ for each_leaf_node in leaf_node_list:
88
+ if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
89
+ redundant_leaf_node_list.append(each_leaf_node)
90
+ for each_red_leaf_node in redundant_leaf_node_list:
91
+ current_node = each_red_leaf_node
92
+ while subtree.degree(current_node) == 1 \
93
+ and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
94
+ self.nodes[current_node]["subhg"].remove_node(each_node)
95
+ remove_node = current_node
96
+ current_node = list(dict(subtree[remove_node]).keys())[0]
97
+ subtree.remove_node(remove_node)
98
+
99
+ fixed_node_set = deepcopy(self.nodes)
100
+ for each_node in fixed_node_set:
101
+ if self.nodes[each_node]["subhg"].num_edges == 0:
102
+ if len(self[each_node]) == 1:
103
+ self.remove_node(each_node)
104
+ elif len(self[each_node]) == 2:
105
+ self.add_edge(*self[each_node])
106
+ self.remove_node(each_node)
107
+ else:
108
+ pass
109
+ else:
110
+ pass
111
+
112
+ redundant = True
113
+ while redundant:
114
+ redundant = False
115
+ fixed_edge_set = deepcopy(self.edges)
116
+ remove_node_set = set()
117
+ for node_1, node_2 in fixed_edge_set:
118
+ if node_1 in remove_node_set or node_2 in remove_node_set:
119
+ pass
120
+ else:
121
+ if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
122
+ redundant = True
123
+ adj_node_list = set(self.adj[node_1]) - {node_2}
124
+ self.remove_node(node_1)
125
+ remove_node_set.add(node_1)
126
+ for each_node in adj_node_list:
127
+ self.add_edge(node_2, each_node)
128
+
129
+ elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
130
+ redundant = True
131
+ adj_node_list = set(self.adj[node_2]) - {node_1}
132
+ self.remove_node(node_2)
133
+ remove_node_set.add(node_2)
134
+ for each_node in adj_node_list:
135
+ self.add_edge(node_1, each_node)
136
+
137
+ def node_update(self, key_node: str, subhg) -> None:
138
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
139
+
140
+ Parameters
141
+ ----------
142
+ key_node : str
143
+ key node that must be removed.
144
+ subhg : Hypegraph
145
+ """
146
+ for each_edge in subhg.edges:
147
+ self.root_hg.remove_edge(each_edge)
148
+ self.root_hg.remove_nodes(self.ident_node_dict[key_node])
149
+
150
+ adj_node_list = list(subhg.nodes)
151
+ for each_node in subhg.nodes:
152
+ if each_node not in self.ident_node_dict[key_node]:
153
+ if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
154
+ self.root_hg.remove_node(each_node)
155
+ adj_node_list.remove(each_node)
156
+ else:
157
+ adj_node_list.remove(each_node)
158
+
159
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
160
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
161
+ self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
162
+
163
+ subhg.remove_edges_with_attr({'tmp' : True})
164
+ self.insert_subhg(subhg)
165
+
166
+ def update(self, subhg, remove_nodes=False):
167
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
168
+
169
+ Parameters
170
+ ----------
171
+ subhg : Hypegraph
172
+ """
173
+ for each_edge in subhg.edges:
174
+ self.root_hg.remove_edge(each_edge)
175
+ if remove_nodes:
176
+ remove_edge_list = []
177
+ for each_edge in self.root_hg.edges:
178
+ if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
179
+ and self.root_hg.edge_attr(each_edge).get('tmp', False):
180
+ remove_edge_list.append(each_edge)
181
+ self.root_hg.remove_edges(remove_edge_list)
182
+
183
+ adj_node_list = list(subhg.nodes)
184
+ for each_node in subhg.nodes:
185
+ if self.root_hg.degree(each_node) == 0:
186
+ self.root_hg.remove_node(each_node)
187
+ adj_node_list.remove(each_node)
188
+
189
+ if len(adj_node_list) != 1 and not remove_nodes:
190
+ self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
191
+ '''
192
+ else:
193
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
194
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
195
+ self.root_hg.add_edge(
196
+ [each_node_1, each_node_2], attr_dict=dict(tmp=True))
197
+ '''
198
+ subhg.remove_edges_with_attr({'tmp':True})
199
+ self.insert_subhg(subhg)
200
+
201
+
202
+ def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
203
+ if mode == 'standard':
204
+ degree_dict = hg.degrees()
205
+ min_deg_node = min(degree_dict, key=degree_dict.get)
206
+ min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
207
+ return min_deg_node, min_deg_subhg
208
+ elif mode == 'mol':
209
+ degree_dict = hg.degrees()
210
+ min_deg = min(degree_dict.values())
211
+ min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
212
+ min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
213
+ for each_min_deg_node in min_deg_node_list]
214
+ best_score = np.inf
215
+ best_idx = -1
216
+ for each_idx in range(len(min_deg_subhg_list)):
217
+ if min_deg_subhg_list[each_idx].num_nodes < best_score:
218
+ best_idx = each_idx
219
+ return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
220
+ else:
221
+ raise ValueError
222
+
223
+
224
+ def tree_decomposition(hg, irredundant=True):
225
+ """ compute a tree decomposition of the input hypergraph
226
+
227
+ Parameters
228
+ ----------
229
+ hg : Hypergraph
230
+ hypergraph to be decomposed
231
+ irredundant : bool
232
+ if True, irredundant tree decomposition will be computed.
233
+
234
+ Returns
235
+ -------
236
+ clique_tree : nx.Graph
237
+ each node contains a subhypergraph of `hg`
238
+ """
239
+ org_hg = hg.copy()
240
+ ident_node_dict = hg.get_identical_node_dict()
241
+ clique_tree = CliqueTree(org_hg)
242
+ clique_tree.add_node(0, subhg=org_hg)
243
+ while True:
244
+ degree_dict = org_hg.degrees()
245
+ min_deg_node = min(degree_dict, key=degree_dict.get)
246
+ min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
247
+ if org_hg.nodes == min_deg_subhg.nodes:
248
+ break
249
+
250
+ # org_hg and min_deg_subhg are divided
251
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
252
+
253
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
254
+
255
+ if irredundant:
256
+ clique_tree.to_irredundant()
257
+ return clique_tree
258
+
259
+
260
+ def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
261
+ ''' compute a tree decomposition given a hyperedge replacement grammar.
262
+ the resultant clique tree should induce a less compact HRG.
263
+
264
+ Parameters
265
+ ----------
266
+ hg : Hypergraph
267
+ hypergraph to be decomposed
268
+ hrg : HyperedgeReplacementGrammar
269
+ current HRG
270
+ irredundant : bool
271
+ if True, irredundant tree decomposition will be computed.
272
+
273
+ Returns
274
+ -------
275
+ clique_tree : nx.Graph
276
+ each node contains a subhypergraph of `hg`
277
+ '''
278
+ org_hg = hg.copy()
279
+ ident_node_dict = hg.get_identical_node_dict()
280
+ clique_tree = CliqueTree(org_hg)
281
+ clique_tree.add_node(0, subhg=org_hg)
282
+ root_node = 0
283
+
284
+ # construct a clique tree using HRG
285
+ success_any = True
286
+ while success_any:
287
+ success_any = False
288
+ for each_prod_rule in hrg.prod_rule_list:
289
+ org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
290
+ if success:
291
+ if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
292
+ success_any = True
293
+ subhg.remove_edges_with_attr({'terminal' : False})
294
+ clique_tree.root_hg = org_hg
295
+ clique_tree.insert_subhg(subhg)
296
+
297
+ clique_tree.root_hg = org_hg
298
+
299
+ for each_edge in deepcopy(org_hg.edges):
300
+ if not org_hg.edge_attr(each_edge)['terminal']:
301
+ node_list = org_hg.nodes_in_edge(each_edge)
302
+ org_hg.remove_edge(each_edge)
303
+
304
+ for each_node_1, each_node_2 in combinations(node_list, 2):
305
+ if not org_hg.is_adj(each_node_1, each_node_2):
306
+ org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
307
+
308
+ # construct a clique tree using the existing algorithm
309
+ degree_dict = org_hg.degrees()
310
+ if degree_dict:
311
+ while True:
312
+ min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
313
+ if org_hg.nodes == min_deg_subhg.nodes: break
314
+
315
+ # org_hg and min_deg_subhg are divided
316
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
317
+
318
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
319
+ if irredundant:
320
+ clique_tree.to_irredundant()
321
+
322
+ if return_root:
323
+ if root_node == 0 and 0 not in clique_tree.nodes:
324
+ root_node = clique_tree.number_of_nodes()
325
+ while root_node not in clique_tree.nodes:
326
+ root_node -= 1
327
+ elif root_node not in clique_tree.nodes:
328
+ while root_node not in clique_tree.nodes:
329
+ root_node -= 1
330
+ else:
331
+ pass
332
+ return clique_tree, root_node
333
+ else:
334
+ return clique_tree
335
+
336
+
337
+ def tree_decomposition_from_leaf(hg, irredundant=True):
338
+ """ compute a tree decomposition of the input hypergraph
339
+
340
+ Parameters
341
+ ----------
342
+ hg : Hypergraph
343
+ hypergraph to be decomposed
344
+ irredundant : bool
345
+ if True, irredundant tree decomposition will be computed.
346
+
347
+ Returns
348
+ -------
349
+ clique_tree : nx.Graph
350
+ each node contains a subhypergraph of `hg`
351
+ """
352
+ def apply_normal_decomposition(clique_tree):
353
+ degree_dict = clique_tree.root_hg.degrees()
354
+ min_deg_node = min(degree_dict, key=degree_dict.get)
355
+ min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
356
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
357
+ return clique_tree, False
358
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
359
+ return clique_tree, True
360
+
361
+ def apply_min_edge_deg_decomposition(clique_tree):
362
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
363
+ non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
364
+ if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
365
+ if not non_tmp_edge_list:
366
+ return clique_tree, False
367
+ min_deg_edge = None
368
+ min_deg = np.inf
369
+ for each_edge in non_tmp_edge_list:
370
+ if min_deg > edge_degree_dict[each_edge]:
371
+ min_deg_edge = each_edge
372
+ min_deg = edge_degree_dict[each_edge]
373
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
374
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
375
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
376
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
377
+ return clique_tree, False
378
+ clique_tree.update(min_deg_subhg)
379
+ return clique_tree, True
380
+
381
+ org_hg = hg.copy()
382
+ clique_tree = CliqueTree(org_hg)
383
+ clique_tree.add_node(0, subhg=org_hg)
384
+
385
+ success = True
386
+ while success:
387
+ clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
388
+ if not success:
389
+ clique_tree, success = apply_normal_decomposition(clique_tree)
390
+
391
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
392
+ if irredundant:
393
+ clique_tree.to_irredundant()
394
+ return clique_tree
395
+
396
+ def topological_tree_decomposition(
397
+ hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
398
+ ''' compute a tree decomposition of the input hypergraph
399
+
400
+ Parameters
401
+ ----------
402
+ hg : Hypergraph
403
+ hypergraph to be decomposed
404
+ irredundant : bool
405
+ if True, irredundant tree decomposition will be computed.
406
+
407
+ Returns
408
+ -------
409
+ clique_tree : CliqueTree
410
+ each node contains a subhypergraph of `hg`
411
+ '''
412
+ def _contract_tree(clique_tree):
413
+ ''' contract a single leaf
414
+
415
+ Parameters
416
+ ----------
417
+ clique_tree : CliqueTree
418
+
419
+ Returns
420
+ -------
421
+ CliqueTree, bool
422
+ bool represents whether this operation succeeds or not.
423
+ '''
424
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
425
+ leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
426
+ if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
427
+ and edge_degree_dict[each_edge] == 1]
428
+ if not leaf_edge_list:
429
+ return clique_tree, False
430
+ min_deg_edge = leaf_edge_list[0]
431
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
432
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
433
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
434
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
435
+ return clique_tree, False
436
+ clique_tree.update(min_deg_subhg)
437
+ return clique_tree, True
438
+
439
+ def _rip_labels_from_cycles(clique_tree, org_hg):
440
+ ''' rip hyperedge-labels off
441
+
442
+ Parameters
443
+ ----------
444
+ clique_tree : CliqueTree
445
+ org_hg : Hypergraph
446
+
447
+ Returns
448
+ -------
449
+ CliqueTree, bool
450
+ bool represents whether this operation succeeds or not.
451
+ '''
452
+ ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
453
+ for each_edge in clique_tree.root_hg.edges:
454
+ if each_edge in org_hg.edges:
455
+ if org_hg.in_cycle(each_edge):
456
+ node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
457
+ subhg = clique_tree.root_hg.get_subhg(
458
+ node_list, [each_edge], ident_node_dict)
459
+ if clique_tree.root_hg.nodes == subhg.nodes:
460
+ return clique_tree, False
461
+ clique_tree.update(subhg)
462
+ '''
463
+ in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
464
+ if not all(in_cycle_dict.values()):
465
+ node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
466
+ node_list = [node_not_in_cycle]
467
+ node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
468
+ edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
469
+ import pdb; pdb.set_trace()
470
+ subhg = clique_tree.root_hg.get_subhg(
471
+ node_list, edge_list, ident_node_dict)
472
+
473
+ clique_tree.update(subhg)
474
+ '''
475
+ return clique_tree, True
476
+ return clique_tree, False
477
+
478
+ def _shrink_cycle(clique_tree):
479
+ ''' shrink a cycle
480
+
481
+ Parameters
482
+ ----------
483
+ clique_tree : CliqueTree
484
+
485
+ Returns
486
+ -------
487
+ CliqueTree, bool
488
+ bool represents whether this operation succeeds or not.
489
+ '''
490
+ def filter_subhg(subhg, hg, key_node):
491
+ num_nodes_cycle = 0
492
+ nodes_in_cycle_list = []
493
+ for each_node in subhg.nodes:
494
+ if hg.in_cycle(each_node):
495
+ num_nodes_cycle += 1
496
+ if each_node != key_node:
497
+ nodes_in_cycle_list.append(each_node)
498
+ if num_nodes_cycle > 3:
499
+ break
500
+ if num_nodes_cycle != 3:
501
+ return False
502
+ else:
503
+ for each_edge in hg.edges:
504
+ if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
505
+ return False
506
+ return True
507
+
508
+ #ident_node_dict = hg.get_identical_node_dict()
509
+ ident_node_dict = clique_tree.ident_node_dict
510
+ for each_node in clique_tree.root_hg.nodes:
511
+ if clique_tree.root_hg.in_cycle(each_node)\
512
+ and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
513
+ clique_tree.root_hg,
514
+ each_node):
515
+ target_node = each_node
516
+ target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
517
+ if clique_tree.root_hg.nodes == target_subhg.nodes:
518
+ return clique_tree, False
519
+ clique_tree.update(target_subhg)
520
+ return clique_tree, True
521
+ return clique_tree, False
522
+
523
+ def _contract_cycles(clique_tree):
524
+ '''
525
+ remove a subhypergraph that looks like a cycle on a leaf.
526
+
527
+ Parameters
528
+ ----------
529
+ clique_tree : CliqueTree
530
+
531
+ Returns
532
+ -------
533
+ CliqueTree, bool
534
+ bool represents whether this operation succeeds or not.
535
+ '''
536
+ def _divide_hg(hg):
537
+ ''' divide a hypergraph into subhypergraphs such that
538
+ each subhypergraph is connected to each other in a tree-like way.
539
+
540
+ Parameters
541
+ ----------
542
+ hg : Hypergraph
543
+
544
+ Returns
545
+ -------
546
+ list of Hypergraphs
547
+ each element corresponds to a subhypergraph of `hg`
548
+ '''
549
+ for each_node in hg.nodes:
550
+ if hg.is_dividable(each_node):
551
+ adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
552
+ '''
553
+ if any(adj_edges_dict.values()):
554
+ import pdb; pdb.set_trace()
555
+ edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
556
+ subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
557
+ return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
558
+ else:
559
+ '''
560
+ subhg1, subhg2 = hg.divide(each_node)
561
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
562
+ return [hg]
563
+
564
+ def _is_leaf(hg, divided_subhg) -> bool:
565
+ ''' judge whether subhg is a leaf-like in the original hypergraph
566
+
567
+ Parameters
568
+ ----------
569
+ hg : Hypergraph
570
+ divided_subhg : Hypergraph
571
+ `divided_subhg` is a subhypergraph of `hg`
572
+
573
+ Returns
574
+ -------
575
+ bool
576
+ '''
577
+ '''
578
+ adj_edges_set = set([])
579
+ for each_node in divided_subhg.nodes:
580
+ adj_edges_set.update(set(hg.adj_edges(each_node)))
581
+
582
+
583
+ _hg = deepcopy(hg)
584
+ _hg.remove_subhg(divided_subhg)
585
+ if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
586
+ import pdb; pdb.set_trace()
587
+ return len(adj_edges_set - divided_subhg.edges) == 1
588
+ '''
589
+ _hg = deepcopy(hg)
590
+ _hg.remove_subhg(divided_subhg)
591
+ return nx.is_connected(_hg.hg)
592
+
593
+ subhg_list = _divide_hg(clique_tree.root_hg)
594
+ if len(subhg_list) == 1:
595
+ return clique_tree, False
596
+ else:
597
+ while len(subhg_list) > 1:
598
+ max_leaf_subhg = None
599
+ for each_subhg in subhg_list:
600
+ if _is_leaf(clique_tree.root_hg, each_subhg):
601
+ if max_leaf_subhg is None:
602
+ max_leaf_subhg = each_subhg
603
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
604
+ max_leaf_subhg = each_subhg
605
+ clique_tree.update(max_leaf_subhg)
606
+ subhg_list.remove(max_leaf_subhg)
607
+ return clique_tree, True
608
+
609
+ org_hg = hg.copy()
610
+ clique_tree = CliqueTree(org_hg)
611
+ clique_tree.add_node(0, subhg=org_hg)
612
+
613
+ success = True
614
+ while success:
615
+ '''
616
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
617
+ if not success:
618
+ clique_tree, success = _contract_cycles(clique_tree)
619
+ '''
620
+ clique_tree, success = _contract_tree(clique_tree)
621
+ if not success:
622
+ if rip_labels:
623
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
624
+ if not success:
625
+ if shrink_cycle:
626
+ clique_tree, success = _shrink_cycle(clique_tree)
627
+ if not success:
628
+ if contract_cycles:
629
+ clique_tree, success = _contract_cycles(clique_tree)
630
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
631
+ if irredundant:
632
+ clique_tree.to_irredundant()
633
+ return clique_tree
634
+
635
+ def molecular_tree_decomposition(hg, irredundant=True):
636
+ """ compute a tree decomposition of the input molecular hypergraph
637
+
638
+ Parameters
639
+ ----------
640
+ hg : Hypergraph
641
+ molecular hypergraph to be decomposed
642
+ irredundant : bool
643
+ if True, irredundant tree decomposition will be computed.
644
+
645
+ Returns
646
+ -------
647
+ clique_tree : CliqueTree
648
+ each node contains a subhypergraph of `hg`
649
+ """
650
+ def _divide_hg(hg):
651
+ ''' divide a hypergraph into subhypergraphs such that
652
+ each subhypergraph is connected to each other in a tree-like way.
653
+
654
+ Parameters
655
+ ----------
656
+ hg : Hypergraph
657
+
658
+ Returns
659
+ -------
660
+ list of Hypergraphs
661
+ each element corresponds to a subhypergraph of `hg`
662
+ '''
663
+ is_ring = False
664
+ for each_node in hg.nodes:
665
+ if hg.node_attr(each_node)['is_in_ring']:
666
+ is_ring = True
667
+ if not hg.node_attr(each_node)['is_in_ring'] \
668
+ and hg.degree(each_node) == 2:
669
+ subhg1, subhg2 = hg.divide(each_node)
670
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
671
+
672
+ if is_ring:
673
+ subhg_list = []
674
+ remove_edge_list = []
675
+ remove_node_list = []
676
+ for each_edge in hg.edges:
677
+ node_list = hg.nodes_in_edge(each_edge)
678
+ subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
679
+ subhg_list.append(subhg)
680
+ remove_edge_list.append(each_edge)
681
+ for each_node in node_list:
682
+ if not hg.node_attr(each_node)['is_in_ring']:
683
+ remove_node_list.append(each_node)
684
+ hg.remove_edges(remove_edge_list)
685
+ hg.remove_nodes(remove_node_list, False)
686
+ return subhg_list + [hg]
687
+ else:
688
+ return [hg]
689
+
690
+ org_hg = hg.copy()
691
+ clique_tree = CliqueTree(org_hg)
692
+ clique_tree.add_node(0, subhg=org_hg)
693
+
694
+ subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
695
+ #_subhg_list = deepcopy(subhg_list)
696
+ if len(subhg_list) == 1:
697
+ pass
698
+ else:
699
+ while len(subhg_list) > 1:
700
+ max_leaf_subhg = None
701
+ for each_subhg in subhg_list:
702
+ if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
703
+ if max_leaf_subhg is None:
704
+ max_leaf_subhg = each_subhg
705
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
706
+ max_leaf_subhg = each_subhg
707
+
708
+ if max_leaf_subhg is None:
709
+ for each_subhg in subhg_list:
710
+ if _is_ring_label(clique_tree.root_hg, each_subhg):
711
+ if max_leaf_subhg is None:
712
+ max_leaf_subhg = each_subhg
713
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
714
+ max_leaf_subhg = each_subhg
715
+ if max_leaf_subhg is not None:
716
+ clique_tree.update(max_leaf_subhg)
717
+ subhg_list.remove(max_leaf_subhg)
718
+ else:
719
+ for each_subhg in subhg_list:
720
+ if _is_leaf(clique_tree.root_hg, each_subhg):
721
+ if max_leaf_subhg is None:
722
+ max_leaf_subhg = each_subhg
723
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
724
+ max_leaf_subhg = each_subhg
725
+ if max_leaf_subhg is not None:
726
+ clique_tree.update(max_leaf_subhg, True)
727
+ subhg_list.remove(max_leaf_subhg)
728
+ else:
729
+ break
730
+ if len(subhg_list) > 1:
731
+ '''
732
+ for each_idx, each_subhg in enumerate(subhg_list):
733
+ each_subhg.draw(f'{each_idx}', True)
734
+ clique_tree.root_hg.draw('root', True)
735
+ import pickle
736
+ with open('buggy_hg.pkl', 'wb') as f:
737
+ pickle.dump(hg, f)
738
+ return clique_tree, subhg_list, _subhg_list
739
+ '''
740
+ raise RuntimeError('bug in tree decomposition algorithm')
741
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
742
+
743
+ '''
744
+ for each_tree_node in clique_tree.adj[0]:
745
+ subhg = clique_tree.nodes[each_tree_node]['subhg']
746
+ for each_edge in subhg.edges:
747
+ if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
748
+ clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
749
+ '''
750
+ if irredundant:
751
+ clique_tree.to_irredundant()
752
+ return clique_tree #, _subhg_list
753
+
754
+ def _is_leaf(hg, subhg) -> bool:
755
+ ''' judge whether subhg is a leaf-like in the original hypergraph
756
+
757
+ Parameters
758
+ ----------
759
+ hg : Hypergraph
760
+ subhg : Hypergraph
761
+ `subhg` is a subhypergraph of `hg`
762
+
763
+ Returns
764
+ -------
765
+ bool
766
+ '''
767
+ if len(subhg.edges) == 0:
768
+ adj_edge_set = set([])
769
+ subhg_edge_set = set([])
770
+ for each_edge in hg.edges:
771
+ if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
772
+ subhg_edge_set.add(each_edge)
773
+ for each_node in subhg.nodes:
774
+ adj_edge_set.update(set(hg.adj_edges(each_node)))
775
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
776
+ return True
777
+ else:
778
+ return False
779
+ elif len(subhg.edges) == 1:
780
+ adj_edge_set = set([])
781
+ subhg_edge_set = subhg.edges
782
+ for each_node in subhg.nodes:
783
+ for each_adj_edge in hg.adj_edges(each_node):
784
+ adj_edge_set.add(each_adj_edge)
785
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
786
+ return True
787
+ else:
788
+ return False
789
+ else:
790
+ raise ValueError('subhg should be nodes only or one-edge hypergraph.')
791
+
792
+ def _is_ring_label(hg, subhg):
793
+ if len(subhg.edges) != 1:
794
+ return False
795
+ edge_name = list(subhg.edges)[0]
796
+ #assert edge_name in hg.edges, f'{edge_name}'
797
+ is_in_ring = False
798
+ for each_node in subhg.nodes:
799
+ if subhg.node_attr(each_node)['is_in_ring']:
800
+ is_in_ring = True
801
+ else:
802
+ adj_edge_list = list(hg.adj_edges(each_node))
803
+ adj_edge_list.remove(edge_name)
804
+ if len(adj_edge_list) == 1:
805
+ if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
806
+ return False
807
+ elif len(adj_edge_list) == 0:
808
+ pass
809
+ else:
810
+ raise ValueError
811
+ if is_in_ring:
812
+ return True
813
+ else:
814
+ return False
815
+
816
+ def _is_ring(hg):
817
+ for each_node in hg.nodes:
818
+ if not hg.node_attr(each_node)['is_in_ring']:
819
+ return False
820
+ return True
821
+
models/mhg_model/graph_grammar/graph_grammar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (690 Bytes). View file
 
models/mhg_model/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc ADDED
Binary file (1.19 kB). View file
 
models/mhg_model/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/mhg_model/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc ADDED
Binary file (29.1 kB). View file
 
models/mhg_model/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (5.39 kB). View file
 
models/mhg_model/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
models/mhg_model/graph_grammar/graph_grammar/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from abc import ABCMeta, abstractmethod
22
+
23
+ class GraphGrammarBase(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def learn(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def sample(self):
30
+ pass
models/mhg_model/graph_grammar/graph_grammar/corpus.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from collections import Counter
22
+ from functools import partial
23
+ from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
24
+ from networkx.algorithms.isomorphism import GraphMatcher
25
+ import os
26
+
27
+
28
+ class CliqueTreeCorpus(object):
29
+
30
+ ''' clique tree corpus
31
+
32
+ Attributes
33
+ ----------
34
+ clique_tree_list : list of CliqueTree
35
+ subhg_list : list of Hypergraph
36
+ '''
37
+
38
+ def __init__(self):
39
+ self.clique_tree_list = []
40
+ self.subhg_list = []
41
+
42
+ @property
43
+ def size(self):
44
+ return len(self.subhg_list)
45
+
46
+ def add_clique_tree(self, clique_tree):
47
+ for each_node in clique_tree.nodes:
48
+ subhg = clique_tree.nodes[each_node]['subhg']
49
+ subhg_idx = self.add_subhg(subhg)
50
+ clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
51
+ self.clique_tree_list.append(clique_tree)
52
+
53
+ def add_to_subhg_list(self, clique_tree, root_node):
54
+ parent_node_dict = {}
55
+ current_node = None
56
+ parent_node_dict[root_node] = None
57
+ stack = [root_node]
58
+ while stack:
59
+ current_node = stack.pop()
60
+ current_subhg = clique_tree.nodes[current_node]['subhg']
61
+ for each_child in clique_tree.adj[current_node]:
62
+ if each_child != parent_node_dict[current_node]:
63
+ stack.append(each_child)
64
+ parent_node_dict[each_child] = current_node
65
+ if parent_node_dict[current_node] is not None:
66
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
67
+ common, _ = common_node_list(parent_subhg, current_subhg)
68
+ parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
69
+
70
+ parent_node_dict = {}
71
+ current_node = None
72
+ parent_node_dict[root_node] = None
73
+ stack = [root_node]
74
+ while stack:
75
+ current_node = stack.pop()
76
+ current_subhg = clique_tree.nodes[current_node]['subhg']
77
+ for each_child in clique_tree.adj[current_node]:
78
+ if each_child != parent_node_dict[current_node]:
79
+ stack.append(each_child)
80
+ parent_node_dict[each_child] = current_node
81
+ if parent_node_dict[current_node] is not None:
82
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
83
+ common, _ = common_node_list(parent_subhg, current_subhg)
84
+ for each_idx, each_node in enumerate(common):
85
+ current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
86
+
87
+ subhg_idx, is_new = self.add_subhg(current_subhg)
88
+ clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
89
+ return clique_tree
90
+
91
+ def add_subhg(self, subhg):
92
+ if len(self.subhg_list) == 0:
93
+ node_dict = {}
94
+ for each_node in subhg.nodes:
95
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
96
+ node_list = []
97
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
98
+ node_list.append(each_key)
99
+ for each_idx, each_node in enumerate(node_list):
100
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
101
+ self.subhg_list.append(subhg)
102
+ return 0, True
103
+ else:
104
+ match = False
105
+ subhg_bond_symbol_counter \
106
+ = Counter([subhg.node_attr(each_node)['symbol'] \
107
+ for each_node in subhg.nodes])
108
+ subhg_atom_symbol_counter \
109
+ = Counter([subhg.edge_attr(each_edge).get('symbol', None) \
110
+ for each_edge in subhg.edges])
111
+ for each_idx, each_subhg in enumerate(self.subhg_list):
112
+ each_bond_symbol_counter \
113
+ = Counter([each_subhg.node_attr(each_node)['symbol'] \
114
+ for each_node in each_subhg.nodes])
115
+ each_atom_symbol_counter \
116
+ = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
117
+ for each_edge in each_subhg.edges])
118
+ if not match \
119
+ and (subhg.num_nodes == each_subhg.num_nodes
120
+ and subhg.num_edges == each_subhg.num_edges
121
+ and subhg_bond_symbol_counter == each_bond_symbol_counter
122
+ and subhg_atom_symbol_counter == each_atom_symbol_counter):
123
+ gm = GraphMatcher(each_subhg.hg,
124
+ subhg.hg,
125
+ node_match=_easy_node_match,
126
+ edge_match=_edge_match)
127
+ try:
128
+ isomap = next(gm.isomorphisms_iter())
129
+ match = True
130
+ for each_node in each_subhg.nodes:
131
+ subhg.node_attr(isomap[each_node])['order4hrg'] \
132
+ = each_subhg.node_attr(each_node)['order4hrg']
133
+ if 'ext_id' in each_subhg.node_attr(each_node):
134
+ subhg.node_attr(isomap[each_node])['ext_id'] \
135
+ = each_subhg.node_attr(each_node)['ext_id']
136
+ return each_idx, False
137
+ except StopIteration:
138
+ match = False
139
+ if not match:
140
+ node_dict = {}
141
+ for each_node in subhg.nodes:
142
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
143
+ node_list = []
144
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
145
+ node_list.append(each_key)
146
+ for each_idx, each_node in enumerate(node_list):
147
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
148
+
149
+ #for each_idx, each_node in enumerate(subhg.nodes):
150
+ # subhg.node_attr(each_node)['order4hrg'] = each_idx
151
+ self.subhg_list.append(subhg)
152
+ return len(self.subhg_list) - 1, True
models/mhg_model/graph_grammar/graph_grammar/hrg.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from .corpus import CliqueTreeCorpus
22
+ from .base import GraphGrammarBase
23
+ from .symbols import TSymbol, NTSymbol, BondSymbol
24
+ from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
25
+ from ..hypergraph import Hypergraph
26
+ from collections import Counter
27
+ from copy import deepcopy
28
+ from ..algo.tree_decomposition import (
29
+ tree_decomposition,
30
+ tree_decomposition_with_hrg,
31
+ tree_decomposition_from_leaf,
32
+ topological_tree_decomposition,
33
+ molecular_tree_decomposition)
34
+ from functools import partial
35
+ from networkx.algorithms.isomorphism import GraphMatcher
36
+ from typing import List, Dict, Tuple
37
+ import networkx as nx
38
+ import numpy as np
39
+ import torch
40
+ import os
41
+ import random
42
+
43
+ DEBUG = False
44
+
45
+
46
+ class ProductionRule(object):
47
+ """ A class of a production rule
48
+
49
+ Attributes
50
+ ----------
51
+ lhs : Hypergraph or None
52
+ the left hand side of the production rule.
53
+ if None, the rule is a starting rule.
54
+ rhs : Hypergraph
55
+ the right hand side of the production rule.
56
+ """
57
+ def __init__(self, lhs, rhs):
58
+ self.lhs = lhs
59
+ self.rhs = rhs
60
+
61
+ @property
62
+ def is_start_rule(self) -> bool:
63
+ return self.lhs.num_nodes == 0
64
+
65
+ @property
66
+ def ext_node(self) -> Dict[int, str]:
67
+ """ return a dict of external nodes
68
+ """
69
+ if self.is_start_rule:
70
+ return {}
71
+ else:
72
+ ext_node_dict = {}
73
+ for each_node in self.lhs.nodes:
74
+ ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
75
+ return ext_node_dict
76
+
77
+ @property
78
+ def lhs_nt_symbol(self) -> NTSymbol:
79
+ if self.is_start_rule:
80
+ return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
81
+ else:
82
+ return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
83
+
84
+ def rhs_adj_mat(self, node_edge_list):
85
+ ''' return the adjacency matrix of rhs of the production rule
86
+ '''
87
+ return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
88
+
89
+ def draw(self, file_path=None):
90
+ return self.rhs.draw(file_path)
91
+
92
+ def is_same(self, prod_rule, ignore_order=False):
93
+ """ judge whether this production rule is
94
+ the same as the input one, `prod_rule`
95
+
96
+ Parameters
97
+ ----------
98
+ prod_rule : ProductionRule
99
+ production rule to be compared
100
+
101
+ Returns
102
+ -------
103
+ is_same : bool
104
+ isomap : dict
105
+ isomorphism of nodes and hyperedges.
106
+ ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
107
+ 'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
108
+ 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
109
+ key comes from `prod_rule`, value comes from `self`.
110
+ """
111
+ if self.is_start_rule:
112
+ if not prod_rule.is_start_rule:
113
+ return False, {}
114
+ else:
115
+ if prod_rule.is_start_rule:
116
+ return False, {}
117
+ else:
118
+ if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
119
+ return False, {}
120
+
121
+ if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
122
+ return False, {}
123
+ if prod_rule.rhs.num_edges != self.rhs.num_edges:
124
+ return False, {}
125
+
126
+ subhg_bond_symbol_counter \
127
+ = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
128
+ for each_node in prod_rule.rhs.nodes])
129
+ each_bond_symbol_counter \
130
+ = Counter([self.rhs.node_attr(each_node)['symbol'] \
131
+ for each_node in self.rhs.nodes])
132
+ if subhg_bond_symbol_counter != each_bond_symbol_counter:
133
+ return False, {}
134
+
135
+ subhg_atom_symbol_counter \
136
+ = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
137
+ for each_edge in prod_rule.rhs.edges])
138
+ each_atom_symbol_counter \
139
+ = Counter([self.rhs.edge_attr(each_edge)['symbol'] \
140
+ for each_edge in self.rhs.edges])
141
+ if subhg_atom_symbol_counter != each_atom_symbol_counter:
142
+ return False, {}
143
+
144
+ gm = GraphMatcher(prod_rule.rhs.hg,
145
+ self.rhs.hg,
146
+ partial(_node_match_prod_rule,
147
+ ignore_order=ignore_order),
148
+ partial(_edge_match,
149
+ ignore_order=ignore_order))
150
+ try:
151
+ return True, next(gm.isomorphisms_iter())
152
+ except StopIteration:
153
+ return False, {}
154
+
155
+ def applied_to(self,
156
+ hg: Hypergraph,
157
+ edge: str) -> Tuple[Hypergraph, List[str]]:
158
+ """ augment `hg` by replacing `edge` with `self.rhs`.
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+ edge : str
164
+ `edge` must belong to `hg`
165
+
166
+ Returns
167
+ -------
168
+ hg : Hypergraph
169
+ resultant hypergraph
170
+ nt_edge_list : list
171
+ list of non-terminal edges
172
+ """
173
+ nt_edge_dict = {}
174
+ if self.is_start_rule:
175
+ if (edge is not None) or (hg is not None):
176
+ ValueError("edge and hg must be None for this prod rule.")
177
+ hg = Hypergraph()
178
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
179
+ for num_idx, each_node in enumerate(self.rhs.nodes):
180
+ hg.add_node(f"bond_{num_idx}",
181
+ #attr_dict=deepcopy(self.rhs.node_attr(each_node)))
182
+ attr_dict=self.rhs.node_attr(each_node))
183
+ node_map_rhs[each_node] = f"bond_{num_idx}"
184
+ for each_edge in self.rhs.edges:
185
+ node_list = []
186
+ for each_node in self.rhs.nodes_in_edge(each_edge):
187
+ node_list.append(node_map_rhs[each_node])
188
+ if isinstance(self.rhs.nodes_in_edge(each_edge), set):
189
+ node_list = set(node_list)
190
+ edge_id = hg.add_edge(
191
+ node_list,
192
+ #attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
193
+ attr_dict=self.rhs.edge_attr(each_edge))
194
+ if "nt_idx" in hg.edge_attr(edge_id):
195
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
196
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
197
+ return hg, nt_edge_list
198
+ else:
199
+ if edge not in hg.edges:
200
+ raise ValueError("the input hyperedge does not exist.")
201
+ if hg.edge_attr(edge)["terminal"]:
202
+ raise ValueError("the input hyperedge is terminal.")
203
+ if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
204
+ print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
205
+ raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
206
+ if DEBUG:
207
+ for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
208
+ other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
209
+ attr = deepcopy(self.lhs.node_attr(other_node))
210
+ attr.pop('ext_id')
211
+ if hg.node_attr(each_node) != attr:
212
+ raise ValueError('node attributes are inconsistent.')
213
+
214
+ # order of nodes that belong to the non-terminal edge in hg
215
+ nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
216
+ nt_order_dict_inv = {} # order -> hg_node
217
+ for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
218
+ nt_order_dict[each_node] = each_idx
219
+ nt_order_dict_inv[each_idx] = each_node
220
+
221
+ # construct a node_map_rhs: rhs -> new hg
222
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
223
+ node_idx = hg.num_nodes
224
+ for each_node in self.rhs.nodes:
225
+ if "ext_id" in self.rhs.node_attr(each_node):
226
+ node_map_rhs[each_node] \
227
+ = nt_order_dict_inv[
228
+ self.rhs.node_attr(each_node)["ext_id"]]
229
+ else:
230
+ node_map_rhs[each_node] = f"bond_{node_idx}"
231
+ node_idx += 1
232
+
233
+ # delete non-terminal
234
+ hg.remove_edge(edge)
235
+
236
+ # add nodes to hg
237
+ for each_node in self.rhs.nodes:
238
+ hg.add_node(node_map_rhs[each_node],
239
+ attr_dict=self.rhs.node_attr(each_node))
240
+
241
+ # add hyperedges to hg
242
+ for each_edge in self.rhs.edges:
243
+ node_list_hg = []
244
+ for each_node in self.rhs.nodes_in_edge(each_edge):
245
+ node_list_hg.append(node_map_rhs[each_node])
246
+ edge_id = hg.add_edge(
247
+ node_list_hg,
248
+ attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
249
+ if "nt_idx" in hg.edge_attr(edge_id):
250
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
251
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
252
+ return hg, nt_edge_list
253
+
254
+ def revert(self, hg: Hypergraph, return_subhg=False):
255
+ ''' revert applying this production rule.
256
+ i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
257
+ this method replaces the subhypergraph with a non-terminal hyperedge.
258
+
259
+ Parameters
260
+ ----------
261
+ hg : Hypergraph
262
+ hypergraph to be reverted
263
+ return_subhg : bool
264
+ if True, the removed subhypergraph will be returned.
265
+
266
+ Returns
267
+ -------
268
+ hg : Hypergraph
269
+ the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
270
+ success : bool
271
+ this indicates whether reverting is successed or not.
272
+ '''
273
+ gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
274
+ edge_match=_edge_match)
275
+ try:
276
+ # in case when the matched subhg is connected to the other part via external nodes and more.
277
+ not_iso = True
278
+ while not_iso:
279
+ isomap = next(gm.subgraph_isomorphisms_iter())
280
+ adj_node_set = set([]) # reachable nodes from the internal nodes
281
+ subhg_node_set = set(isomap.keys()) # nodes in subhg
282
+ for each_node in subhg_node_set:
283
+ adj_node_set.add(each_node)
284
+ if isomap[each_node] not in self.ext_node.values():
285
+ adj_node_set.update(hg.hg.adj[each_node])
286
+ if adj_node_set == subhg_node_set:
287
+ not_iso = False
288
+ else:
289
+ if return_subhg:
290
+ return hg, False, Hypergraph()
291
+ else:
292
+ return hg, False
293
+ inv_isomap = {v: k for k, v in isomap.items()}
294
+ '''
295
+ isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
296
+ 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
297
+ where keys come from `hg` and values come from `self.rhs`
298
+ '''
299
+ except StopIteration:
300
+ if return_subhg:
301
+ return hg, False, Hypergraph()
302
+ else:
303
+ return hg, False
304
+
305
+ if return_subhg:
306
+ subhg = Hypergraph()
307
+ for each_node in hg.nodes:
308
+ if each_node in isomap:
309
+ subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
310
+ for each_edge in hg.edges:
311
+ if each_edge in isomap:
312
+ subhg.add_edge(hg.nodes_in_edge(each_edge),
313
+ attr_dict=hg.edge_attr(each_edge),
314
+ edge_name=each_edge)
315
+ subhg.edge_idx = hg.edge_idx
316
+
317
+ # remove subhg except for the externael nodes
318
+ for each_key, each_val in isomap.items():
319
+ if each_key.startswith('e'):
320
+ hg.remove_edge(each_key)
321
+ for each_key, each_val in isomap.items():
322
+ if each_key.startswith('bond_'):
323
+ if each_val not in self.ext_node.values():
324
+ hg.remove_node(each_key)
325
+
326
+ # add non-terminal hyperedge
327
+ nt_node_list = []
328
+ for each_ext_id in self.ext_node.keys():
329
+ nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
330
+
331
+ hg.add_edge(nt_node_list,
332
+ attr_dict=dict(
333
+ terminal=False,
334
+ symbol=self.lhs_nt_symbol))
335
+ if return_subhg:
336
+ return hg, True, subhg
337
+ else:
338
+ return hg, True
339
+
340
+
341
+ class ProductionRuleCorpus(object):
342
+
343
+ '''
344
+ A corpus of production rules.
345
+ This class maintains
346
+ (i) list of unique production rules,
347
+ (ii) list of unique edge symbols (both terminal and non-terminal), and
348
+ (iii) list of unique node symbols.
349
+
350
+ Attributes
351
+ ----------
352
+ prod_rule_list : list
353
+ list of unique production rules
354
+ edge_symbol_list : list
355
+ list of unique symbols (including both terminal and non-terminal)
356
+ node_symbol_list : list
357
+ list of node symbols
358
+ nt_symbol_list : list
359
+ list of unique lhs symbols
360
+ ext_id_list : list
361
+ list of ext_ids
362
+ lhs_in_prod_rule : array
363
+ a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
364
+ '''
365
+
366
+ def __init__(self):
367
+ self.prod_rule_list = []
368
+ self.edge_symbol_list = []
369
+ self.edge_symbol_dict = {}
370
+ self.node_symbol_list = []
371
+ self.node_symbol_dict = {}
372
+ self.nt_symbol_list = []
373
+ self.ext_id_list = []
374
+ self._lhs_in_prod_rule = None
375
+ self.lhs_in_prod_rule_row_list = []
376
+ self.lhs_in_prod_rule_col_list = []
377
+
378
+ @property
379
+ def lhs_in_prod_rule(self):
380
+ if self._lhs_in_prod_rule is None:
381
+ self._lhs_in_prod_rule = torch.sparse.FloatTensor(
382
+ torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
383
+ torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
384
+ torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
385
+ ).to_dense()
386
+ return self._lhs_in_prod_rule
387
+
388
+ @property
389
+ def num_prod_rule(self):
390
+ ''' return the number of production rules
391
+
392
+ Returns
393
+ -------
394
+ int : the number of unique production rules
395
+ '''
396
+ return len(self.prod_rule_list)
397
+
398
+ @property
399
+ def start_rule_list(self):
400
+ ''' return a list of start rules
401
+
402
+ Returns
403
+ -------
404
+ list : list of start rules
405
+ '''
406
+ start_rule_list = []
407
+ for each_prod_rule in self.prod_rule_list:
408
+ if each_prod_rule.is_start_rule:
409
+ start_rule_list.append(each_prod_rule)
410
+ return start_rule_list
411
+
412
+ @property
413
+ def num_edge_symbol(self):
414
+ return len(self.edge_symbol_list)
415
+
416
+ @property
417
+ def num_node_symbol(self):
418
+ return len(self.node_symbol_list)
419
+
420
+ @property
421
+ def num_ext_id(self):
422
+ return len(self.ext_id_list)
423
+
424
+ def construct_feature_vectors(self):
425
+ ''' this method constructs feature vectors for the production rules collected so far.
426
+ currently, NTSymbol and TSymbol are treated in the same manner.
427
+ '''
428
+ feature_id_dict = {}
429
+ feature_id_dict['TSymbol'] = 0
430
+ feature_id_dict['NTSymbol'] = 1
431
+ feature_id_dict['BondSymbol'] = 2
432
+ for each_edge_symbol in self.edge_symbol_list:
433
+ for each_attr in each_edge_symbol.__dict__.keys():
434
+ each_val = each_edge_symbol.__dict__[each_attr]
435
+ if isinstance(each_val, list):
436
+ each_val = tuple(each_val)
437
+ if (each_attr, each_val) not in feature_id_dict:
438
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
439
+
440
+ for each_node_symbol in self.node_symbol_list:
441
+ for each_attr in each_node_symbol.__dict__.keys():
442
+ each_val = each_node_symbol.__dict__[each_attr]
443
+ if isinstance(each_val, list):
444
+ each_val = tuple(each_val)
445
+ if (each_attr, each_val) not in feature_id_dict:
446
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
447
+ for each_ext_id in self.ext_id_list:
448
+ feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
449
+ dim = len(feature_id_dict)
450
+
451
+ feature_dict = {}
452
+ for each_edge_symbol in self.edge_symbol_list:
453
+ idx_list = []
454
+ idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
455
+ for each_attr in each_edge_symbol.__dict__.keys():
456
+ each_val = each_edge_symbol.__dict__[each_attr]
457
+ if isinstance(each_val, list):
458
+ each_val = tuple(each_val)
459
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
460
+ feature = torch.sparse.LongTensor(
461
+ torch.LongTensor([idx_list]),
462
+ torch.ones(len(idx_list)),
463
+ torch.Size([len(feature_id_dict)])
464
+ )
465
+ feature_dict[each_edge_symbol] = feature
466
+
467
+ for each_node_symbol in self.node_symbol_list:
468
+ idx_list = []
469
+ idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
470
+ for each_attr in each_node_symbol.__dict__.keys():
471
+ each_val = each_node_symbol.__dict__[each_attr]
472
+ if isinstance(each_val, list):
473
+ each_val = tuple(each_val)
474
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
475
+ feature = torch.sparse.LongTensor(
476
+ torch.LongTensor([idx_list]),
477
+ torch.ones(len(idx_list)),
478
+ torch.Size([len(feature_id_dict)])
479
+ )
480
+ feature_dict[each_node_symbol] = feature
481
+ for each_ext_id in self.ext_id_list:
482
+ idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
483
+ feature_dict[('ext_id', each_ext_id)] \
484
+ = torch.sparse.LongTensor(
485
+ torch.LongTensor([idx_list]),
486
+ torch.ones(len(idx_list)),
487
+ torch.Size([len(feature_id_dict)])
488
+ )
489
+ return feature_dict, dim
490
+
491
+ def edge_symbol_idx(self, symbol):
492
+ return self.edge_symbol_dict[symbol]
493
+
494
+ def node_symbol_idx(self, symbol):
495
+ return self.node_symbol_dict[symbol]
496
+
497
+ def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
498
+ """ return whether the input production rule is new or not, and its production rule id.
499
+ Production rules are regarded as the same if
500
+ i) there exists a one-to-one mapping of nodes and edges, and
501
+ ii) all the attributes associated with nodes and hyperedges are the same.
502
+
503
+ Parameters
504
+ ----------
505
+ prod_rule : ProductionRule
506
+
507
+ Returns
508
+ -------
509
+ prod_rule_id : int
510
+ production rule index. if new, a new index will be assigned.
511
+ prod_rule : ProductionRule
512
+ """
513
+ num_lhs = len(self.nt_symbol_list)
514
+ for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
515
+ is_same, isomap = prod_rule.is_same(each_prod_rule)
516
+ if is_same:
517
+ # we do not care about edge and node names, but care about the order of non-terminal edges.
518
+ for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
519
+ if key.startswith("bond_"):
520
+ continue
521
+
522
+ # rewrite `nt_idx` in `prod_rule` for further processing
523
+ if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
524
+ if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
525
+ raise ValueError
526
+ prod_rule.rhs.set_edge_attr(
527
+ val,
528
+ {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
529
+ return each_idx, prod_rule
530
+ self.prod_rule_list.append(prod_rule)
531
+ self._update_edge_symbol_list(prod_rule)
532
+ self._update_node_symbol_list(prod_rule)
533
+ self._update_ext_id_list(prod_rule)
534
+
535
+ lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
536
+ self.lhs_in_prod_rule_row_list.append(lhs_idx)
537
+ self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
538
+ self._lhs_in_prod_rule = None
539
+ return len(self.prod_rule_list)-1, prod_rule
540
+
541
+ def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
542
+ return self.prod_rule_list[prod_rule_idx]
543
+
544
+ def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
545
+ ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
546
+
547
+ Parameters
548
+ ----------
549
+ unmasked_logit_array : array-like, length `num_prod_rule`
550
+ nt_symbol : NTSymbol
551
+ '''
552
+ if not isinstance(unmasked_logit_array, np.ndarray):
553
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
554
+ if deterministic:
555
+ prob = masked_softmax(unmasked_logit_array,
556
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
557
+ return self.prod_rule_list[np.argmax(prob)]
558
+ else:
559
+ return np.random.choice(
560
+ self.prod_rule_list, 1,
561
+ p=masked_softmax(unmasked_logit_array,
562
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
563
+
564
+ def masked_logprob(self, unmasked_logit_array, nt_symbol):
565
+ if not isinstance(unmasked_logit_array, np.ndarray):
566
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
567
+ prob = masked_softmax(unmasked_logit_array,
568
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
569
+ return np.log(prob)
570
+
571
+ def _update_edge_symbol_list(self, prod_rule: ProductionRule):
572
+ ''' update edge symbol list
573
+
574
+ Parameters
575
+ ----------
576
+ prod_rule : ProductionRule
577
+ '''
578
+ if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
579
+ self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
580
+
581
+ for each_edge in prod_rule.rhs.edges:
582
+ if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
583
+ edge_symbol_idx = len(self.edge_symbol_list)
584
+ self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
585
+ self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
586
+ else:
587
+ edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
588
+ prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
589
+ pass
590
+
591
+ def _update_node_symbol_list(self, prod_rule: ProductionRule):
592
+ ''' update node symbol list
593
+
594
+ Parameters
595
+ ----------
596
+ prod_rule : ProductionRule
597
+ '''
598
+ for each_node in prod_rule.rhs.nodes:
599
+ if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
600
+ node_symbol_idx = len(self.node_symbol_list)
601
+ self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
602
+ self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
603
+ else:
604
+ node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
605
+ prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
606
+
607
+ def _update_ext_id_list(self, prod_rule: ProductionRule):
608
+ for each_node in prod_rule.rhs.nodes:
609
+ if 'ext_id' in prod_rule.rhs.node_attr(each_node):
610
+ if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
611
+ self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
612
+
613
+
614
+ class HyperedgeReplacementGrammar(GraphGrammarBase):
615
+ """
616
+ Learn a hyperedge replacement grammar from a set of hypergraphs.
617
+
618
+ Attributes
619
+ ----------
620
+ prod_rule_list : list of ProductionRule
621
+ production rules learned from the input hypergraphs
622
+ """
623
+ def __init__(self,
624
+ tree_decomposition=molecular_tree_decomposition,
625
+ ignore_order=False, **kwargs):
626
+ from functools import partial
627
+ self.prod_rule_corpus = ProductionRuleCorpus()
628
+ self.clique_tree_corpus = CliqueTreeCorpus()
629
+ self.ignore_order = ignore_order
630
+ self.tree_decomposition = partial(tree_decomposition, **kwargs)
631
+
632
+ @property
633
+ def num_prod_rule(self):
634
+ ''' return the number of production rules
635
+
636
+ Returns
637
+ -------
638
+ int : the number of unique production rules
639
+ '''
640
+ return self.prod_rule_corpus.num_prod_rule
641
+
642
+ @property
643
+ def start_rule_list(self):
644
+ ''' return a list of start rules
645
+
646
+ Returns
647
+ -------
648
+ list : list of start rules
649
+ '''
650
+ return self.prod_rule_corpus.start_rule_list
651
+
652
+ @property
653
+ def prod_rule_list(self):
654
+ return self.prod_rule_corpus.prod_rule_list
655
+
656
+ def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
657
+ """ learn from a list of hypergraphs
658
+
659
+ Parameters
660
+ ----------
661
+ hg_list : list of Hypergraph
662
+
663
+ Returns
664
+ -------
665
+ prod_rule_seq_list : list of integers
666
+ each element corresponds to a sequence of production rules to generate each hypergraph.
667
+ """
668
+ prod_rule_seq_list = []
669
+ idx = 0
670
+ for each_idx, each_hg in enumerate(hg_list):
671
+ clique_tree = self.tree_decomposition(each_hg)
672
+
673
+ # get a pair of myself and children
674
+ root_node = _find_root(clique_tree)
675
+ clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
676
+ prod_rule_seq = []
677
+ stack = []
678
+
679
+ children = sorted(list(clique_tree[root_node].keys()))
680
+
681
+ # extract a temporary production rule
682
+ prod_rule = extract_prod_rule(
683
+ None,
684
+ clique_tree.nodes[root_node]["subhg"],
685
+ [clique_tree.nodes[each_child]["subhg"]
686
+ for each_child in children],
687
+ clique_tree.nodes[root_node].get('subhg_idx', None))
688
+
689
+ # update the production rule list
690
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
691
+ children = reorder_children(root_node,
692
+ children,
693
+ prod_rule,
694
+ clique_tree)
695
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
696
+ prod_rule_seq.append(prod_rule_id)
697
+
698
+ while len(stack) != 0:
699
+ # get a triple of parent, myself, and children
700
+ parent, myself = stack.pop()
701
+ children = sorted(list(dict(clique_tree[myself]).keys()))
702
+ children.remove(parent)
703
+
704
+ # extract a temp prod rule
705
+ prod_rule = extract_prod_rule(
706
+ clique_tree.nodes[parent]["subhg"],
707
+ clique_tree.nodes[myself]["subhg"],
708
+ [clique_tree.nodes[each_child]["subhg"]
709
+ for each_child in children],
710
+ clique_tree.nodes[myself].get('subhg_idx', None))
711
+
712
+ # update the prod rule list
713
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
714
+ children = reorder_children(myself,
715
+ children,
716
+ prod_rule,
717
+ clique_tree)
718
+ stack.extend([(myself, each_child)
719
+ for each_child in children[::-1]])
720
+ prod_rule_seq.append(prod_rule_id)
721
+ prod_rule_seq_list.append(prod_rule_seq)
722
+ if (each_idx+1) % print_freq == 0:
723
+ msg = f'#(molecules processed)={each_idx+1}\t'\
724
+ f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
725
+ logger(msg)
726
+ if each_idx > max_mol:
727
+ break
728
+
729
+ print(f'corpus_size = {self.clique_tree_corpus.size}')
730
+ return prod_rule_seq_list
731
+
732
+ def sample(self, z, deterministic=False):
733
+ """ sample a new hypergraph from HRG.
734
+
735
+ Parameters
736
+ ----------
737
+ z : array-like, shape (len, num_prod_rule)
738
+ logit
739
+ deterministic : bool
740
+ if True, deterministic sampling
741
+
742
+ Returns
743
+ -------
744
+ Hypergraph
745
+ """
746
+ seq_idx = 0
747
+ stack = []
748
+ z = z[:, :-1]
749
+ init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
750
+ is_aromatic=False,
751
+ bond_symbol_list=[]),
752
+ deterministic=deterministic)
753
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
754
+ stack = deepcopy(nt_edge_list[::-1])
755
+ while len(stack) != 0 and seq_idx < z.shape[0]-1:
756
+ seq_idx += 1
757
+ nt_edge = stack.pop()
758
+ nt_symbol = hg.edge_attr(nt_edge)['symbol']
759
+ prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
760
+ hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
761
+ stack.extend(nt_edge_list[::-1])
762
+ if len(stack) != 0:
763
+ raise RuntimeError(f'{len(stack)} non-terminals are left.')
764
+ return hg
765
+
766
+ def construct(self, prod_rule_seq):
767
+ """ construct a hypergraph following `prod_rule_seq`
768
+
769
+ Parameters
770
+ ----------
771
+ prod_rule_seq : list of integers
772
+ a sequence of production rules.
773
+
774
+ Returns
775
+ -------
776
+ UndirectedHypergraph
777
+ """
778
+ seq_idx = 0
779
+ init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
780
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
781
+ stack = deepcopy(nt_edge_list[::-1])
782
+ while len(stack) != 0:
783
+ seq_idx += 1
784
+ nt_edge = stack.pop()
785
+ hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
786
+ stack.extend(nt_edge_list[::-1])
787
+ return hg
788
+
789
+ def update_prod_rule_list(self, prod_rule):
790
+ """ return whether the input production rule is new or not, and its production rule id.
791
+ Production rules are regarded as the same if
792
+ i) there exists a one-to-one mapping of nodes and edges, and
793
+ ii) all the attributes associated with nodes and hyperedges are the same.
794
+
795
+ Parameters
796
+ ----------
797
+ prod_rule : ProductionRule
798
+
799
+ Returns
800
+ -------
801
+ is_new : bool
802
+ if True, this production rule is new
803
+ prod_rule_id : int
804
+ production rule index. if new, a new index will be assigned.
805
+ """
806
+ return self.prod_rule_corpus.append(prod_rule)
807
+
808
+
809
+ class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
810
+ '''
811
+ This class learns HRG incrementally leveraging the previously obtained production rules.
812
+ '''
813
+ def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
814
+ self.prod_rule_list = []
815
+ self.tree_decomposition = tree_decomposition
816
+ self.ignore_order = ignore_order
817
+
818
+ def learn(self, hg_list):
819
+ """ learn from a list of hypergraphs
820
+
821
+ Parameters
822
+ ----------
823
+ hg_list : list of UndirectedHypergraph
824
+
825
+ Returns
826
+ -------
827
+ prod_rule_seq_list : list of integers
828
+ each element corresponds to a sequence of production rules to generate each hypergraph.
829
+ """
830
+ prod_rule_seq_list = []
831
+ for each_hg in hg_list:
832
+ clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
833
+
834
+ prod_rule_seq = []
835
+ stack = []
836
+
837
+ # get a pair of myself and children
838
+ children = sorted(list(clique_tree[root_node].keys()))
839
+
840
+ # extract a temporary production rule
841
+ prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
842
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
843
+
844
+ # update the production rule list
845
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
846
+ children = reorder_children(root_node, children, prod_rule, clique_tree)
847
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
848
+ prod_rule_seq.append(prod_rule_id)
849
+
850
+ while len(stack) != 0:
851
+ # get a triple of parent, myself, and children
852
+ parent, myself = stack.pop()
853
+ children = sorted(list(dict(clique_tree[myself]).keys()))
854
+ children.remove(parent)
855
+
856
+ # extract a temp prod rule
857
+ prod_rule = extract_prod_rule(
858
+ clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
859
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
860
+
861
+ # update the prod rule list
862
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
863
+ children = reorder_children(myself, children, prod_rule, clique_tree)
864
+ stack.extend([(myself, each_child) for each_child in children[::-1]])
865
+ prod_rule_seq.append(prod_rule_id)
866
+ prod_rule_seq_list.append(prod_rule_seq)
867
+ self._compute_stats()
868
+ return prod_rule_seq_list
869
+
870
+
871
+ def reorder_children(myself, children, prod_rule, clique_tree):
872
+ """ reorder children so that they match the order in `prod_rule`.
873
+
874
+ Parameters
875
+ ----------
876
+ myself : int
877
+ children : list of int
878
+ prod_rule : ProductionRule
879
+ clique_tree : nx.Graph
880
+
881
+ Returns
882
+ -------
883
+ new_children : list of str
884
+ reordered children
885
+ """
886
+ perm = {} # key : `nt_idx`, val : child
887
+ for each_edge in prod_rule.rhs.edges:
888
+ if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
889
+ for each_child in children:
890
+ common_node_set = set(
891
+ common_node_list(clique_tree.nodes[myself]["subhg"],
892
+ clique_tree.nodes[each_child]["subhg"])[0])
893
+ if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
894
+ assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
895
+ perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
896
+ new_children = []
897
+ assert len(perm) == len(children)
898
+ for i in range(len(perm)):
899
+ new_children.append(perm[i])
900
+ return new_children
901
+
902
+
903
+ def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
904
+ """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
905
+
906
+ Parameters
907
+ ----------
908
+ parent_hg : Hypergraph
909
+ myself_hg : Hypergraph
910
+ children_hg_list : list of Hypergraph
911
+
912
+ Returns
913
+ -------
914
+ ProductionRule, consisting of
915
+ lhs : Hypergraph or None
916
+ rhs : Hypergraph
917
+ """
918
+ def _add_ext_node(hg, ext_nodes):
919
+ """ mark nodes to be external (ordered ids are assigned)
920
+
921
+ Parameters
922
+ ----------
923
+ hg : UndirectedHypergraph
924
+ ext_nodes : list of str
925
+ list of external nodes
926
+
927
+ Returns
928
+ -------
929
+ hg : Hypergraph
930
+ nodes in `ext_nodes` are marked to be external
931
+ """
932
+ ext_id = 0
933
+ ext_id_exists = []
934
+ for each_node in ext_nodes:
935
+ ext_id_exists.append('ext_id' in hg.node_attr(each_node))
936
+ if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
937
+ raise ValueError
938
+ if not all(ext_id_exists):
939
+ for each_node in ext_nodes:
940
+ hg.node_attr(each_node)['ext_id'] = ext_id
941
+ ext_id += 1
942
+ return hg
943
+
944
+ def _check_aromatic(hg, node_list):
945
+ is_aromatic = False
946
+ node_aromatic_list = []
947
+ for each_node in node_list:
948
+ if hg.node_attr(each_node)['symbol'].is_aromatic:
949
+ is_aromatic = True
950
+ node_aromatic_list.append(True)
951
+ else:
952
+ node_aromatic_list.append(False)
953
+ return is_aromatic, node_aromatic_list
954
+
955
+ def _check_ring(hg):
956
+ for each_edge in hg.edges:
957
+ if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
958
+ return False
959
+ return True
960
+
961
+ if parent_hg is None:
962
+ lhs = Hypergraph()
963
+ node_list = []
964
+ else:
965
+ lhs = Hypergraph()
966
+ node_list, edge_exists = common_node_list(parent_hg, myself_hg)
967
+ for each_node in node_list:
968
+ lhs.add_node(each_node,
969
+ deepcopy(myself_hg.node_attr(each_node)))
970
+ is_aromatic, _ = _check_aromatic(parent_hg, node_list)
971
+ for_ring = _check_ring(myself_hg)
972
+ bond_symbol_list = []
973
+ for each_node in node_list:
974
+ bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
975
+ lhs.add_edge(
976
+ node_list,
977
+ attr_dict=dict(
978
+ terminal=False,
979
+ edge_exists=edge_exists,
980
+ symbol=NTSymbol(
981
+ degree=len(node_list),
982
+ is_aromatic=is_aromatic,
983
+ bond_symbol_list=bond_symbol_list,
984
+ for_ring=for_ring)))
985
+ try:
986
+ lhs = _add_ext_node(lhs, node_list)
987
+ except ValueError:
988
+ import pdb; pdb.set_trace()
989
+
990
+ rhs = remove_tmp_edge(deepcopy(myself_hg))
991
+ #rhs = remove_ext_node(rhs)
992
+ #rhs = remove_nt_edge(rhs)
993
+ try:
994
+ rhs = _add_ext_node(rhs, node_list)
995
+ except ValueError:
996
+ import pdb; pdb.set_trace()
997
+
998
+ nt_idx = 0
999
+ if children_hg_list is not None:
1000
+ for each_child_hg in children_hg_list:
1001
+ node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
1002
+ is_aromatic, _ = _check_aromatic(myself_hg, node_list)
1003
+ for_ring = _check_ring(each_child_hg)
1004
+ bond_symbol_list = []
1005
+ for each_node in node_list:
1006
+ bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
1007
+ rhs.add_edge(
1008
+ node_list,
1009
+ attr_dict=dict(
1010
+ terminal=False,
1011
+ nt_idx=nt_idx,
1012
+ edge_exists=edge_exists,
1013
+ symbol=NTSymbol(degree=len(node_list),
1014
+ is_aromatic=is_aromatic,
1015
+ bond_symbol_list=bond_symbol_list,
1016
+ for_ring=for_ring)))
1017
+ nt_idx += 1
1018
+ prod_rule = ProductionRule(lhs, rhs)
1019
+ prod_rule.subhg_idx = subhg_idx
1020
+ if DEBUG:
1021
+ if sorted(list(prod_rule.ext_node.keys())) \
1022
+ != list(np.arange(len(prod_rule.ext_node))):
1023
+ raise RuntimeError('ext_id is not continuous')
1024
+ return prod_rule
1025
+
1026
+
1027
+ def _find_root(clique_tree):
1028
+ max_node = None
1029
+ num_nodes_max = -np.inf
1030
+ for each_node in clique_tree.nodes:
1031
+ if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
1032
+ max_node = each_node
1033
+ num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
1034
+ '''
1035
+ children = sorted(list(clique_tree[each_node].keys()))
1036
+ prod_rule = extract_prod_rule(None,
1037
+ clique_tree.nodes[each_node]["subhg"],
1038
+ [clique_tree.nodes[each_child]["subhg"]
1039
+ for each_child in children])
1040
+ for each_start_rule in start_rule_list:
1041
+ if prod_rule.is_same(each_start_rule):
1042
+ return each_node
1043
+ '''
1044
+ return max_node
1045
+
1046
+ def remove_ext_node(hg):
1047
+ for each_node in hg.nodes:
1048
+ hg.node_attr(each_node).pop('ext_id', None)
1049
+ return hg
1050
+
1051
+ def remove_nt_edge(hg):
1052
+ remove_edge_list = []
1053
+ for each_edge in hg.edges:
1054
+ if not hg.edge_attr(each_edge)['terminal']:
1055
+ remove_edge_list.append(each_edge)
1056
+ hg.remove_edges(remove_edge_list)
1057
+ return hg
1058
+
1059
+ def remove_tmp_edge(hg):
1060
+ remove_edge_list = []
1061
+ for each_edge in hg.edges:
1062
+ if hg.edge_attr(each_edge).get('tmp', False):
1063
+ remove_edge_list.append(each_edge)
1064
+ hg.remove_edges(remove_edge_list)
1065
+ return hg
models/mhg_model/graph_grammar/graph_grammar/symbols.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+
15
+ """ Title """
16
+
17
+ __author__ = "Hiroshi Kajino <[email protected]>"
18
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
19
+ __version__ = "0.1"
20
+ __date__ = "Jan 1 2018"
21
+
22
+ from typing import List
23
+
24
+ class TSymbol(object):
25
+
26
+ ''' terminal symbol
27
+
28
+ Attributes
29
+ ----------
30
+ degree : int
31
+ the number of nodes in a hyperedge
32
+ is_aromatic : bool
33
+ whether or not the hyperedge is in an aromatic ring
34
+ symbol : str
35
+ atomic symbol
36
+ num_explicit_Hs : int
37
+ the number of hydrogens associated to this hyperedge
38
+ formal_charge : int
39
+ charge
40
+ chirality : int
41
+ chirality
42
+ '''
43
+
44
+ def __init__(self, degree, is_aromatic,
45
+ symbol, num_explicit_Hs, formal_charge, chirality):
46
+ self.degree = degree
47
+ self.is_aromatic = is_aromatic
48
+ self.symbol = symbol
49
+ self.num_explicit_Hs = num_explicit_Hs
50
+ self.formal_charge = formal_charge
51
+ self.chirality = chirality
52
+
53
+ @property
54
+ def terminal(self):
55
+ return True
56
+
57
+ def __eq__(self, other):
58
+ if not isinstance(other, TSymbol):
59
+ return False
60
+ if self.degree != other.degree:
61
+ return False
62
+ if self.is_aromatic != other.is_aromatic:
63
+ return False
64
+ if self.symbol != other.symbol:
65
+ return False
66
+ if self.num_explicit_Hs != other.num_explicit_Hs:
67
+ return False
68
+ if self.formal_charge != other.formal_charge:
69
+ return False
70
+ if self.chirality != other.chirality:
71
+ return False
72
+ return True
73
+
74
+ def __hash__(self):
75
+ return self.__str__().__hash__()
76
+
77
+ def __str__(self):
78
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
79
+ f'symbol={self.symbol}, '\
80
+ f'num_explicit_Hs={self.num_explicit_Hs}, '\
81
+ f'formal_charge={self.formal_charge}, chirality={self.chirality}'
82
+
83
+
84
+ class NTSymbol(object):
85
+
86
+ ''' non-terminal symbol
87
+
88
+ Attributes
89
+ ----------
90
+ degree : int
91
+ degree of the hyperedge
92
+ is_aromatic : bool
93
+ if True, at least one of the associated bonds must be aromatic.
94
+ node_aromatic_list : list of bool
95
+ indicate whether each of the nodes is aromatic or not.
96
+ bond_type_list : list of int
97
+ bond type of each node"
98
+ '''
99
+
100
+ def __init__(self, degree: int, is_aromatic: bool,
101
+ bond_symbol_list: list,
102
+ for_ring=False):
103
+ self.degree = degree
104
+ self.is_aromatic = is_aromatic
105
+ self.for_ring = for_ring
106
+ self.bond_symbol_list = bond_symbol_list
107
+
108
+ @property
109
+ def terminal(self) -> bool:
110
+ return False
111
+
112
+ @property
113
+ def symbol(self):
114
+ return f'NT{self.degree}'
115
+
116
+ def __eq__(self, other) -> bool:
117
+ if not isinstance(other, NTSymbol):
118
+ return False
119
+
120
+ if self.degree != other.degree:
121
+ return False
122
+ if self.is_aromatic != other.is_aromatic:
123
+ return False
124
+ if self.for_ring != other.for_ring:
125
+ return False
126
+ if len(self.bond_symbol_list) != len(other.bond_symbol_list):
127
+ return False
128
+ for each_idx in range(len(self.bond_symbol_list)):
129
+ if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
130
+ return False
131
+ return True
132
+
133
+ def __hash__(self):
134
+ return self.__str__().__hash__()
135
+
136
+ def __str__(self) -> str:
137
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
138
+ f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
139
+ f'for_ring={self.for_ring}'
140
+
141
+
142
+ class BondSymbol(object):
143
+
144
+
145
+ ''' Bond symbol
146
+
147
+ Attributes
148
+ ----------
149
+ is_aromatic : bool
150
+ if True, at least one of the associated bonds must be aromatic.
151
+ bond_type : int
152
+ bond type of each node"
153
+ '''
154
+
155
+ def __init__(self, is_aromatic: bool,
156
+ bond_type: int,
157
+ stereo: int):
158
+ self.is_aromatic = is_aromatic
159
+ self.bond_type = bond_type
160
+ self.stereo = stereo
161
+
162
+ def __eq__(self, other) -> bool:
163
+ if not isinstance(other, BondSymbol):
164
+ return False
165
+
166
+ if self.is_aromatic != other.is_aromatic:
167
+ return False
168
+ if self.bond_type != other.bond_type:
169
+ return False
170
+ if self.stereo != other.stereo:
171
+ return False
172
+ return True
173
+
174
+ def __hash__(self):
175
+ return self.__str__().__hash__()
176
+
177
+ def __str__(self) -> str:
178
+ return f'is_aromatic={self.is_aromatic}, '\
179
+ f'bond_type={self.bond_type}, '\
180
+ f'stereo={self.stereo}, '
models/mhg_model/graph_grammar/graph_grammar/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from ..hypergraph import Hypergraph
22
+ from copy import deepcopy
23
+ from typing import List
24
+ import numpy as np
25
+
26
+
27
+ def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
28
+ """ return a list of common nodes
29
+
30
+ Parameters
31
+ ----------
32
+ hg1, hg2 : Hypergraph
33
+
34
+ Returns
35
+ -------
36
+ list of str
37
+ list of common nodes
38
+ """
39
+ if hg1 is None or hg2 is None:
40
+ return [], False
41
+ else:
42
+ node_set = hg1.nodes.intersection(hg2.nodes)
43
+ node_dict = {}
44
+ if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
45
+ for each_node in node_set:
46
+ node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
47
+ else:
48
+ for each_node in node_set:
49
+ node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
50
+ node_list = []
51
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
52
+ node_list.append(each_key)
53
+ edge_name = hg1.has_edge(node_list, ignore_order=True)
54
+ if edge_name:
55
+ if not hg1.edge_attr(edge_name).get('terminal', True):
56
+ node_list = hg1.nodes_in_edge(edge_name)
57
+ return node_list, True
58
+ else:
59
+ return node_list, False
60
+
61
+
62
+ def _node_match(node1, node2):
63
+ # if the nodes are hyperedges, `atom_attr` determines the match
64
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
65
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
66
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
67
+ # bond_symbol
68
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
69
+ else:
70
+ return False
71
+
72
+ def _easy_node_match(node1, node2):
73
+ # if the nodes are hyperedges, `atom_attr` determines the match
74
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
75
+ return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
76
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
77
+ # bond_symbol
78
+ return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
79
+ and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
80
+ else:
81
+ return False
82
+
83
+
84
+ def _node_match_prod_rule(node1, node2, ignore_order=False):
85
+ # if the nodes are hyperedges, `atom_attr` determines the match
86
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
87
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
88
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
89
+ # ext_id, order4hrg, bond_symbol
90
+ if ignore_order:
91
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
92
+ else:
93
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
94
+ and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
95
+ else:
96
+ return False
97
+
98
+
99
+ def _edge_match(edge1, edge2, ignore_order=False):
100
+ #return True
101
+ if ignore_order:
102
+ return True
103
+ else:
104
+ return edge1["order"] == edge2["order"]
105
+
106
+ def masked_softmax(logit, mask):
107
+ ''' compute a probability distribution from logit
108
+
109
+ Parameters
110
+ ----------
111
+ logit : array-like, length D
112
+ each element indicates how each dimension is likely to be chosen
113
+ (the larger, the more likely)
114
+ mask : array-like, length D
115
+ each element is either 0 or 1.
116
+ if 0, the dimension is ignored
117
+ when computing the probability distribution.
118
+
119
+ Returns
120
+ -------
121
+ prob_dist : array, length D
122
+ probability distribution computed from logit.
123
+ if `mask[d] = 0`, `prob_dist[d] = 0`.
124
+ '''
125
+ if logit.shape != mask.shape:
126
+ raise ValueError('logit and mask must have the same shape')
127
+ c = np.max(logit)
128
+ exp_logit = np.exp(logit - c) * mask
129
+ sum_exp_logit = exp_logit @ mask
130
+ return exp_logit / sum_exp_logit
models/mhg_model/graph_grammar/hypergraph.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 31 2018"
20
+
21
+ from copy import deepcopy
22
+ from typing import List, Dict, Tuple
23
+ import networkx as nx
24
+ import numpy as np
25
+ import os
26
+
27
+
28
+ class Hypergraph(object):
29
+ '''
30
+ A class of a hypergraph.
31
+ Each hyperedge can be ordered. For the ordered case,
32
+ edges adjacent to the hyperedge node are labeled by their orders.
33
+
34
+ Attributes
35
+ ----------
36
+ hg : nx.Graph
37
+ a bipartite graph representation of a hypergraph
38
+ edge_idx : int
39
+ total number of hyperedges that exist so far
40
+ '''
41
+ def __init__(self):
42
+ self.hg = nx.Graph()
43
+ self.edge_idx = 0
44
+ self.nodes = set([])
45
+ self.num_nodes = 0
46
+ self.edges = set([])
47
+ self.num_edges = 0
48
+ self.nodes_in_edge_dict = {}
49
+
50
+ def add_node(self, node: str, attr_dict=None):
51
+ ''' add a node to hypergraph
52
+
53
+ Parameters
54
+ ----------
55
+ node : str
56
+ node name
57
+ attr_dict : dict
58
+ dictionary of node attributes
59
+ '''
60
+ self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
61
+ if node not in self.nodes:
62
+ self.num_nodes += 1
63
+ self.nodes.add(node)
64
+
65
+ def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
66
+ ''' add an edge consisting of nodes `node_list`
67
+
68
+ Parameters
69
+ ----------
70
+ node_list : list
71
+ ordered list of nodes that consist the edge
72
+ attr_dict : dict
73
+ dictionary of edge attributes
74
+ '''
75
+ if edge_name is None:
76
+ edge = 'e{}'.format(self.edge_idx)
77
+ else:
78
+ assert edge_name not in self.edges
79
+ edge = edge_name
80
+ self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
81
+ if edge not in self.edges:
82
+ self.num_edges += 1
83
+ self.edges.add(edge)
84
+ self.nodes_in_edge_dict[edge] = node_list
85
+ if type(node_list) == list:
86
+ for node_idx, each_node in enumerate(node_list):
87
+ self.hg.add_edge(edge, each_node, order=node_idx)
88
+ if each_node not in self.nodes:
89
+ self.num_nodes += 1
90
+ self.nodes.add(each_node)
91
+
92
+ elif type(node_list) == set:
93
+ for each_node in node_list:
94
+ self.hg.add_edge(edge, each_node, order=-1)
95
+ if each_node not in self.nodes:
96
+ self.num_nodes += 1
97
+ self.nodes.add(each_node)
98
+ else:
99
+ raise ValueError
100
+ self.edge_idx += 1
101
+ return edge
102
+
103
+ def remove_node(self, node: str, remove_connected_edges=True):
104
+ ''' remove a node
105
+
106
+ Parameters
107
+ ----------
108
+ node : str
109
+ node name
110
+ remove_connected_edges : bool
111
+ if True, remove edges that are adjacent to the node
112
+ '''
113
+ if remove_connected_edges:
114
+ connected_edges = deepcopy(self.adj_edges(node))
115
+ for each_edge in connected_edges:
116
+ self.remove_edge(each_edge)
117
+ self.hg.remove_node(node)
118
+ self.num_nodes -= 1
119
+ self.nodes.remove(node)
120
+
121
+ def remove_nodes(self, node_iter, remove_connected_edges=True):
122
+ ''' remove a set of nodes
123
+
124
+ Parameters
125
+ ----------
126
+ node_iter : iterator of strings
127
+ nodes to be removed
128
+ remove_connected_edges : bool
129
+ if True, remove edges that are adjacent to the node
130
+ '''
131
+ for each_node in node_iter:
132
+ self.remove_node(each_node, remove_connected_edges)
133
+
134
+ def remove_edge(self, edge: str):
135
+ ''' remove an edge
136
+
137
+ Parameters
138
+ ----------
139
+ edge : str
140
+ edge to be removed
141
+ '''
142
+ self.hg.remove_node(edge)
143
+ self.edges.remove(edge)
144
+ self.num_edges -= 1
145
+ self.nodes_in_edge_dict.pop(edge)
146
+
147
+ def remove_edges(self, edge_iter):
148
+ ''' remove a set of edges
149
+
150
+ Parameters
151
+ ----------
152
+ edge_iter : iterator of strings
153
+ edges to be removed
154
+ '''
155
+ for each_edge in edge_iter:
156
+ self.remove_edge(each_edge)
157
+
158
+ def remove_edges_with_attr(self, edge_attr_dict):
159
+ remove_edge_list = []
160
+ for each_edge in self.edges:
161
+ satisfy = True
162
+ for each_key, each_val in edge_attr_dict.items():
163
+ if not satisfy:
164
+ break
165
+ try:
166
+ if self.edge_attr(each_edge)[each_key] != each_val:
167
+ satisfy = False
168
+ except KeyError:
169
+ satisfy = False
170
+ if satisfy:
171
+ remove_edge_list.append(each_edge)
172
+ self.remove_edges(remove_edge_list)
173
+
174
+ def remove_subhg(self, subhg):
175
+ ''' remove subhypergraph.
176
+ all of the hyperedges are removed.
177
+ each node of subhg is removed if its degree becomes 0 after removing hyperedges.
178
+
179
+ Parameters
180
+ ----------
181
+ subhg : Hypergraph
182
+ '''
183
+ for each_edge in subhg.edges:
184
+ self.remove_edge(each_edge)
185
+ for each_node in subhg.nodes:
186
+ if self.degree(each_node) == 0:
187
+ self.remove_node(each_node)
188
+
189
+ def nodes_in_edge(self, edge):
190
+ ''' return an ordered list of nodes in a given edge.
191
+
192
+ Parameters
193
+ ----------
194
+ edge : str
195
+ edge whose nodes are returned
196
+
197
+ Returns
198
+ -------
199
+ list or set
200
+ ordered list or set of nodes that belong to the edge
201
+ '''
202
+ if edge.startswith('e'):
203
+ return self.nodes_in_edge_dict[edge]
204
+ else:
205
+ adj_node_list = self.hg.adj[edge]
206
+ adj_node_order_list = []
207
+ adj_node_name_list = []
208
+ for each_node in adj_node_list:
209
+ adj_node_order_list.append(adj_node_list[each_node]['order'])
210
+ adj_node_name_list.append(each_node)
211
+ if adj_node_order_list == [-1] * len(adj_node_order_list):
212
+ return set(adj_node_name_list)
213
+ else:
214
+ return [adj_node_name_list[each_idx] for each_idx
215
+ in np.argsort(adj_node_order_list)]
216
+
217
+ def adj_edges(self, node):
218
+ ''' return a dict of adjacent hyperedges
219
+
220
+ Parameters
221
+ ----------
222
+ node : str
223
+
224
+ Returns
225
+ -------
226
+ set
227
+ set of edges that are adjacent to `node`
228
+ '''
229
+ return self.hg.adj[node]
230
+
231
+ def adj_nodes(self, node):
232
+ ''' return a set of adjacent nodes
233
+
234
+ Parameters
235
+ ----------
236
+ node : str
237
+
238
+ Returns
239
+ -------
240
+ set
241
+ set of nodes that are adjacent to `node`
242
+ '''
243
+ node_set = set([])
244
+ for each_adj_edge in self.adj_edges(node):
245
+ node_set.update(set(self.nodes_in_edge(each_adj_edge)))
246
+ node_set.discard(node)
247
+ return node_set
248
+
249
+ def has_edge(self, node_list, ignore_order=False):
250
+ for each_edge in self.edges:
251
+ if ignore_order:
252
+ if set(self.nodes_in_edge(each_edge)) == set(node_list):
253
+ return each_edge
254
+ else:
255
+ if self.nodes_in_edge(each_edge) == node_list:
256
+ return each_edge
257
+ return False
258
+
259
+ def degree(self, node):
260
+ return len(self.hg.adj[node])
261
+
262
+ def degrees(self):
263
+ return {each_node: self.degree(each_node) for each_node in self.nodes}
264
+
265
+ def edge_degree(self, edge):
266
+ return len(self.nodes_in_edge(edge))
267
+
268
+ def edge_degrees(self):
269
+ return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
270
+
271
+ def is_adj(self, node1, node2):
272
+ return node1 in self.adj_nodes(node2)
273
+
274
+ def adj_subhg(self, node, ident_node_dict=None):
275
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
276
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
277
+
278
+ Parameters
279
+ ----------
280
+ node : str
281
+ ident_node_dict : dict
282
+ dict containing identical nodes. see `get_identical_node_dict` for more details
283
+
284
+ Returns
285
+ -------
286
+ subhg : Hypergraph
287
+ """
288
+ if ident_node_dict is None:
289
+ ident_node_dict = self.get_identical_node_dict()
290
+ adj_node_set = set(ident_node_dict[node])
291
+ adj_edge_set = set([])
292
+ for each_node in ident_node_dict[node]:
293
+ adj_edge_set.update(set(self.adj_edges(each_node)))
294
+ fixed_adj_edge_set = deepcopy(adj_edge_set)
295
+ for each_edge in fixed_adj_edge_set:
296
+ other_nodes = self.nodes_in_edge(each_edge)
297
+ adj_node_set.update(other_nodes)
298
+
299
+ # if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
300
+ for each_node in other_nodes:
301
+ for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
302
+ if len(set(self.nodes_in_edge(other_edge)) \
303
+ - set(self.nodes_in_edge(each_edge))) == 0:
304
+ adj_edge_set.update(set([other_edge]))
305
+ subhg = Hypergraph()
306
+ for each_node in adj_node_set:
307
+ subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
308
+ for each_edge in adj_edge_set:
309
+ subhg.add_edge(self.nodes_in_edge(each_edge),
310
+ attr_dict=self.edge_attr(each_edge),
311
+ edge_name=each_edge)
312
+ subhg.edge_idx = self.edge_idx
313
+ return subhg
314
+
315
+ def get_subhg(self, node_list, edge_list, ident_node_dict=None):
316
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
317
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
318
+
319
+ Parameters
320
+ ----------
321
+ node : str
322
+ ident_node_dict : dict
323
+ dict containing identical nodes. see `get_identical_node_dict` for more details
324
+
325
+ Returns
326
+ -------
327
+ subhg : Hypergraph
328
+ """
329
+ if ident_node_dict is None:
330
+ ident_node_dict = self.get_identical_node_dict()
331
+ adj_node_set = set([])
332
+ for each_node in node_list:
333
+ adj_node_set.update(set(ident_node_dict[each_node]))
334
+ adj_edge_set = set(edge_list)
335
+
336
+ subhg = Hypergraph()
337
+ for each_node in adj_node_set:
338
+ subhg.add_node(each_node,
339
+ attr_dict=deepcopy(self.node_attr(each_node)))
340
+ for each_edge in adj_edge_set:
341
+ subhg.add_edge(self.nodes_in_edge(each_edge),
342
+ attr_dict=deepcopy(self.edge_attr(each_edge)),
343
+ edge_name=each_edge)
344
+ subhg.edge_idx = self.edge_idx
345
+ return subhg
346
+
347
+ def copy(self):
348
+ ''' return a copy of the object
349
+
350
+ Returns
351
+ -------
352
+ Hypergraph
353
+ '''
354
+ return deepcopy(self)
355
+
356
+ def node_attr(self, node):
357
+ return self.hg.nodes[node]['attr_dict']
358
+
359
+ def edge_attr(self, edge):
360
+ return self.hg.nodes[edge]['attr_dict']
361
+
362
+ def set_node_attr(self, node, attr_dict):
363
+ for each_key, each_val in attr_dict.items():
364
+ self.hg.nodes[node]['attr_dict'][each_key] = each_val
365
+
366
+ def set_edge_attr(self, edge, attr_dict):
367
+ for each_key, each_val in attr_dict.items():
368
+ self.hg.nodes[edge]['attr_dict'][each_key] = each_val
369
+
370
+ def get_identical_node_dict(self):
371
+ ''' get identical nodes
372
+ nodes are identical if they share the same set of adjacent edges.
373
+
374
+ Returns
375
+ -------
376
+ ident_node_dict : dict
377
+ ident_node_dict[node] returns a list of nodes that are identical to `node`.
378
+ '''
379
+ ident_node_dict = {}
380
+ for each_node in self.nodes:
381
+ ident_node_list = []
382
+ for each_other_node in self.nodes:
383
+ if each_other_node == each_node:
384
+ ident_node_list.append(each_other_node)
385
+ elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
386
+ and len(self.adj_edges(each_node)) != 0:
387
+ ident_node_list.append(each_other_node)
388
+ ident_node_dict[each_node] = ident_node_list
389
+ return ident_node_dict
390
+ '''
391
+ ident_node_dict = {}
392
+ for each_node in self.nodes:
393
+ ident_node_dict[each_node] = [each_node]
394
+ return ident_node_dict
395
+ '''
396
+
397
+ def get_leaf_edge(self):
398
+ ''' get an edge that is incident only to one edge
399
+
400
+ Returns
401
+ -------
402
+ if exists, return a leaf edge. otherwise, return None.
403
+ '''
404
+ for each_edge in self.edges:
405
+ if len(self.adj_nodes(each_edge)) == 1:
406
+ if 'tmp' not in self.edge_attr(each_edge):
407
+ return each_edge
408
+ return None
409
+
410
+ def get_nontmp_edge(self):
411
+ for each_edge in self.edges:
412
+ if 'tmp' not in self.edge_attr(each_edge):
413
+ return each_edge
414
+ return None
415
+
416
+ def is_subhg(self, hg):
417
+ ''' return whether this hypergraph is a subhypergraph of `hg`
418
+
419
+ Returns
420
+ -------
421
+ True if self \in hg,
422
+ False otherwise.
423
+ '''
424
+ for each_node in self.nodes:
425
+ if each_node not in hg.nodes:
426
+ return False
427
+ for each_edge in self.edges:
428
+ if each_edge not in hg.edges:
429
+ return False
430
+ return True
431
+
432
+ def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
433
+ ''' if `node` is in a cycle, then return True. otherwise, False.
434
+
435
+ Parameters
436
+ ----------
437
+ node : str
438
+ node in a hypergraph
439
+ visited : list
440
+ list of visited nodes, used for recursion
441
+ parent : str
442
+ parent node, used to eliminate a cycle consisting of two nodes and one edge.
443
+
444
+ Returns
445
+ -------
446
+ bool
447
+ '''
448
+ if visited is None:
449
+ visited = []
450
+ if parent == '':
451
+ visited = []
452
+ if root_node == '':
453
+ root_node = node
454
+ visited.append(node)
455
+ for each_adj_node in self.adj_nodes(node):
456
+ if each_adj_node not in visited:
457
+ if self.in_cycle(each_adj_node, visited, node, root_node):
458
+ return True
459
+ elif each_adj_node != parent and each_adj_node == root_node:
460
+ return True
461
+ return False
462
+
463
+
464
+ def draw(self, file_path=None, with_node=False, with_edge_name=False):
465
+ ''' draw hypergraph
466
+ '''
467
+ import graphviz
468
+ G = graphviz.Graph(format='png')
469
+ for each_node in self.nodes:
470
+ if 'ext_id' in self.node_attr(each_node):
471
+ G.node(each_node, label='',
472
+ shape='circle', width='0.1', height='0.1', style='filled',
473
+ fillcolor='black')
474
+ else:
475
+ if with_node:
476
+ G.node(each_node, label='',
477
+ shape='circle', width='0.1', height='0.1', style='filled',
478
+ fillcolor='gray')
479
+ edge_list = []
480
+ for each_edge in self.edges:
481
+ if self.edge_attr(each_edge).get('terminal', False):
482
+ G.node(each_edge,
483
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
484
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
485
+ fontcolor='black', shape='square')
486
+ elif self.edge_attr(each_edge).get('tmp', False):
487
+ G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
488
+ fontcolor='black', shape='square')
489
+ else:
490
+ G.node(each_edge,
491
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
492
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
493
+ fontcolor='black', shape='square', style='filled')
494
+ if with_node:
495
+ for each_node in self.nodes_in_edge(each_edge):
496
+ G.edge(each_edge, each_node)
497
+ else:
498
+ for each_node in self.nodes_in_edge(each_edge):
499
+ if 'ext_id' in self.node_attr(each_node)\
500
+ and set([each_node, each_edge]) not in edge_list:
501
+ G.edge(each_edge, each_node)
502
+ edge_list.append(set([each_node, each_edge]))
503
+ for each_other_edge in self.adj_nodes(each_edge):
504
+ if set([each_edge, each_other_edge]) not in edge_list:
505
+ num_bond = 0
506
+ common_node_set = set(self.nodes_in_edge(each_edge))\
507
+ .intersection(set(self.nodes_in_edge(each_other_edge)))
508
+ for each_node in common_node_set:
509
+ if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
510
+ num_bond += self.node_attr(each_node)['symbol'].bond_type
511
+ elif self.node_attr(each_node)['symbol'].bond_type in [12]:
512
+ num_bond += 1
513
+ else:
514
+ raise NotImplementedError('unsupported bond type')
515
+ for _ in range(num_bond):
516
+ G.edge(each_edge, each_other_edge)
517
+ edge_list.append(set([each_edge, each_other_edge]))
518
+ if file_path is not None:
519
+ G.render(file_path, cleanup=True)
520
+ #os.remove(file_path)
521
+ return G
522
+
523
+ def is_dividable(self, node):
524
+ _hg = deepcopy(self.hg)
525
+ _hg.remove_node(node)
526
+ return (not nx.is_connected(_hg))
527
+
528
+ def divide(self, node):
529
+ subhg_list = []
530
+
531
+ hg_wo_node = deepcopy(self)
532
+ hg_wo_node.remove_node(node, remove_connected_edges=False)
533
+ connected_components = nx.connected_components(hg_wo_node.hg)
534
+ for each_component in connected_components:
535
+ node_list = [node]
536
+ edge_list = []
537
+ node_list.extend([each_node for each_node in each_component
538
+ if each_node.startswith('bond_')])
539
+ edge_list.extend([each_edge for each_edge in each_component
540
+ if each_edge.startswith('e')])
541
+ subhg_list.append(self.get_subhg(node_list, edge_list))
542
+ #subhg_list[-1].set_node_attr(node, {'divided': True})
543
+ return subhg_list
544
+
models/mhg_model/graph_grammar/io/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/io/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (679 Bytes). View file
 
models/mhg_model/graph_grammar/io/__pycache__/smi.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
models/mhg_model/graph_grammar/io/smi.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 12 2018"
20
+
21
+ from copy import deepcopy
22
+ from rdkit import Chem
23
+ from rdkit import RDLogger
24
+ import networkx as nx
25
+ import numpy as np
26
+ from ..hypergraph import Hypergraph
27
+ from ..graph_grammar.symbols import TSymbol, BondSymbol
28
+
29
+ # supress warnings
30
+ lg = RDLogger.logger()
31
+ lg.setLevel(RDLogger.CRITICAL)
32
+
33
+
34
+ class HGGen(object):
35
+ """
36
+ load .smi file and yield a hypergraph.
37
+
38
+ Attributes
39
+ ----------
40
+ path_to_file : str
41
+ path to .smi file
42
+ kekulize : bool
43
+ kekulize or not
44
+ add_Hs : bool
45
+ add implicit hydrogens to the molecule or not.
46
+ all_single : bool
47
+ if True, all multiple bonds are summarized into a single bond with some attributes
48
+
49
+ Yields
50
+ ------
51
+ Hypergraph
52
+ """
53
+ def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
54
+ self.num_line = 1
55
+ self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
56
+ self.kekulize = kekulize
57
+ self.add_Hs = add_Hs
58
+ self.all_single = all_single
59
+
60
+ def __iter__(self):
61
+ return self
62
+
63
+ def __next__(self):
64
+ '''
65
+ each_mol = None
66
+ while each_mol is None:
67
+ each_mol = next(self.mol_gen)
68
+ '''
69
+ # not ignoring parse errors
70
+ each_mol = next(self.mol_gen)
71
+ if each_mol is None:
72
+ raise ValueError(f'incorrect smiles in line {self.num_line}')
73
+ else:
74
+ self.num_line += 1
75
+ return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
76
+
77
+
78
+ def mol_to_bipartite(mol, kekulize):
79
+ """
80
+ get a bipartite representation of a molecule.
81
+
82
+ Parameters
83
+ ----------
84
+ mol : rdkit.Chem.rdchem.Mol
85
+ molecule object
86
+
87
+ Returns
88
+ -------
89
+ nx.Graph
90
+ a bipartite graph representing which bond is connected to which atoms.
91
+ """
92
+ try:
93
+ mol = standardize_stereo(mol)
94
+ except KeyError:
95
+ print(Chem.MolToSmiles(mol))
96
+ raise KeyError
97
+
98
+ if kekulize:
99
+ Chem.Kekulize(mol)
100
+
101
+ bipartite_g = nx.Graph()
102
+ for each_atom in mol.GetAtoms():
103
+ bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
104
+ atom_attr=atom_attr(each_atom, kekulize))
105
+
106
+ for each_bond in mol.GetBonds():
107
+ bond_idx = each_bond.GetIdx()
108
+ bipartite_g.add_node(
109
+ f"bond_{bond_idx}",
110
+ bond_attr=bond_attr(each_bond, kekulize))
111
+ bipartite_g.add_edge(
112
+ f"atom_{each_bond.GetBeginAtomIdx()}",
113
+ f"bond_{bond_idx}")
114
+ bipartite_g.add_edge(
115
+ f"atom_{each_bond.GetEndAtomIdx()}",
116
+ f"bond_{bond_idx}")
117
+ return bipartite_g
118
+
119
+
120
+ def mol_to_hg(mol, kekulize, add_Hs):
121
+ """
122
+ get a bipartite representation of a molecule.
123
+
124
+ Parameters
125
+ ----------
126
+ mol : rdkit.Chem.rdchem.Mol
127
+ molecule object
128
+ kekulize : bool
129
+ kekulize or not
130
+ add_Hs : bool
131
+ add implicit hydrogens to the molecule or not.
132
+
133
+ Returns
134
+ -------
135
+ Hypergraph
136
+ """
137
+ if add_Hs:
138
+ mol = Chem.AddHs(mol)
139
+
140
+ if kekulize:
141
+ Chem.Kekulize(mol)
142
+
143
+ bipartite_g = mol_to_bipartite(mol, kekulize)
144
+ hg = Hypergraph()
145
+ for each_atom in [each_node for each_node in bipartite_g.nodes()
146
+ if each_node.startswith('atom_')]:
147
+ node_set = set([])
148
+ for each_bond in bipartite_g.adj[each_atom]:
149
+ hg.add_node(each_bond,
150
+ attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
151
+ node_set.add(each_bond)
152
+ hg.add_edge(node_set,
153
+ attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
154
+ return hg
155
+
156
+
157
+ def hg_to_mol(hg, verbose=False):
158
+ """ convert a hypergraph into Mol object
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+
164
+ Returns
165
+ -------
166
+ mol : Chem.RWMol
167
+ """
168
+ mol = Chem.RWMol()
169
+ atom_dict = {}
170
+ bond_set = set([])
171
+ for each_edge in hg.edges:
172
+ atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
173
+ atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
174
+ atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
175
+ atom.SetChiralTag(
176
+ Chem.rdchem.ChiralType.values[
177
+ hg.edge_attr(each_edge)['symbol'].chirality])
178
+ atom_idx = mol.AddAtom(atom)
179
+ atom_dict[each_edge] = atom_idx
180
+
181
+ for each_node in hg.nodes:
182
+ edge_1, edge_2 = hg.adj_edges(each_node)
183
+ if edge_1+edge_2 not in bond_set:
184
+ if hg.node_attr(each_node)['symbol'].bond_type <= 3:
185
+ num_bond = hg.node_attr(each_node)['symbol'].bond_type
186
+ elif hg.node_attr(each_node)['symbol'].bond_type == 12:
187
+ num_bond = 1
188
+ else:
189
+ raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
190
+ _ = mol.AddBond(atom_dict[edge_1],
191
+ atom_dict[edge_2],
192
+ order=Chem.rdchem.BondType.values[num_bond])
193
+ bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
194
+
195
+ # stereo
196
+ mol.GetBondWithIdx(bond_idx).SetStereo(
197
+ Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
198
+ bond_set.update([edge_1+edge_2])
199
+ bond_set.update([edge_2+edge_1])
200
+ mol.UpdatePropertyCache()
201
+ mol = mol.GetMol()
202
+ not_stereo_mol = deepcopy(mol)
203
+ if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
204
+ raise RuntimeError('no valid molecule was obtained.')
205
+ try:
206
+ mol = set_stereo(mol)
207
+ is_stereo = True
208
+ except:
209
+ import traceback
210
+ traceback.print_exc()
211
+ is_stereo = False
212
+ mol_tmp = deepcopy(mol)
213
+ Chem.SetAromaticity(mol_tmp)
214
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
215
+ mol = mol_tmp
216
+ else:
217
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
218
+ mol = not_stereo_mol
219
+ mol.UpdatePropertyCache()
220
+ Chem.GetSymmSSSR(mol)
221
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
222
+ if verbose:
223
+ return mol, is_stereo
224
+ else:
225
+ return mol
226
+
227
+ def hgs_to_mols(hg_list, ignore_error=False):
228
+ if ignore_error:
229
+ mol_list = []
230
+ for each_hg in hg_list:
231
+ try:
232
+ mol = hg_to_mol(each_hg)
233
+ except:
234
+ mol = None
235
+ mol_list.append(mol)
236
+ else:
237
+ mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
238
+ return mol_list
239
+
240
+ def hgs_to_smiles(hg_list, ignore_error=False):
241
+ mol_list = hgs_to_mols(hg_list, ignore_error)
242
+ smiles_list = []
243
+ for each_mol in mol_list:
244
+ try:
245
+ smiles_list.append(
246
+ Chem.MolToSmiles(
247
+ Chem.MolFromSmiles(
248
+ Chem.MolToSmiles(
249
+ each_mol))))
250
+ except:
251
+ smiles_list.append(None)
252
+ return smiles_list
253
+
254
+ def atom_attr(atom, kekulize):
255
+ """
256
+ get atom's attributes
257
+
258
+ Parameters
259
+ ----------
260
+ atom : rdkit.Chem.rdchem.Atom
261
+ kekulize : bool
262
+ kekulize or not
263
+
264
+ Returns
265
+ -------
266
+ atom_attr : dict
267
+ "is_aromatic" : bool
268
+ the atom is aromatic or not.
269
+ "smarts" : str
270
+ SMARTS representation of the atom.
271
+ """
272
+ if kekulize:
273
+ return {'terminal': True,
274
+ 'is_in_ring': atom.IsInRing(),
275
+ 'symbol': TSymbol(degree=0,
276
+ #degree=atom.GetTotalDegree(),
277
+ is_aromatic=False,
278
+ symbol=atom.GetSymbol(),
279
+ num_explicit_Hs=atom.GetNumExplicitHs(),
280
+ formal_charge=atom.GetFormalCharge(),
281
+ chirality=atom.GetChiralTag().real
282
+ )}
283
+ else:
284
+ return {'terminal': True,
285
+ 'is_in_ring': atom.IsInRing(),
286
+ 'symbol': TSymbol(degree=0,
287
+ #degree=atom.GetTotalDegree(),
288
+ is_aromatic=atom.GetIsAromatic(),
289
+ symbol=atom.GetSymbol(),
290
+ num_explicit_Hs=atom.GetNumExplicitHs(),
291
+ formal_charge=atom.GetFormalCharge(),
292
+ chirality=atom.GetChiralTag().real
293
+ )}
294
+
295
+ def bond_attr(bond, kekulize):
296
+ """
297
+ get atom's attributes
298
+
299
+ Parameters
300
+ ----------
301
+ bond : rdkit.Chem.rdchem.Bond
302
+ kekulize : bool
303
+ kekulize or not
304
+
305
+ Returns
306
+ -------
307
+ bond_attr : dict
308
+ "bond_type" : int
309
+ {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
310
+ 1: rdkit.Chem.rdchem.BondType.SINGLE,
311
+ 2: rdkit.Chem.rdchem.BondType.DOUBLE,
312
+ 3: rdkit.Chem.rdchem.BondType.TRIPLE,
313
+ 4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
314
+ 5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
315
+ 6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
316
+ 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
317
+ 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
318
+ 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
319
+ 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
320
+ 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
321
+ 12: rdkit.Chem.rdchem.BondType.AROMATIC,
322
+ 13: rdkit.Chem.rdchem.BondType.IONIC,
323
+ 14: rdkit.Chem.rdchem.BondType.HYDROGEN,
324
+ 15: rdkit.Chem.rdchem.BondType.THREECENTER,
325
+ 16: rdkit.Chem.rdchem.BondType.DATIVEONE,
326
+ 17: rdkit.Chem.rdchem.BondType.DATIVE,
327
+ 18: rdkit.Chem.rdchem.BondType.DATIVEL,
328
+ 19: rdkit.Chem.rdchem.BondType.DATIVER,
329
+ 20: rdkit.Chem.rdchem.BondType.OTHER,
330
+ 21: rdkit.Chem.rdchem.BondType.ZERO}
331
+ """
332
+ if kekulize:
333
+ is_aromatic = False
334
+ if bond.GetBondType().real == 12:
335
+ bond_type = 1
336
+ else:
337
+ bond_type = bond.GetBondType().real
338
+ else:
339
+ is_aromatic = bond.GetIsAromatic()
340
+ bond_type = bond.GetBondType().real
341
+ return {'symbol': BondSymbol(is_aromatic=is_aromatic,
342
+ bond_type=bond_type,
343
+ stereo=int(bond.GetStereo())),
344
+ 'is_in_ring': bond.IsInRing()}
345
+
346
+
347
+ def standardize_stereo(mol):
348
+ '''
349
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
350
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
351
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
352
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
353
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
354
+
355
+ '''
356
+ # mol = Chem.AddHs(mol) # this removes CIPRank !!!
357
+ for each_bond in mol.GetBonds():
358
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
359
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
360
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
361
+ atom_idx_1 = each_bond.GetStereoAtoms()[0]
362
+ atom_idx_2 = each_bond.GetStereoAtoms()[1]
363
+ if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
364
+ begin_atom_idx = atom_idx_1
365
+ end_atom_idx = atom_idx_2
366
+ else:
367
+ begin_atom_idx = atom_idx_2
368
+ end_atom_idx = atom_idx_1
369
+
370
+ begin_another_atom_idx = None
371
+ assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
372
+ for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
373
+ each_neighbor_idx = each_neighbor.GetIdx()
374
+ if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
375
+ begin_another_atom_idx = each_neighbor_idx
376
+
377
+ end_another_atom_idx = None
378
+ assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
379
+ for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
380
+ each_neighbor_idx = each_neighbor.GetIdx()
381
+ if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
382
+ end_another_atom_idx = each_neighbor_idx
383
+
384
+ '''
385
+ relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
386
+ '''
387
+ begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
388
+ end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
389
+ try:
390
+ begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
391
+ except:
392
+ begin_another_atom_rank = np.inf
393
+ try:
394
+ end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
395
+ except:
396
+ end_another_atom_rank = np.inf
397
+ if begin_atom_rank < begin_another_atom_rank\
398
+ and end_atom_rank < end_another_atom_rank:
399
+ pass
400
+ elif begin_atom_rank < begin_another_atom_rank\
401
+ and end_atom_rank > end_another_atom_rank:
402
+ # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
403
+ if each_bond.GetStereo() == 2:
404
+ # set stereo
405
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
406
+ # set bond dir
407
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
408
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
409
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
410
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
411
+ elif each_bond.GetStereo() == 3:
412
+ # set stereo
413
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
414
+ # set bond dir
415
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
416
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
417
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
418
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
419
+ else:
420
+ raise ValueError
421
+ each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
422
+ elif begin_atom_rank > begin_another_atom_rank\
423
+ and end_atom_rank < end_another_atom_rank:
424
+ # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
425
+ if each_bond.GetStereo() == 2:
426
+ # set stereo
427
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
428
+ # set bond dir
429
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
430
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
431
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
432
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
433
+ elif each_bond.GetStereo() == 3:
434
+ # set stereo
435
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
436
+ # set bond dir
437
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
438
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
439
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
440
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
441
+ else:
442
+ raise ValueError
443
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
444
+ elif begin_atom_rank > begin_another_atom_rank\
445
+ and end_atom_rank > end_another_atom_rank:
446
+ # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
447
+ if each_bond.GetStereo() == 2:
448
+ # set bond dir
449
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
450
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
451
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
452
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
453
+ elif each_bond.GetStereo() == 3:
454
+ # set bond dir
455
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
456
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
457
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
458
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
459
+ else:
460
+ raise ValueError
461
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
462
+ else:
463
+ raise RuntimeError
464
+ return mol
465
+
466
+
467
+ def set_stereo(mol):
468
+ '''
469
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
470
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
471
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
472
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
473
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
474
+ '''
475
+ _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
476
+ Chem.Kekulize(_mol, True)
477
+ substruct_match = mol.GetSubstructMatch(_mol)
478
+ if not substruct_match:
479
+ ''' mol and _mol are kekulized.
480
+ sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
481
+ '''
482
+ Chem.SetAromaticity(mol)
483
+ Chem.SetAromaticity(_mol)
484
+ substruct_match = mol.GetSubstructMatch(_mol)
485
+ try:
486
+ atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
487
+ except:
488
+ raise ValueError('two molecules obtained from the same data do not match.')
489
+
490
+ for each_bond in mol.GetBonds():
491
+ begin_atom_idx = each_bond.GetBeginAtomIdx()
492
+ end_atom_idx = each_bond.GetEndAtomIdx()
493
+ _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
494
+ _bond.SetStereo(each_bond.GetStereo())
495
+
496
+ mol = _mol
497
+ for each_bond in mol.GetBonds():
498
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
499
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
500
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
501
+ begin_atom_idx_set = set([each_neighbor.GetIdx()
502
+ for each_neighbor
503
+ in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
504
+ if each_neighbor.GetIdx() != end_stereo_atom_idx])
505
+ end_atom_idx_set = set([each_neighbor.GetIdx()
506
+ for each_neighbor
507
+ in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
508
+ if each_neighbor.GetIdx() != begin_stereo_atom_idx])
509
+ if not begin_atom_idx_set:
510
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
511
+ continue
512
+ if not end_atom_idx_set:
513
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
514
+ continue
515
+ if len(begin_atom_idx_set) == 1:
516
+ begin_atom_idx = begin_atom_idx_set.pop()
517
+ begin_another_atom_idx = None
518
+ if len(end_atom_idx_set) == 1:
519
+ end_atom_idx = end_atom_idx_set.pop()
520
+ end_another_atom_idx = None
521
+ if len(begin_atom_idx_set) == 2:
522
+ atom_idx_1 = begin_atom_idx_set.pop()
523
+ atom_idx_2 = begin_atom_idx_set.pop()
524
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
525
+ begin_atom_idx = atom_idx_1
526
+ begin_another_atom_idx = atom_idx_2
527
+ else:
528
+ begin_atom_idx = atom_idx_2
529
+ begin_another_atom_idx = atom_idx_1
530
+ if len(end_atom_idx_set) == 2:
531
+ atom_idx_1 = end_atom_idx_set.pop()
532
+ atom_idx_2 = end_atom_idx_set.pop()
533
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
534
+ end_atom_idx = atom_idx_1
535
+ end_another_atom_idx = atom_idx_2
536
+ else:
537
+ end_atom_idx = atom_idx_2
538
+ end_another_atom_idx = atom_idx_1
539
+
540
+ if each_bond.GetStereo() == 2: # same side
541
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
542
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
543
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
544
+ elif each_bond.GetStereo() == 3: # opposite side
545
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
546
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
547
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
548
+ else:
549
+ raise ValueError
550
+ return mol
551
+
552
+
553
+ def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
554
+ if atom_idx_1 is None or atom_idx_2 is None:
555
+ return mol
556
+ else:
557
+ mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
558
+ return mol
559
+
models/mhg_model/graph_grammar/nn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """