Spaces:
Running
Running
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +4 -0
- app.py +717 -0
- data/.DS_Store +0 -0
- data/bace/test.csv +0 -0
- data/bace/train.csv +0 -0
- data/bace/valid.csv +0 -0
- data/esol/test.csv +109 -0
- data/esol/train.csv +0 -0
- img/.DS_Store +0 -0
- img/img1.png +0 -0
- img/img2.png +0 -0
- img/img3.png +0 -0
- img/img4.png +0 -0
- img/img5.png +0 -0
- img/introduction.png +0 -0
- img/latent_multi_bace.png +0 -0
- log.csv +1 -0
- models/.DS_Store +0 -0
- models/__pycache__/fm4m.cpython-310.pyc +0 -0
- models/fm4m.py +663 -0
- models/mhg_model/.DS_Store +0 -0
- models/mhg_model/README.md +75 -0
- models/mhg_model/__init__.py +5 -0
- models/mhg_model/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mhg_model/__pycache__/load.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/__init__.py +19 -0
- models/mhg_model/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/__pycache__/hypergraph.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/algo/__init__.py +20 -0
- models/mhg_model/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/algo/tree_decomposition.py +821 -0
- models/mhg_model/graph_grammar/graph_grammar/__init__.py +20 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/graph_grammar/base.py +30 -0
- models/mhg_model/graph_grammar/graph_grammar/corpus.py +152 -0
- models/mhg_model/graph_grammar/graph_grammar/hrg.py +1065 -0
- models/mhg_model/graph_grammar/graph_grammar/symbols.py +180 -0
- models/mhg_model/graph_grammar/graph_grammar/utils.py +130 -0
- models/mhg_model/graph_grammar/hypergraph.py +544 -0
- models/mhg_model/graph_grammar/io/__init__.py +20 -0
- models/mhg_model/graph_grammar/io/__pycache__/__init__.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/io/__pycache__/smi.cpython-310.pyc +0 -0
- models/mhg_model/graph_grammar/io/smi.py +559 -0
- models/mhg_model/graph_grammar/nn/__init__.py +11 -0
README.md
CHANGED
@@ -8,6 +8,10 @@ sdk_version: 5.4.0
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
models:
|
12 |
+
- ibm/materials.smi-ted
|
13 |
+
- ibm/materials.selfies-ted
|
14 |
+
- ibm/materials.mhg-ged
|
15 |
---
|
16 |
|
17 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from PIL import Image
|
5 |
+
from rdkit.Chem import Descriptors, QED, Draw
|
6 |
+
from rdkit.Chem.Crippen import MolLogP
|
7 |
+
import pandas as pd
|
8 |
+
from rdkit.Contrib.SA_Score import sascorer
|
9 |
+
from rdkit.Chem import DataStructs, AllChem
|
10 |
+
from transformers import BartForConditionalGeneration, AutoTokenizer, AutoModel
|
11 |
+
from transformers.modeling_outputs import BaseModelOutput
|
12 |
+
import selfies as sf
|
13 |
+
from rdkit import Chem
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import umap
|
17 |
+
import pickle
|
18 |
+
import xgboost as xgb
|
19 |
+
from sklearn.svm import SVR
|
20 |
+
from sklearn.linear_model import LinearRegression
|
21 |
+
from sklearn.kernel_ridge import KernelRidge
|
22 |
+
import json
|
23 |
+
|
24 |
+
import os
|
25 |
+
|
26 |
+
os.environ["OMP_MAX_ACTIVE_LEVELS"] = "1"
|
27 |
+
|
28 |
+
# my_theme = gr.Theme.from_hub("ysharma/steampunk")
|
29 |
+
# my_theme = gr.themes.Glass()
|
30 |
+
|
31 |
+
"""
|
32 |
+
# カスタムテーマ設定
|
33 |
+
theme = gr.themes.Default().set(
|
34 |
+
body_background_fill="#000000", # 背景色を黒に設定
|
35 |
+
text_color="#FFFFFF", # テキスト色を白に設定
|
36 |
+
)
|
37 |
+
"""
|
38 |
+
"""
|
39 |
+
import sys
|
40 |
+
sys.path.append("models")
|
41 |
+
sys.path.append("../models")
|
42 |
+
sys.path.append("../")"""
|
43 |
+
|
44 |
+
|
45 |
+
# Get the current file's directory
|
46 |
+
base_dir = os.path.dirname(__file__)
|
47 |
+
print("Base Dir : ", base_dir)
|
48 |
+
|
49 |
+
import models.fm4m as fm4m
|
50 |
+
|
51 |
+
|
52 |
+
# Function to display molecule image from SMILES
|
53 |
+
def smiles_to_image(smiles):
|
54 |
+
mol = Chem.MolFromSmiles(smiles)
|
55 |
+
if mol:
|
56 |
+
img = Draw.MolToImage(mol)
|
57 |
+
return img
|
58 |
+
return None
|
59 |
+
|
60 |
+
|
61 |
+
# Function to get canonical SMILES
|
62 |
+
def get_canonical_smiles(smiles):
|
63 |
+
mol = Chem.MolFromSmiles(smiles)
|
64 |
+
if mol:
|
65 |
+
return Chem.MolToSmiles(mol, canonical=True)
|
66 |
+
return None
|
67 |
+
|
68 |
+
|
69 |
+
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
|
70 |
+
smiles_image_mapping = {
|
71 |
+
"Mol 1": {"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", "image": "img/img1.png"},
|
72 |
+
# Example SMILES for ethanol
|
73 |
+
"Mol 2": {"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", "image": "img/img2.png"},
|
74 |
+
# Example SMILES for butane
|
75 |
+
"Mol 3": {"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
|
76 |
+
"image": "img/img3.png"}, # Example SMILES for ethylamine
|
77 |
+
"Mol 4": {"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", "image": "img/img4.png"},
|
78 |
+
# Example SMILES for diethyl ether
|
79 |
+
"Mol 5": {"smiles": "C=CCS[C@@H](C)CC(=O)OCC", "image": "img/img5.png"} # Example SMILES for chloroethane
|
80 |
+
}
|
81 |
+
|
82 |
+
datasets = [" ","BACE", "ESOL", "Custom Dataset"]
|
83 |
+
|
84 |
+
models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"]
|
85 |
+
|
86 |
+
fusion_available = ["Concat"]
|
87 |
+
|
88 |
+
global log_df
|
89 |
+
log_df = pd.DataFrame(columns=["Selected Models", "Dataset", "Task", "Result"])
|
90 |
+
|
91 |
+
|
92 |
+
def log_selection(models, dataset, task_type, result, log_df):
|
93 |
+
# Append the new entry to the DataFrame
|
94 |
+
new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_type, "Result": result}
|
95 |
+
updated_log_df = log_df.append(new_entry, ignore_index=True)
|
96 |
+
return updated_log_df
|
97 |
+
|
98 |
+
|
99 |
+
# Function to handle evaluation and logging
|
100 |
+
def save_rep(models, dataset, task_type, eval_output):
|
101 |
+
return
|
102 |
+
def evaluate_and_log(models, dataset, task_type, eval_output):
|
103 |
+
task_dic = {'Classification': 'CLS', 'Regression': 'RGR'}
|
104 |
+
result = f"{eval_output}"#display_eval(models, dataset, task_type, fusion_type=None)
|
105 |
+
result = result.replace(" Score", "")
|
106 |
+
|
107 |
+
new_entry = {"Selected Models": str(models), "Dataset": dataset, "Task": task_dic[task_type], "Result": result}
|
108 |
+
new_entry_df = pd.DataFrame([new_entry])
|
109 |
+
|
110 |
+
log_df = pd.read_csv('log.csv', index_col=0)
|
111 |
+
log_df = pd.concat([new_entry_df, log_df])
|
112 |
+
|
113 |
+
log_df.to_csv('log.csv')
|
114 |
+
|
115 |
+
return log_df
|
116 |
+
|
117 |
+
|
118 |
+
log_df = pd.read_csv('log.csv', index_col=0)
|
119 |
+
|
120 |
+
|
121 |
+
# Load images for selection
|
122 |
+
def load_image(path):
|
123 |
+
return Image.open(smiles_image_mapping[path]["image"])# Image.1open(path)
|
124 |
+
|
125 |
+
|
126 |
+
# Function to handle image selection
|
127 |
+
def handle_image_selection(image_key):
|
128 |
+
smiles = smiles_image_mapping[image_key]["smiles"]
|
129 |
+
mol_image = smiles_to_image(smiles)
|
130 |
+
return smiles, mol_image
|
131 |
+
|
132 |
+
|
133 |
+
def calculate_properties(smiles):
|
134 |
+
mol = Chem.MolFromSmiles(smiles)
|
135 |
+
if mol:
|
136 |
+
qed = QED.qed(mol)
|
137 |
+
logp = MolLogP(mol)
|
138 |
+
sa = sascorer.calculateScore(mol)
|
139 |
+
wt = Descriptors.MolWt(mol)
|
140 |
+
return qed, sa, logp, wt
|
141 |
+
return None, None, None, None
|
142 |
+
|
143 |
+
|
144 |
+
# Function to calculate Tanimoto similarity
|
145 |
+
def calculate_tanimoto(smiles1, smiles2):
|
146 |
+
mol1 = Chem.MolFromSmiles(smiles1)
|
147 |
+
mol2 = Chem.MolFromSmiles(smiles2)
|
148 |
+
if mol1 and mol2:
|
149 |
+
# fp1 = FingerprintMols.FingerprintMol(mol1)
|
150 |
+
# fp2 = FingerprintMols.FingerprintMol(mol2)
|
151 |
+
fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2)
|
152 |
+
fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2)
|
153 |
+
return round(DataStructs.FingerprintSimilarity(fp1, fp2), 2)
|
154 |
+
return None
|
155 |
+
|
156 |
+
|
157 |
+
#with open("models/selfies_model/bart-2908.pickle", "rb") as input_file:
|
158 |
+
# gen_model, gen_tokenizer = pickle.load(input_file)
|
159 |
+
|
160 |
+
gen_tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
|
161 |
+
gen_model = BartForConditionalGeneration.from_pretrained("ibm/materials.selfies-ted")
|
162 |
+
|
163 |
+
|
164 |
+
def generate(latent_vector, mask):
|
165 |
+
encoder_outputs = BaseModelOutput(latent_vector)
|
166 |
+
decoder_output = gen_model.generate(encoder_outputs=encoder_outputs, attention_mask=mask,
|
167 |
+
max_new_tokens=64, do_sample=True, top_k=5, top_p=0.95, num_return_sequences=1)
|
168 |
+
selfies = gen_tokenizer.batch_decode(decoder_output, skip_special_tokens=True)
|
169 |
+
outs = []
|
170 |
+
for i in selfies:
|
171 |
+
outs.append(sf.decoder(i.replace("] [", "][")))
|
172 |
+
return outs
|
173 |
+
|
174 |
+
|
175 |
+
def perturb_latent(latent_vecs, noise_scale=0.5):
|
176 |
+
modified_vec = torch.tensor(np.random.uniform(0, 1, latent_vecs.shape) * noise_scale,
|
177 |
+
dtype=torch.float32) + latent_vecs
|
178 |
+
return modified_vec
|
179 |
+
|
180 |
+
|
181 |
+
def encode(selfies):
|
182 |
+
encoding = gen_tokenizer(selfies, return_tensors='pt', max_length=128, truncation=True, padding='max_length')
|
183 |
+
input_ids = encoding['input_ids']
|
184 |
+
attention_mask = encoding['attention_mask']
|
185 |
+
outputs = gen_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
186 |
+
model_output = outputs.last_hidden_state
|
187 |
+
|
188 |
+
"""input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
|
189 |
+
sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
|
190 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
191 |
+
model_output = sum_embeddings / sum_mask"""
|
192 |
+
return model_output, attention_mask
|
193 |
+
|
194 |
+
|
195 |
+
# Function to generate canonical SMILES and molecule image
|
196 |
+
def generate_canonical(smiles):
|
197 |
+
s = sf.encoder(smiles)
|
198 |
+
selfie = s.replace("][", "] [")
|
199 |
+
latent_vec, mask = encode([selfie])
|
200 |
+
gen_mol = None
|
201 |
+
for i in range(5, 51):
|
202 |
+
noise = i / 10
|
203 |
+
perturbed_latent = perturb_latent(latent_vec, noise_scale=noise)
|
204 |
+
gen = generate(perturbed_latent, mask)
|
205 |
+
gen_mol = Chem.MolToSmiles(Chem.MolFromSmiles(gen[0]))
|
206 |
+
if gen_mol != Chem.MolToSmiles(Chem.MolFromSmiles(smiles)): break
|
207 |
+
|
208 |
+
if gen_mol:
|
209 |
+
# Calculate properties for ref and gen molecules
|
210 |
+
ref_properties = calculate_properties(smiles)
|
211 |
+
gen_properties = calculate_properties(gen_mol)
|
212 |
+
tanimoto_similarity = calculate_tanimoto(smiles, gen_mol)
|
213 |
+
|
214 |
+
# Prepare the table with ref mol and gen mol
|
215 |
+
data = {
|
216 |
+
"Property": ["QED", "SA", "LogP", "Mol Wt", "Tanimoto Similarity"],
|
217 |
+
"Reference Mol": [ref_properties[0], ref_properties[1], ref_properties[2], ref_properties[3],
|
218 |
+
tanimoto_similarity],
|
219 |
+
"Generated Mol": [gen_properties[0], gen_properties[1], gen_properties[2], gen_properties[3], ""]
|
220 |
+
}
|
221 |
+
df = pd.DataFrame(data)
|
222 |
+
|
223 |
+
# Display molecule image of canonical smiles
|
224 |
+
mol_image = smiles_to_image(gen_mol)
|
225 |
+
|
226 |
+
return df, gen_mol, mol_image
|
227 |
+
return "Invalid SMILES", None, None
|
228 |
+
|
229 |
+
|
230 |
+
# Function to display evaluation score
|
231 |
+
def display_eval(selected_models, dataset, task_type, downstream, fusion_type):
|
232 |
+
result = None
|
233 |
+
|
234 |
+
try:
|
235 |
+
downstream_model = downstream.split("*")[0].lstrip()
|
236 |
+
downstream_model = downstream_model.rstrip()
|
237 |
+
hyp_param = downstream.split("*")[-1].lstrip()
|
238 |
+
hyp_param = hyp_param.rstrip()
|
239 |
+
hyp_param = hyp_param.replace("nan", "float('nan')")
|
240 |
+
params = eval(hyp_param)
|
241 |
+
except:
|
242 |
+
downstream_model = downstream.split("*")[0].lstrip()
|
243 |
+
downstream_model = downstream_model.rstrip()
|
244 |
+
params = None
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
try:
|
250 |
+
if not selected_models:
|
251 |
+
return "Please select at least one enabled model."
|
252 |
+
|
253 |
+
if task_type == "Classification":
|
254 |
+
global roc_auc, fpr, tpr, x_batch, y_batch
|
255 |
+
elif task_type == "Regression":
|
256 |
+
global RMSE, y_batch_test, y_prob
|
257 |
+
|
258 |
+
if len(selected_models) > 1:
|
259 |
+
if task_type == "Classification":
|
260 |
+
#result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
261 |
+
# downstream_model="XGBClassifier",
|
262 |
+
# dataset=dataset.lower())
|
263 |
+
if downstream_model == "Default Settings":
|
264 |
+
downstream_model = "DefaultClassifier"
|
265 |
+
params = None
|
266 |
+
result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
267 |
+
downstream_model=downstream_model,
|
268 |
+
params = params,
|
269 |
+
dataset=dataset)
|
270 |
+
|
271 |
+
elif task_type == "Regression":
|
272 |
+
#result, RMSE, y_batch_test, y_prob = fm4m.multi_modal(model_list=selected_models,
|
273 |
+
# downstream_model="XGBRegressor",
|
274 |
+
# dataset=dataset.lower())
|
275 |
+
|
276 |
+
if downstream_model == "Default Settings":
|
277 |
+
downstream_model = "DefaultRegressor"
|
278 |
+
params = None
|
279 |
+
|
280 |
+
result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.multi_modal(model_list=selected_models,
|
281 |
+
downstream_model=downstream_model,
|
282 |
+
params=params,
|
283 |
+
dataset=dataset)
|
284 |
+
|
285 |
+
else:
|
286 |
+
if task_type == "Classification":
|
287 |
+
#result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
288 |
+
# downstream_model="XGBClassifier",
|
289 |
+
# dataset=dataset.lower())
|
290 |
+
if downstream_model == "Default Settings":
|
291 |
+
downstream_model = "DefaultClassifier"
|
292 |
+
params = None
|
293 |
+
|
294 |
+
result, roc_auc, fpr, tpr, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
295 |
+
downstream_model=downstream_model,
|
296 |
+
params=params,
|
297 |
+
dataset=dataset)
|
298 |
+
|
299 |
+
elif task_type == "Regression":
|
300 |
+
#result, RMSE, y_batch_test, y_prob = fm4m.single_modal(model=selected_models[0],
|
301 |
+
# downstream_model="XGBRegressor",
|
302 |
+
# dataset=dataset.lower())
|
303 |
+
|
304 |
+
if downstream_model == "Default Settings":
|
305 |
+
downstream_model = "DefaultRegressor"
|
306 |
+
params = None
|
307 |
+
|
308 |
+
result, RMSE, y_batch_test, y_prob, x_batch, y_batch = fm4m.single_modal(model=selected_models[0],
|
309 |
+
downstream_model=downstream_model,
|
310 |
+
params=params,
|
311 |
+
dataset=dataset)
|
312 |
+
|
313 |
+
if result == None:
|
314 |
+
result = "Data & Model Setting is incorrect"
|
315 |
+
except Exception as e:
|
316 |
+
return f"An error occurred: {e}"
|
317 |
+
return f"{result}"
|
318 |
+
|
319 |
+
|
320 |
+
# Function to handle plot display
|
321 |
+
def display_plot(plot_type):
|
322 |
+
fig, ax = plt.subplots()
|
323 |
+
|
324 |
+
if plot_type == "Latent Space":
|
325 |
+
global x_batch, y_batch
|
326 |
+
ax.set_title("T-SNE Plot")
|
327 |
+
# reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
|
328 |
+
# features_umap = reducer.fit_transform(x_batch[:500])
|
329 |
+
# x = y_batch.values[:500]
|
330 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
331 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
332 |
+
class_0 = x_batch # features_umap[index_0]
|
333 |
+
class_1 = y_batch # features_umap[index_1]
|
334 |
+
|
335 |
+
"""with open("latent_multi_bace.pkl", "rb") as f:
|
336 |
+
class_0, class_1 = pickle.load(f)
|
337 |
+
"""
|
338 |
+
plt.scatter(class_1[:, 0], class_1[:, 1], c='red', label='Class 1')
|
339 |
+
plt.scatter(class_0[:, 0], class_0[:, 1], c='blue', label='Class 0')
|
340 |
+
|
341 |
+
ax.set_xlabel('Feature 1')
|
342 |
+
ax.set_ylabel('Feature 2')
|
343 |
+
ax.set_title('Dataset Distribution')
|
344 |
+
|
345 |
+
elif plot_type == "ROC-AUC":
|
346 |
+
global roc_auc, fpr, tpr
|
347 |
+
ax.set_title("ROC-AUC Curve")
|
348 |
+
try:
|
349 |
+
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.4f})')
|
350 |
+
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
351 |
+
ax.set_xlim([0.0, 1.0])
|
352 |
+
ax.set_ylim([0.0, 1.05])
|
353 |
+
except:
|
354 |
+
pass
|
355 |
+
ax.set_xlabel('False Positive Rate')
|
356 |
+
ax.set_ylabel('True Positive Rate')
|
357 |
+
ax.set_title('Receiver Operating Characteristic')
|
358 |
+
ax.legend(loc='lower right')
|
359 |
+
|
360 |
+
elif plot_type == "Parity Plot":
|
361 |
+
global RMSE, y_batch_test, y_prob
|
362 |
+
ax.set_title("Parity plot")
|
363 |
+
|
364 |
+
# change format
|
365 |
+
try:
|
366 |
+
print(y_batch_test)
|
367 |
+
print(y_prob)
|
368 |
+
y_batch_test = np.array(y_batch_test, dtype=float)
|
369 |
+
y_prob = np.array(y_prob, dtype=float)
|
370 |
+
ax.scatter(y_batch_test, y_prob, color="blue", label=f"Predicted vs Actual (RMSE: {RMSE:.4f})")
|
371 |
+
min_val = min(min(y_batch_test), min(y_prob))
|
372 |
+
max_val = max(max(y_batch_test), max(y_prob))
|
373 |
+
ax.plot([min_val, max_val], [min_val, max_val], 'r-')
|
374 |
+
|
375 |
+
except:
|
376 |
+
|
377 |
+
y_batch_test = []
|
378 |
+
y_prob = []
|
379 |
+
RMSE = None
|
380 |
+
print(y_batch_test)
|
381 |
+
print(y_prob)
|
382 |
+
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
ax.set_xlabel('Actual Values')
|
388 |
+
ax.set_ylabel('Predicted Values')
|
389 |
+
|
390 |
+
ax.legend(loc='lower right')
|
391 |
+
return fig
|
392 |
+
|
393 |
+
|
394 |
+
# Predefined dataset paths (these should be adjusted to your file paths)
|
395 |
+
predefined_datasets = {
|
396 |
+
"BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
|
397 |
+
"ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
|
398 |
+
}
|
399 |
+
|
400 |
+
|
401 |
+
# Function to load a predefined dataset from the local path
|
402 |
+
def load_predefined_dataset(dataset_name):
|
403 |
+
val = predefined_datasets.get(dataset_name)
|
404 |
+
try: file_path = val.split(",")[0]
|
405 |
+
except:file_path=False
|
406 |
+
|
407 |
+
if file_path:
|
408 |
+
df = pd.read_csv(file_path)
|
409 |
+
return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns)), f"{dataset_name.lower()}"
|
410 |
+
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]), f"Dataset not found"
|
411 |
+
|
412 |
+
|
413 |
+
# Function to display the head of the uploaded CSV file
|
414 |
+
def display_csv_head(file):
|
415 |
+
if file is not None:
|
416 |
+
# Load the CSV file into a DataFrame
|
417 |
+
df = pd.read_csv(file.name)
|
418 |
+
return df.head(), gr.update(choices=list(df.columns)), gr.update(choices=list(df.columns))
|
419 |
+
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
|
420 |
+
|
421 |
+
|
422 |
+
# Function to handle dataset selection (predefined or custom)
|
423 |
+
def handle_dataset_selection(selected_dataset):
|
424 |
+
if selected_dataset == "Custom Dataset":
|
425 |
+
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
|
426 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
427 |
+
visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
428 |
+
else:
|
429 |
+
#[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
|
430 |
+
# input_column_selector, output_column_selector]
|
431 |
+
|
432 |
+
|
433 |
+
|
434 |
+
# Load the predefined dataset from its local path
|
435 |
+
#return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
|
436 |
+
# visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
437 |
+
#return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(
|
438 |
+
# visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
439 |
+
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(
|
440 |
+
visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
441 |
+
|
442 |
+
|
443 |
+
# Function to select input and output columns and display a message
|
444 |
+
def select_columns(input_column, output_column, train_data, test_data,dataset_name):
|
445 |
+
if input_column and output_column:
|
446 |
+
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}"
|
447 |
+
return "Please select both input and output columns."
|
448 |
+
|
449 |
+
def set_dataname(dataset_name, dataset_selector ):
|
450 |
+
if dataset_selector == "Custom Dataset":
|
451 |
+
return f"{dataset_name}"
|
452 |
+
return f"{dataset_selector}"
|
453 |
+
|
454 |
+
# Function to create model based on user input
|
455 |
+
def create_model(model_name, max_depth=None, n_estimators=None, alpha=None, degree=None, kernel=None):
|
456 |
+
if model_name == "XGBClassifier":
|
457 |
+
model = xgb.XGBClassifier(objective='binary:logistic',eval_metric= 'auc', max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
|
458 |
+
elif model_name == "SVR":
|
459 |
+
model = SVR(degree=degree, kernel=kernel)
|
460 |
+
elif model_name == "Kernel Ridge":
|
461 |
+
model = KernelRidge(alpha=alpha, degree=degree, kernel=kernel)
|
462 |
+
elif model_name == "Linear Regression":
|
463 |
+
model = LinearRegression()
|
464 |
+
elif model_name == "Default - Auto":
|
465 |
+
model = "Default Settings"
|
466 |
+
return f"{model}"
|
467 |
+
else:
|
468 |
+
return "Model not supported."
|
469 |
+
|
470 |
+
return f"{model_name} * {model.get_params()}"
|
471 |
+
def model_selector(model_name):
|
472 |
+
# Dynamically return the appropriate hyperparameter components based on the selected model
|
473 |
+
if model_name == "XGBClassifier":
|
474 |
+
return (
|
475 |
+
gr.Slider(1, 10, label="max_depth"),
|
476 |
+
gr.Slider(50, 500, label="n_estimators"),
|
477 |
+
gr.Slider(0.1, 10.0, step=0.1, label="alpha")
|
478 |
+
)
|
479 |
+
elif model_name == "SVR":
|
480 |
+
return (
|
481 |
+
gr.Slider(1, 5, label="degree"),
|
482 |
+
gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
|
483 |
+
)
|
484 |
+
elif model_name == "Kernel Ridge":
|
485 |
+
return (
|
486 |
+
gr.Slider(0.1, 10.0, step=0.1, label="alpha"),
|
487 |
+
gr.Slider(1, 5, label="degree"),
|
488 |
+
gr.Dropdown(["rbf", "poly", "linear"], label="kernel")
|
489 |
+
)
|
490 |
+
elif model_name == "Linear Regression":
|
491 |
+
return () # No hyperparameters for Linear Regression
|
492 |
+
else:
|
493 |
+
return ()
|
494 |
+
|
495 |
+
|
496 |
+
|
497 |
+
# Define the Gradio layout
|
498 |
+
# with gr.Blocks(theme=my_theme) as demo:
|
499 |
+
with gr.Blocks() as demo:
|
500 |
+
with gr.Row():
|
501 |
+
# Left Column
|
502 |
+
with gr.Column():
|
503 |
+
gr.HTML('''
|
504 |
+
<div style="background-color: #6A8EAE; color: #FFFFFF; padding: 10px;">
|
505 |
+
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Data & Model Setting</h3>
|
506 |
+
</div>
|
507 |
+
''')
|
508 |
+
# gr.Markdown("## Data & Model Setting")
|
509 |
+
#dataset_dropdown = gr.Dropdown(choices=datasets, label="Select Dat")
|
510 |
+
|
511 |
+
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
512 |
+
dataset_selector = gr.Dropdown(label="Select Dataset",
|
513 |
+
choices=list(predefined_datasets.keys()) + ["Custom Dataset"])
|
514 |
+
# Display the message for selected columns
|
515 |
+
selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=False)
|
516 |
+
|
517 |
+
with gr.Accordion("Dataset Settings", open=True):
|
518 |
+
# File upload options for custom dataset (train and test)
|
519 |
+
dataset_name = gr.Textbox(label="Dataset Name", visible=False)
|
520 |
+
train_file = gr.File(label="Upload Custom Train Dataset", file_types=[".csv"], visible=False)
|
521 |
+
train_display = gr.Dataframe(label="Train Dataset Preview (First 5 Rows)", visible=False, interactive=False)
|
522 |
+
|
523 |
+
test_file = gr.File(label="Upload Custom Test Dataset", file_types=[".csv"], visible=False)
|
524 |
+
test_display = gr.Dataframe(label="Test Dataset Preview (First 5 Rows)", visible=False, interactive=False)
|
525 |
+
|
526 |
+
# Predefined dataset displays
|
527 |
+
predefined_display = gr.Dataframe(label="Predefined Dataset Preview (First 5 Rows)", visible=False,
|
528 |
+
interactive=False)
|
529 |
+
|
530 |
+
|
531 |
+
|
532 |
+
# Dropdowns for selecting input and output columns for the custom dataset
|
533 |
+
input_column_selector = gr.Dropdown(label="Select Input Column", choices=[], visible=False)
|
534 |
+
output_column_selector = gr.Dropdown(label="Select Output Column", choices=[], visible=False)
|
535 |
+
|
536 |
+
#selected_columns_message = gr.Textbox(label="Selected Columns Info", visible=True)
|
537 |
+
|
538 |
+
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
539 |
+
dataset_selector.change(handle_dataset_selection,
|
540 |
+
inputs=dataset_selector,
|
541 |
+
outputs=[dataset_name, train_file, train_display, test_file, test_display, predefined_display,
|
542 |
+
input_column_selector, output_column_selector])
|
543 |
+
|
544 |
+
# When a predefined dataset is selected, load its head and update column selectors
|
545 |
+
dataset_selector.change(load_predefined_dataset,
|
546 |
+
inputs=dataset_selector,
|
547 |
+
outputs=[predefined_display, input_column_selector, output_column_selector, selected_columns_message])
|
548 |
+
|
549 |
+
# When a custom train file is uploaded, display its head and update column selectors
|
550 |
+
train_file.change(display_csv_head, inputs=train_file,
|
551 |
+
outputs=[train_display, input_column_selector, output_column_selector])
|
552 |
+
|
553 |
+
# When a custom test file is uploaded, display its head
|
554 |
+
test_file.change(display_csv_head, inputs=test_file,
|
555 |
+
outputs=[test_display, input_column_selector, output_column_selector])
|
556 |
+
|
557 |
+
dataset_selector.change(set_dataname,
|
558 |
+
inputs=[dataset_name, dataset_selector],
|
559 |
+
outputs=dataset_name)
|
560 |
+
|
561 |
+
# Update the selected columns information when dropdown values are changed
|
562 |
+
input_column_selector.change(select_columns,
|
563 |
+
inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
|
564 |
+
outputs=selected_columns_message)
|
565 |
+
|
566 |
+
output_column_selector.change(select_columns,
|
567 |
+
inputs=[input_column_selector, output_column_selector, train_file, test_file, dataset_name],
|
568 |
+
outputs=selected_columns_message)
|
569 |
+
|
570 |
+
model_checkbox = gr.CheckboxGroup(choices=models_enabled, label="Select Model")
|
571 |
+
|
572 |
+
# Add disabled checkboxes for GNN and FNN
|
573 |
+
# gnn_checkbox = gr.Checkbox(label="GNN (Disabled)", value=False, interactive=False)
|
574 |
+
# fnn_checkbox = gr.Checkbox(label="FNN (Disabled)", value=False, interactive=False)
|
575 |
+
|
576 |
+
task_radiobutton = gr.Radio(choices=["Classification", "Regression"], label="Task Type")
|
577 |
+
|
578 |
+
####### adding hyper parameter tuning ###########
|
579 |
+
model_name = gr.Dropdown(["Default - Auto", "XGBClassifier", "SVR", "Kernel Ridge", "Linear Regression"], label="Select Downstream Model")
|
580 |
+
with gr.Accordion("Downstream Hyperparameter Settings", open=True):
|
581 |
+
# Create placeholders for hyperparameter components
|
582 |
+
max_depth = gr.Slider(1, 20, step=1,visible=False, label="max_depth")
|
583 |
+
n_estimators = gr.Slider(100, 5000, step=100, visible=False, label="n_estimators")
|
584 |
+
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha")
|
585 |
+
degree = gr.Slider(1, 20, step=1,visible=False, label="degree")
|
586 |
+
kernel = gr.Dropdown(choices=["rbf", "poly", "linear"], visible=False, label="kernel")
|
587 |
+
|
588 |
+
# Output textbox
|
589 |
+
output = gr.Textbox(label="Loaded Parameters")
|
590 |
+
|
591 |
+
|
592 |
+
# Dynamically show relevant hyperparameters based on selected model
|
593 |
+
def update_hyperparameters(model_name):
|
594 |
+
if model_name == "XGBClassifier":
|
595 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
596 |
+
visible=False), gr.update(visible=False)
|
597 |
+
elif model_name == "SVR":
|
598 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
|
599 |
+
visible=True), gr.update(visible=True)
|
600 |
+
elif model_name == "Kernel Ridge":
|
601 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(
|
602 |
+
visible=True), gr.update(visible=True)
|
603 |
+
elif model_name == "Linear Regression":
|
604 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
|
605 |
+
visible=False), gr.update(visible=False)
|
606 |
+
elif model_name == "Default - Auto":
|
607 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(
|
608 |
+
visible=False), gr.update(visible=False)
|
609 |
+
|
610 |
+
|
611 |
+
# When model is selected, update which hyperparameters are visible
|
612 |
+
model_name.change(update_hyperparameters, inputs=[model_name],
|
613 |
+
outputs=[max_depth, n_estimators, alpha, degree, kernel])
|
614 |
+
|
615 |
+
# Submit button to create the model with selected hyperparameters
|
616 |
+
submit_button = gr.Button("Create Downstream Model")
|
617 |
+
|
618 |
+
|
619 |
+
# Function to handle model creation based on input parameters
|
620 |
+
def on_submit(model_name, max_depth, n_estimators, alpha, degree, kernel):
|
621 |
+
if model_name == "XGBClassifier":
|
622 |
+
return create_model(model_name, max_depth=max_depth, n_estimators=n_estimators, alpha=alpha)
|
623 |
+
elif model_name == "SVR":
|
624 |
+
return create_model(model_name, degree=degree, kernel=kernel)
|
625 |
+
elif model_name == "Kernel Ridge":
|
626 |
+
return create_model(model_name, alpha=alpha, degree=degree, kernel=kernel)
|
627 |
+
elif model_name == "Linear Regression":
|
628 |
+
return create_model(model_name)
|
629 |
+
elif model_name == "Default - Auto":
|
630 |
+
return create_model(model_name)
|
631 |
+
|
632 |
+
# When the submit button is clicked, run the on_submit function
|
633 |
+
submit_button.click(on_submit, inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel],
|
634 |
+
outputs=output)
|
635 |
+
###### End of hyper param tuning #########
|
636 |
+
|
637 |
+
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type")
|
638 |
+
|
639 |
+
|
640 |
+
|
641 |
+
eval_button = gr.Button("Train downstream model")
|
642 |
+
#eval_button.style(css_class="custom-button-left")
|
643 |
+
|
644 |
+
# Middle Column
|
645 |
+
with gr.Column():
|
646 |
+
gr.HTML('''
|
647 |
+
<div style="background-color: #8F9779; color: #FFFFFF; padding: 10px;">
|
648 |
+
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 1: Property Prediction</h3>
|
649 |
+
</div>
|
650 |
+
''')
|
651 |
+
# gr.Markdown("## Downstream task Result")
|
652 |
+
eval_output = gr.Textbox(label="Train downstream model")
|
653 |
+
|
654 |
+
plot_radio = gr.Radio(choices=["ROC-AUC", "Parity Plot", "Latent Space"], label="Select Plot Type")
|
655 |
+
plot_output = gr.Plot(label="Visualization")#, height=250, width=250)
|
656 |
+
|
657 |
+
#download_rep = gr.Button("Download representation")
|
658 |
+
|
659 |
+
create_log = gr.Button("Store log")
|
660 |
+
|
661 |
+
log_table = gr.Dataframe(value=log_df, label="Log of Selections and Results", interactive=False)
|
662 |
+
|
663 |
+
eval_button.click(display_eval,
|
664 |
+
inputs=[model_checkbox, selected_columns_message, task_radiobutton, output, fusion_radiobutton],
|
665 |
+
outputs=eval_output)
|
666 |
+
|
667 |
+
plot_radio.change(display_plot, inputs=plot_radio, outputs=plot_output)
|
668 |
+
|
669 |
+
|
670 |
+
# Function to gather selected models
|
671 |
+
def gather_selected_models(*models):
|
672 |
+
selected = [model for model in models if model]
|
673 |
+
return selected
|
674 |
+
|
675 |
+
|
676 |
+
create_log.click(evaluate_and_log, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
|
677 |
+
outputs=log_table)
|
678 |
+
#download_rep.click(save_rep, inputs=[model_checkbox, dataset_name, task_radiobutton, eval_output],
|
679 |
+
# outputs=None)
|
680 |
+
|
681 |
+
# Right Column
|
682 |
+
with gr.Column():
|
683 |
+
gr.HTML('''
|
684 |
+
<div style="background-color: #D2B48C; color: #FFFFFF; padding: 10px;">
|
685 |
+
<h3 style="color: #FFFFFF; margin: 0;font-size: 20px;"> Downstream Task 2: Molecule Generation</h3>
|
686 |
+
</div>
|
687 |
+
''')
|
688 |
+
# gr.Markdown("## Molecular Generation")
|
689 |
+
smiles_input = gr.Textbox(label="Input SMILES String")
|
690 |
+
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
691 |
+
# Show images for selection
|
692 |
+
with gr.Accordion("Select from sample molecules", open=False):
|
693 |
+
image_selector = gr.Radio(
|
694 |
+
choices=list(smiles_image_mapping.keys()),
|
695 |
+
label="Select from sample molecules",
|
696 |
+
value=None,
|
697 |
+
#item_images=[load_image(smiles_image_mapping[key]["image"]) for key in smiles_image_mapping.keys()]
|
698 |
+
)
|
699 |
+
image_selector.change(load_image, image_selector, image_display)
|
700 |
+
generate_button = gr.Button("Generate")
|
701 |
+
gen_image_display = gr.Image(label="Generated Molecule Image", height=250, width=250)
|
702 |
+
generated_output = gr.Textbox(label="Generated Output")
|
703 |
+
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
704 |
+
|
705 |
+
|
706 |
+
|
707 |
+
# Handle image selection
|
708 |
+
image_selector.change(handle_image_selection, inputs=image_selector, outputs=[smiles_input, image_display])
|
709 |
+
smiles_input.change(smiles_to_image, inputs=smiles_input, outputs=image_display)
|
710 |
+
|
711 |
+
# Generate button to display canonical SMILES and molecule image
|
712 |
+
generate_button.click(generate_canonical, inputs=smiles_input,
|
713 |
+
outputs=[property_table, generated_output, gen_image_display])
|
714 |
+
|
715 |
+
|
716 |
+
if __name__ == "__main__":
|
717 |
+
demo.launch(share=True)
|
data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/bace/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/bace/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/bace/valid.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/esol/test.csv
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,selfies,prop,smiles
|
2 |
+
0,[Cl] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [C] [C] [C] [C] [Branch1] [Branch2] [C] [O] [C] [Ring1] [=Branch1] [Ring1] [Ring1] [C] [Ring1] [Branch2] [C] [Ring1] [=C] [Branch1] [C] [Cl] [C] [Ring1] [=N] [Branch1] [C] [Cl] [Cl],-4.533,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
|
3 |
+
1,[C] [C] [C] [C] [C] [=O],-1.103,CCCCC=O
|
4 |
+
2,[O] [C] [C] [C] [C] [=C],-0.7909999999999999,OCCCC=C
|
5 |
+
3,[C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [N] [=C] [C] [Branch1] [C] [N] [=C] [Branch1] [C] [Br] [C] [Ring1] [Branch2] [=O],-3.005,c1ccccc1n2ncc(N)c(Br)c2(=O)
|
6 |
+
4,[N] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-1.231,Nc1ccc(O)cc1
|
7 |
+
5,[C] [C] [Branch1] [C] [C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.817,CC(C)CCOC(=O)C
|
8 |
+
6,[C] [O] [P] [=Branch1] [C] [=S] [Branch1] [Ring1] [O] [C] [S] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=O],-2.087,COP(=S)(OC)SCC(=O)N(C)C=O
|
9 |
+
7,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=Branch1] [Ring2] [=C] [Ring1] [#Branch1] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-6.312,Clc1ccc(Cl)c(c1)c2ccc(Cl)c(Cl)c2
|
10 |
+
8,[C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [N] [=C] [C] [=N] [C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [Cl],-4.438,c2(Cl)c(Cl)c(Cl)c1nccnc1c2(Cl)
|
11 |
+
9,[C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [N] [=C] [Branch1] [=Branch1] [N] [=C] [Ring1] [#Branch1] [O] [N] [Branch1] [C] [C] [C],-3.57,CCCCc1c(C)nc(nc1O)N(C)C
|
12 |
+
10,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [C] [=Branch1] [C] [=O] [O] [C] [C],-1.413,CCOC(=O)CC(=O)OCC
|
13 |
+
11,[C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-3.192,CC(C)(C)c1ccc(O)cc1
|
14 |
+
12,[C] [C] [=C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch1],-3.035,Cc1cccc(C)c1
|
15 |
+
13,[C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.125,CCCOC(=O)C
|
16 |
+
14,[C] [S] [C] [=N] [N] [=C] [Branch1] [=Branch2] [C] [=Branch1] [C] [=O] [N] [Ring1] [#Branch1] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.324,CSc1nnc(c(=O)n1N)C(C)(C)C
|
17 |
+
15,[Cl] [C] [=C] [C] [=C] [Branch1] [Branch1] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Cl],-5.142,Clc1ccc(cc1)c2ccccc2Cl
|
18 |
+
16,[C] [C] [C] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [C] [Branch1] [Ring2] [C] [Ring1] [Branch2] [C] [Branch1] [C] [O] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2],-1.5319999999999998,CC1CC(C)C(=O)C(C1)C(O)CC2CC(=O)NC(=O)C2
|
19 |
+
17,[C] [N] [C] [=Branch1] [C] [=O] [O] [C] [=C] [C] [=C] [C] [Branch1] [Branch2] [N] [=C] [N] [Branch1] [C] [C] [C] [=C] [Ring1] [O],-1.846,CNC(=O)Oc1cccc(N=CN(C)C)c1
|
20 |
+
18,[C] [C] [=C] [C] [=N] [C] [N] [Branch1] [=Branch1] [C] [C] [C] [Ring1] [Ring1] [C] [=N] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [N] [C] [Ring2] [Ring1] [Ring1] [=Ring1] [#C],-3.397,Cc3ccnc4N(C1CC1)c2ncccc2C(=O)Nc34
|
21 |
+
19,[C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.389,CCNc1ccccc1
|
22 |
+
20,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Ring2] [Ring1] [Ring1] [Ring1] [#Branch2],-6.297000000000001,Cc1c2ccccc2c(C)c3ccc4ccccc4c13
|
23 |
+
21,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [F] [C] [Branch1] [C] [Cl] [=C] [Ring1] [=Branch2] [F],-5.462000000000001,Fc1cccc(F)c1C(=O)NC(=O)Nc2cc(Cl)c(F)c(Cl)c2F
|
24 |
+
22,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.057,COc1ccc(Cl)cc1
|
25 |
+
23,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=N] [Ring1] [=Branch1],-4.2010000000000005,o1c2ccccc2c3ccccc13
|
26 |
+
24,[C] [=C] [C] [=C] [N] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [Ring1] [#Branch2] [=C] [Ring1] [=C],-3.846,c3ccc2nc1ccccc1cc2c3
|
27 |
+
25,[C] [C] [C] [C] [=Branch1] [C] [=O] [C] [C] [Branch1] [P] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [O] [Ring1] [#Branch1] [C] [C] [Ring1] [P] [C] [C] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [O] [C] [=Branch1] [C] [=O] [C] [O],-2.893,CC12CC(=O)C3C(CCC4=CC(=O)CCC34C)C2CCC1(O)C(=O)CO
|
28 |
+
26,[C] [C] [C] [=C] [C] [=C] [C] [Branch1] [Ring1] [C] [C] [=C] [Ring1] [Branch2] [N] [Branch1] [Ring2] [C] [O] [C] [C] [=Branch1] [C] [=O] [C] [Cl],-3.319,CCc1cccc(CC)c1N(COC)C(=O)CCl
|
29 |
+
27,[C] [C] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-4.157,CCCCN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
|
30 |
+
28,[C] [S] [C] [=Branch1] [C] [=S] [N] [C] [Ring1] [=Branch1] [=O],-0.396,C1SC(=S)NC1(=O)
|
31 |
+
29,[O] [C] [=C] [C] [=C] [Branch1] [Branch2] [C] [Branch1] [C] [O] [=C] [Ring1] [#Branch1] [C] [O] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [Branch1] [C] [O] [=C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [C] [=Ring1] [=N] [O],-2.7310000000000003,Oc1ccc(c(O)c1)c3oc2cc(O)cc(O)c2c(=O)c3O
|
32 |
+
30,[C] [N] [Branch1] [C] [C] [C] [=N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C],-3.164,CN(C)C=Nc1ccc(Cl)cc1C
|
33 |
+
31,[N] [C] [=Branch1] [C] [=O] [N] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [=Branch1] [=O],0.652,NC(=O)NC1NC(=O)NC1=O
|
34 |
+
32,[Cl] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-4.063,Clc1cccc2ccccc12
|
35 |
+
33,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.352,Oc1ccc(Cl)c(Cl)c1
|
36 |
+
34,[C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [=C] [Branch1] [C] [Cl] [Cl] [C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-6.775,CC1(C)C(C=C(Cl)Cl)C1C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
|
37 |
+
35,[C] [=C] [C] [=C] [NH1] [N] [=N] [C] [Ring1] [Branch1] [=C] [Ring1] [=Branch2],-2.21,c2ccc1[nH]nnc1c2
|
38 |
+
36,[C] [C] [Branch1] [C] [C] [C] [Branch2] [Ring1] [Branch1] [N] [C] [=C] [C] [=C] [Branch1] [=Branch1] [C] [=C] [Ring1] [=Branch1] [Cl] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-8.057,CC(C)C(Nc1ccc(cc1Cl)C(F)(F)F)C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
|
39 |
+
37,[C] [C] [C],-1.5530000000000002,CCC
|
40 |
+
38,[C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [O] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-3.792,C1Cc2cccc3cccc1c23
|
41 |
+
39,[C] [C] [C] [#C],-1.092,CCC#C
|
42 |
+
40,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.5580000000000003,Clc1ccc(Cl)cc1
|
43 |
+
41,[C] [C] [=C] [NH1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch2] [Ring1] [=Branch1],-2.9810000000000003,Cc1c[nH]c2ccccc12
|
44 |
+
42,[C] [C] [#N],0.152,CC#N
|
45 |
+
43,[C] [C] [C] [C] [O],-0.688,CCCCO
|
46 |
+
44,[C] [C] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C],-2.052,CC(=C)C(=C)C
|
47 |
+
45,[C] [C] [C] [Branch1] [C] [C] [C] [C] [O],-1.308,CCC(C)CCO
|
48 |
+
46,[Cl] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [Branch1] [C] [Cl] [=C] [Ring1] [#Branch1] [Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2] [Cl],-7.192,Clc1ccc(c(Cl)c1Cl)c2ccc(Cl)c(Cl)c2Cl
|
49 |
+
47,[C] [C] [=C] [C] [=Branch2] [Ring1] [=Branch1] [=C] [C] [=C] [Ring1] [=Branch1] [N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-4.945,Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2
|
50 |
+
48,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [Cl],-3.22,Oc1ccc(Cl)cc1Cl
|
51 |
+
49,[C] [N] [C] [=Branch2] [Ring1] [Ring2] [=C] [Branch1] [C] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [S] [Ring1] [O] [=Branch1] [C] [=O] [=O] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=N] [Ring1] [=Branch1],-3.4730000000000003,CN2C(=C(O)c1ccccc1S2(=O)=O)C(=O)Nc3ccccn3
|
52 |
+
50,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [C] [Ring2] [Ring1] [C] [=O],-3.872,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1=O
|
53 |
+
51,[C] [C] [=C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1],-4.147,Cc1cccc2c(C)cccc12
|
54 |
+
52,[N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [O] [N] [C] [N] [S] [Ring1] [=Branch1] [=Branch1] [C] [=O] [=O] [C] [=C] [Ring1] [N] [Cl],-1.72,NS(=O)(=O)c2cc1c(NCNS1(=O)=O)cc2Cl
|
55 |
+
53,[O] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-2.725,Oc1cccc2cccnc12
|
56 |
+
54,[C] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [Ring1] [#Branch2],-3.447,C1CCc2ccccc2C1
|
57 |
+
55,[C] [C] [O] [C] [Branch1] [C] [C] [O] [C] [C],-0.899,CCOC(C)OCC
|
58 |
+
56,[C] [C] [C] [C] [Ring1] [Ring1] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [Branch1] [C] [Ring1] [Branch2] [=O] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.464,CC12CC2(C)C(=O)N(C1=O)c3cc(Cl)cc(Cl)c3
|
59 |
+
57,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-4.87,Cc1c2ccccc2cc3ccccc13
|
60 |
+
58,[C] [C] [C] [C] [O] [C],-1.072,CCCCOC
|
61 |
+
59,[C] [C] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [Ring1] [#Branch1] [C] [C] [C] [C] [C] [C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [C] [O] [C] [Ring1] [=Branch2] [Branch1] [N] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [#Branch1] [Ring1] [=C] [C] [=O],-3.0660000000000003,CC13CCC(=O)C=C1CCC4C2CCC(C(=O)CO)C2(CC(O)C34)C=O
|
62 |
+
60,[C] [C] [C] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [O] [=O],-1.6030000000000002,CCC1(C(C)C)C(=O)NC(=O)NC1=O
|
63 |
+
61,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-2.761,CCOC(=O)c1ccc(O)cc1
|
64 |
+
62,[C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [C] [=C] [Ring2] [Ring1] [C] [C] [Ring1] [S] [=C] [Ring1] [=C] [C] [Ring1] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-6.885,c1cc2ccc3ccc4ccc5ccc6ccc1c7c2c3c4c5c67
|
65 |
+
63,[C] [C] [N] [C] [=C] [C] [Branch1] [=Branch1] [N] [Branch1] [C] [C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch2] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [N] [=C] [Ring2] [Ring1] [Ring2] [Ring1] [=Branch1],-4.408,CCN2c1cc(N(C)C)cc(C)c1NC(=O)c3cccnc23
|
66 |
+
64,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.301,CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
|
67 |
+
65,[C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-3.3080000000000003,CCCCCC(C)C
|
68 |
+
66,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [N] [N] [=C] [Branch1] [#C] [N] [=C] [Ring1] [#Branch1] [C] [Branch1] [Ring1] [O] [C] [=C] [Ring1] [=N] [O] [C] [N] [C] [C] [N] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [O] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-3.958,COc2cc1c(N)nc(nc1c(OC)c2OC)N3CCN(CC3)C(=O)OCC(C)(C)O
|
69 |
+
67,[C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [#Branch1] [C] [=C] [Ring1] [O],-0.636,c1cC2C(=O)NC(=O)C2cc1
|
70 |
+
68,[C] [C] [C] [=O],-0.3939999999999999,CCC=O
|
71 |
+
69,[Cl] [C] [=C] [C] [=C] [Branch2] [Ring1] [=Branch2] [C] [N] [Branch1] [Branch2] [C] [C] [C] [C] [C] [Ring1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring2] [Ring1] [=Branch1],-5.126,Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1
|
72 |
+
70,[C] [C] [C] [C] [C] [Branch1] [Ring1] [C] [C] [C] [=O],-2.232,CCCCC(CC)C=O
|
73 |
+
71,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-2.312,O=C1NC(=O)NC(=O)C1(CC)CCC(C)C
|
74 |
+
72,[C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-1.857,CC(=O)Nc1ccccc1
|
75 |
+
73,[C] [=N] [C] [=C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [N] [=C] [Ring1] [#Branch2],-0.7170000000000001,c1nccc(C(=O)NN)c1
|
76 |
+
74,[C] [C] [Branch1] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [Ring2] [C] [Ring1] [=Branch1] [C] [Ring1] [=Branch2] [=O],-2.158,CC2(C)C1CCC(C)(C1)C2=O
|
77 |
+
75,[C] [O] [C] [=C] [N] [=C] [C] [=N] [C] [=N] [C] [Ring1] [=Branch1] [=N] [Ring1] [#Branch2],-1.589,COc2cnc1cncnc1n2
|
78 |
+
76,[C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [Branch1] [C] [C] [O] [P] [=Branch1] [C] [=O] [Branch1] [Ring1] [O] [C] [O] [C],-0.949,CNC(=O)C=C(C)OP(=O)(OC)OC
|
79 |
+
77,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [Branch1] [Ring1] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [Ring1] [=Branch1],-3.784,O2c1ccccc1N(CC)C(=O)c3ccccc23
|
80 |
+
78,[C] [=C] [C] [=C] [C] [=C] [Branch1] [Ring1] [O] [C] [C] [Branch1] [Branch2] [C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [=N] [O] [C] [Ring1] [P] [=O],-4.0760000000000005,c1cc2ccc(OC)c(CC=C(C)(C))c2oc1=O
|
81 |
+
79,[C] [C] [C] [S] [C] [C] [C],-2.307,CCCSCCC
|
82 |
+
80,[C] [O] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-2.948,CON(C)C(=O)Nc1ccc(Cl)cc1
|
83 |
+
81,[C] [C] [O] [C] [C],-0.718,CCOCC
|
84 |
+
82,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [Ring1] [O],-3.858,CC34CCC1C(CCc2cc(O)ccc12)C3CC(O)C4O
|
85 |
+
83,[C] [C] [N] [C] [=N] [C] [Branch1] [C] [Cl] [=N] [C] [Branch1] [O] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [#N] [=N] [Ring1] [=N],-2.49,CCNc1nc(Cl)nc(NC(C)(C)C#N)n1
|
86 |
+
84,[C] [C] [Branch1] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-1.6469999999999998,CC(C)CC(C)(C)O
|
87 |
+
85,[Cl] [C] [=C] [C] [=C] [C] [Branch1] [C] [Br] [=C] [Ring1] [#Branch1],-3.928,Clc1cccc(Br)c1
|
88 |
+
86,[C] [C] [C] [C] [C] [C] [Branch1] [C] [O] [C] [C],-2.033,CCCCCC(O)CC
|
89 |
+
87,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.126,O=C1NC(=O)NC(=O)C1(CC)CC=C(C)C
|
90 |
+
88,[C] [C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [C] [Branch1] [C] [Br] [=C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [N] [=O],-2.766,CCC(C)C1(CC(Br)=C)C(=O)NC(=O)NC1=O
|
91 |
+
89,[C] [O] [C] [=Branch1] [C] [=O] [C],-0.416,COC(=O)C
|
92 |
+
90,[C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [#Branch1] [O],-3.129,CC(C)c1ccc(C)cc1O
|
93 |
+
91,[C],-0.636,C
|
94 |
+
92,[N] [C] [=N] [C] [Branch1] [C] [O] [=N] [C] [N] [=C] [NH1] [C] [Ring1] [#Branch2] [=Ring1] [Branch1],-1.74,Nc1nc(O)nc2nc[nH]c12
|
95 |
+
93,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-4.692,Fc1cccc(F)c1C(=O)NC(=O)Nc2ccc(Cl)cc2
|
96 |
+
94,[C] [C] [C] [C] [C] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O] [Ring1] [#Branch2],-2.579,CC12CCC(CC1)C(C)(C)O2
|
97 |
+
95,[C] [C] [O],0.02,CCO
|
98 |
+
96,[C] [=C] [Branch2] [Ring1] [C] [N] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [C] [C] [=C] [C] [=C] [Ring1] [P],-2.29,c1c(NC(=O)OC(C)C(=O)NCC)cccc1
|
99 |
+
97,[C] [C] [Branch1] [C] [C] [=C] [C] [C] [Branch2] [Ring1] [#Branch2] [C] [=Branch1] [C] [=O] [O] [C] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [C] [C],-6.763,CC(C)=CC3C(C(=O)OCc2cccc(Oc1ccccc1)c2)C3(C)C
|
100 |
+
98,[C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Branch1] [Branch2] [N] [C] [=Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-2.902,CCCCNC(=O)n1c(NC(=O)OC)nc2ccccc12
|
101 |
+
99,[C] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.542,CN(C)c1ccccc1
|
102 |
+
100,[C] [O] [C] [=Branch1] [C] [=O] [C] [=C],-0.878,COC(=O)C=C
|
103 |
+
101,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [=N] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C] [=C] [Ring1] [=C],-4.477,CN(C)C(=O)Nc2ccc(Oc1ccc(Cl)cc1)cc2
|
104 |
+
102,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.465,O=C1NC(=O)NC(=O)C1(C(C)C)CC=C(C)C
|
105 |
+
103,[C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1] [C],-2.6210000000000004,Cc1ccc(O)cc1C
|
106 |
+
104,[Cl] [C] [=C] [C] [=C] [C] [=Branch1] [Ring2] [=N] [Ring1] [=Branch1] [C] [Branch1] [C] [Cl] [Branch1] [C] [Cl] [Cl],-3.833,Clc1cccc(n1)C(Cl)(Cl)Cl
|
107 |
+
105,[C] [C] [=Branch1] [C] [=O] [O] [C] [Branch2] [Ring1] [=C] [C] [C] [C] [C] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [#Branch1] [C] [Ring1] [O] [C] [C] [C] [Ring2] [Ring1] [C] [Ring1] [#C] [C] [C] [#C],-4.2410000000000005,CC(=O)OC3(CCC4C2CCC1=CC(=O)CCC1C2CCC34C)C#C
|
108 |
+
106,[C] [N] [C] [=Branch1] [C] [=O] [O] [N] [=C] [Branch1] [Ring2] [C] [S] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.7,CNC(=O)ON=C(CSC)C(C)(C)C
|
109 |
+
107,[C] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [O],-2.033,CCCCCCC(C)O
|
data/esol/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
img/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
img/img1.png
ADDED
img/img2.png
ADDED
img/img3.png
ADDED
img/img4.png
ADDED
img/img5.png
ADDED
img/introduction.png
ADDED
img/latent_multi_bace.png
ADDED
log.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
,Selected Models,Dataset,Task,Result
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/__pycache__/fm4m.cpython-310.pyc
ADDED
Binary file (14.4 kB). View file
|
|
models/fm4m.py
ADDED
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import os
|
5 |
+
import umap
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import pandas as pd
|
10 |
+
import pickle
|
11 |
+
import json
|
12 |
+
|
13 |
+
from xgboost import XGBClassifier, XGBRegressor
|
14 |
+
import xgboost as xgb
|
15 |
+
from sklearn.metrics import roc_auc_score, mean_squared_error
|
16 |
+
import xgboost as xgb
|
17 |
+
from sklearn.svm import SVR
|
18 |
+
from sklearn.linear_model import LinearRegression
|
19 |
+
from sklearn.kernel_ridge import KernelRidge
|
20 |
+
import json
|
21 |
+
from sklearn.compose import TransformedTargetRegressor
|
22 |
+
from sklearn.preprocessing import MinMaxScaler
|
23 |
+
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from transformers import AutoTokenizer, AutoModel
|
27 |
+
|
28 |
+
from .selfies_model.load import SELFIES as bart
|
29 |
+
from .mhg_model import load as mhg
|
30 |
+
from .smi_ted.smi_ted_light.load import load_smi_ted
|
31 |
+
|
32 |
+
datasets = {}
|
33 |
+
models = {}
|
34 |
+
downstream_models ={}
|
35 |
+
|
36 |
+
|
37 |
+
def avail_models_data():
|
38 |
+
global datasets
|
39 |
+
global models
|
40 |
+
|
41 |
+
datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
|
42 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
43 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
|
44 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
|
45 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
|
46 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
|
47 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"}]
|
48 |
+
|
49 |
+
|
50 |
+
models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
|
51 |
+
{"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
|
52 |
+
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
|
53 |
+
{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
|
54 |
+
|
55 |
+
|
56 |
+
def avail_models(raw=False):
|
57 |
+
global models
|
58 |
+
|
59 |
+
models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
|
60 |
+
{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
|
61 |
+
{"Name": "mol-xl","Model Name": "Molformer", "Description": "MolFormer model for string based SMILES modality"},
|
62 |
+
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
if raw: return models
|
68 |
+
else:
|
69 |
+
return pd.DataFrame(models).drop('Name', axis=1)
|
70 |
+
|
71 |
+
return models
|
72 |
+
|
73 |
+
def avail_downstream_models():
|
74 |
+
global downstream_models
|
75 |
+
|
76 |
+
with open("downstream_models.json", "r") as outfile:
|
77 |
+
downstream_models = json.load(outfile)
|
78 |
+
return downstream_models
|
79 |
+
|
80 |
+
def avail_datasets():
|
81 |
+
global datasets
|
82 |
+
|
83 |
+
datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv",
|
84 |
+
"Timestamp": "2024-06-26 11:27:37"},
|
85 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre",
|
86 |
+
"Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
87 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv",
|
88 |
+
"Timestamp": "2024-06-26 11:33:47"},
|
89 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo",
|
90 |
+
"Timestamp": "2024-06-26 11:34:37"},
|
91 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace",
|
92 |
+
"Timestamp": "2024-06-26 11:36:40"},
|
93 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp",
|
94 |
+
"Timestamp": "2024-06-26 11:39:23"},
|
95 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox",
|
96 |
+
"Timestamp": "2024-06-26 11:42:43"}]
|
97 |
+
|
98 |
+
return datasets
|
99 |
+
|
100 |
+
def reset():
|
101 |
+
|
102 |
+
"""datasets = {"esol": ["smiles", "ESOL predicted log solubility in mols per litre", "data/esol", "2024-06-26 11:36:46.509324"],
|
103 |
+
"freesolv": ["smiles", "expt", "data/freesolv", "2024-06-26 11:37:37.393273"],
|
104 |
+
"lipo": ["smiles", "y", "data/lipo", "2024-06-26 11:37:37.393273"],
|
105 |
+
"hiv": ["smiles", "HIV_active", "data/hiv", "2024-06-26 11:37:37.393273"],
|
106 |
+
"bace": ["smiles", "Class", "data/bace", "2024-06-26 11:38:40.058354"],
|
107 |
+
"bbbp": ["smiles", "p_np", "data/bbbp","2024-06-26 11:38:40.058354"],
|
108 |
+
"clintox": ["smiles", "CT_TOX", "data/clintox","2024-06-26 11:38:40.058354"],
|
109 |
+
"sider": ["smiles","1:", "data/sider","2024-06-26 11:38:40.058354"],
|
110 |
+
"tox21": ["smiles",":-2", "data/tox21","2024-06-26 11:38:40.058354"]
|
111 |
+
}"""
|
112 |
+
|
113 |
+
datasets = [
|
114 |
+
{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
|
115 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
116 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
|
117 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
|
118 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
|
119 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
|
120 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"},
|
121 |
+
#{"Dataset": "sider", "Input": "smiles", "Output": "1:", "path": "data/sider", "Timestamp": "2024-06-26 11:38:40.058354"},
|
122 |
+
#{"Dataset": "tox21", "Input": "smiles", "Output": ":-2", "path": "data/tox21", "Timestamp": "2024-06-26 11:38:40.058354"}
|
123 |
+
]
|
124 |
+
|
125 |
+
models = [{"Name": "bart", "Description": "BART model for string based SELFIES modality",
|
126 |
+
"Timestamp": "2024-06-21 12:32:20"},
|
127 |
+
{"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
|
128 |
+
"Timestamp": "2024-06-21 12:35:56"},
|
129 |
+
{"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
|
130 |
+
{"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
|
131 |
+
{"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
|
132 |
+
{"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
|
133 |
+
|
134 |
+
|
135 |
+
downstream_models = [
|
136 |
+
{"Name": "XGBClassifier", "Description": "XG Boost Classifier",
|
137 |
+
"Timestamp": "2024-06-21 12:31:20"},
|
138 |
+
{"Name": "XGBRegressor", "Description": "XG Boost Regressor",
|
139 |
+
"Timestamp": "2024-06-21 12:32:56"},
|
140 |
+
{"Name": "2-FNN", "Description": "A two layer feedforward network",
|
141 |
+
"Timestamp": "2024-06-24 14:34:16"},
|
142 |
+
{"Name": "3-FNN", "Description": "A three layer feedforward network",
|
143 |
+
"Timestamp": "2024-06-24 14:38:37"},
|
144 |
+
]
|
145 |
+
|
146 |
+
with open("datasets.json", "w") as outfile:
|
147 |
+
json.dump(datasets, outfile)
|
148 |
+
|
149 |
+
with open("models.json", "w") as outfile:
|
150 |
+
json.dump(models, outfile)
|
151 |
+
|
152 |
+
with open("downstream_models.json", "w") as outfile:
|
153 |
+
json.dump(downstream_models, outfile)
|
154 |
+
|
155 |
+
def update_data_list(list_data):
|
156 |
+
#datasets[list_data[0]] = list_data[1:]
|
157 |
+
|
158 |
+
with open("datasets.json", "w") as outfile:
|
159 |
+
json.dump(datasets, outfile)
|
160 |
+
|
161 |
+
avail_models_data()
|
162 |
+
|
163 |
+
def update_model_list(list_model):
|
164 |
+
#models[list_model[0]] = list_model[1]
|
165 |
+
|
166 |
+
with open("models.json", "w") as outfile:
|
167 |
+
json.dump(list_model, outfile)
|
168 |
+
|
169 |
+
avail_models_data()
|
170 |
+
|
171 |
+
def update_downstream_model_list(list_model):
|
172 |
+
#models[list_model[0]] = list_model[1]
|
173 |
+
|
174 |
+
with open("downstream_models.json", "w") as outfile:
|
175 |
+
json.dump(list_model, outfile)
|
176 |
+
|
177 |
+
avail_models_data()
|
178 |
+
|
179 |
+
avail_models_data()
|
180 |
+
|
181 |
+
def get_representation(train_data,test_data,model_type, return_tensor=True):
|
182 |
+
alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
|
183 |
+
if model_type in alias.keys():
|
184 |
+
model_type = alias[model_type]
|
185 |
+
|
186 |
+
if model_type == "mhg":
|
187 |
+
model = mhg.load("models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
|
188 |
+
with torch.no_grad():
|
189 |
+
train_emb = model.encode(train_data)
|
190 |
+
x_batch = torch.stack(train_emb)
|
191 |
+
|
192 |
+
test_emb = model.encode(test_data)
|
193 |
+
x_batch_test = torch.stack(test_emb)
|
194 |
+
if not return_tensor:
|
195 |
+
x_batch = pd.DataFrame(x_batch)
|
196 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
elif model_type == "bart":
|
201 |
+
model = bart()
|
202 |
+
model.load()
|
203 |
+
x_batch = model.encode(train_data, return_tensor=return_tensor)
|
204 |
+
x_batch_test = model.encode(test_data, return_tensor=return_tensor)
|
205 |
+
|
206 |
+
elif model_type == "smi-ted":
|
207 |
+
model = load_smi_ted(folder='./models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
|
208 |
+
with torch.no_grad():
|
209 |
+
x_batch = model.encode(train_data, return_torch=return_tensor)
|
210 |
+
x_batch_test = model.encode(test_data, return_torch=return_tensor)
|
211 |
+
|
212 |
+
elif model_type == "mol-xl":
|
213 |
+
model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True,
|
214 |
+
trust_remote_code=True)
|
215 |
+
tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
|
216 |
+
|
217 |
+
if type(train_data) == list:
|
218 |
+
inputs = tokenizer(train_data, padding=True, return_tensors="pt")
|
219 |
+
else:
|
220 |
+
inputs = tokenizer(list(train_data.values), padding=True, return_tensors="pt")
|
221 |
+
|
222 |
+
with torch.no_grad():
|
223 |
+
outputs = model(**inputs)
|
224 |
+
|
225 |
+
x_batch = outputs.pooler_output
|
226 |
+
|
227 |
+
if type(test_data) == list:
|
228 |
+
inputs = tokenizer(test_data, padding=True, return_tensors="pt")
|
229 |
+
else:
|
230 |
+
inputs = tokenizer(list(test_data.values), padding=True, return_tensors="pt")
|
231 |
+
|
232 |
+
with torch.no_grad():
|
233 |
+
outputs = model(**inputs)
|
234 |
+
|
235 |
+
x_batch_test = outputs.pooler_output
|
236 |
+
|
237 |
+
if not return_tensor:
|
238 |
+
x_batch = pd.DataFrame(x_batch)
|
239 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
240 |
+
|
241 |
+
|
242 |
+
return x_batch, x_batch_test
|
243 |
+
|
244 |
+
def single_modal(model,dataset, downstream_model,params):
|
245 |
+
print(model)
|
246 |
+
alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED": "smi-ted"}
|
247 |
+
data = avail_models(raw=True)
|
248 |
+
df = pd.DataFrame(data)
|
249 |
+
print(list(df["Name"].values))
|
250 |
+
if alias[model] in list(df["Name"].values):
|
251 |
+
if model in alias.keys():
|
252 |
+
model_type = alias[model]
|
253 |
+
else:
|
254 |
+
model_type = model
|
255 |
+
else:
|
256 |
+
print("Model not available")
|
257 |
+
return
|
258 |
+
|
259 |
+
data = avail_datasets()
|
260 |
+
df = pd.DataFrame(data)
|
261 |
+
print(list(df["Dataset"].values))
|
262 |
+
|
263 |
+
if dataset in list(df["Dataset"].values):
|
264 |
+
task = dataset
|
265 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
266 |
+
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
267 |
+
print(f" Representation loaded successfully")
|
268 |
+
else:
|
269 |
+
|
270 |
+
print("Custom Dataset")
|
271 |
+
#return
|
272 |
+
components = dataset.split(",")
|
273 |
+
train_data = pd.read_csv(components[0])[components[2]]
|
274 |
+
test_data = pd.read_csv(components[1])[components[2]]
|
275 |
+
|
276 |
+
y_batch = pd.read_csv(components[0])[components[3]]
|
277 |
+
y_batch_test = pd.read_csv(components[1])[components[3]]
|
278 |
+
|
279 |
+
|
280 |
+
x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
print(f" Representation loaded successfully")
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
print(f" Calculating ROC AUC Score ...")
|
291 |
+
|
292 |
+
if downstream_model == "XGBClassifier":
|
293 |
+
xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
|
294 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
295 |
+
|
296 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
297 |
+
|
298 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
299 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
300 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
301 |
+
|
302 |
+
try:
|
303 |
+
with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
304 |
+
class_0,class_1 = pickle.load(f1)
|
305 |
+
except:
|
306 |
+
print("Generating latent plots")
|
307 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
308 |
+
verbose=False)
|
309 |
+
n_samples = np.minimum(1000, len(x_batch))
|
310 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
311 |
+
x = y_batch.values[:n_samples]
|
312 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
313 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
314 |
+
|
315 |
+
class_0 = features_umap[index_0]
|
316 |
+
class_1 = features_umap[index_1]
|
317 |
+
print("Generating latent plots : Done")
|
318 |
+
|
319 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
320 |
+
|
321 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
322 |
+
|
323 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
324 |
+
|
325 |
+
elif downstream_model == "DefaultClassifier":
|
326 |
+
xgb_predict_concat = XGBClassifier() # n_estimators=5000, learning_rate=0.01, max_depth=10
|
327 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
328 |
+
|
329 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
330 |
+
|
331 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
332 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
333 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
334 |
+
|
335 |
+
try:
|
336 |
+
with open(f"./plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
337 |
+
class_0,class_1 = pickle.load(f1)
|
338 |
+
except:
|
339 |
+
print("Generating latent plots")
|
340 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
|
341 |
+
n_samples = np.minimum(1000,len(x_batch))
|
342 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
343 |
+
x = y_batch.values[:n_samples]
|
344 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
345 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
346 |
+
|
347 |
+
class_0 = features_umap[index_0]
|
348 |
+
class_1 = features_umap[index_1]
|
349 |
+
print("Generating latent plots : Done")
|
350 |
+
|
351 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
352 |
+
|
353 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
354 |
+
|
355 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
356 |
+
|
357 |
+
elif downstream_model == "SVR":
|
358 |
+
regressor = SVR(**params)
|
359 |
+
model = TransformedTargetRegressor(regressor= regressor,
|
360 |
+
transformer = MinMaxScaler(feature_range=(-1, 1))
|
361 |
+
).fit(x_batch,y_batch)
|
362 |
+
|
363 |
+
y_prob = model.predict(x_batch_test)
|
364 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
365 |
+
|
366 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
367 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
368 |
+
|
369 |
+
print("Generating latent plots")
|
370 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
371 |
+
verbose=False)
|
372 |
+
n_samples = np.minimum(1000, len(x_batch))
|
373 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
374 |
+
x = y_batch.values[:n_samples]
|
375 |
+
#index_0 = [index for index in range(len(x)) if x[index] == 0]
|
376 |
+
#index_1 = [index for index in range(len(x)) if x[index] == 1]
|
377 |
+
|
378 |
+
class_0 = features_umap#[index_0]
|
379 |
+
class_1 = features_umap#[index_1]
|
380 |
+
print("Generating latent plots : Done")
|
381 |
+
|
382 |
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
|
383 |
+
|
384 |
+
elif downstream_model == "Kernel Ridge":
|
385 |
+
regressor = KernelRidge(**params)
|
386 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
387 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
388 |
+
).fit(x_batch, y_batch)
|
389 |
+
|
390 |
+
y_prob = model.predict(x_batch_test)
|
391 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
392 |
+
|
393 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
394 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
395 |
+
|
396 |
+
print("Generating latent plots")
|
397 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
398 |
+
verbose=False)
|
399 |
+
n_samples = np.minimum(1000, len(x_batch))
|
400 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
401 |
+
x = y_batch.values[:n_samples]
|
402 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
403 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
404 |
+
|
405 |
+
class_0 = features_umap#[index_0]
|
406 |
+
class_1 = features_umap#[index_1]
|
407 |
+
print("Generating latent plots : Done")
|
408 |
+
|
409 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
410 |
+
|
411 |
+
|
412 |
+
elif downstream_model == "Linear Regression":
|
413 |
+
regressor = LinearRegression(**params)
|
414 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
415 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
416 |
+
).fit(x_batch, y_batch)
|
417 |
+
|
418 |
+
y_prob = model.predict(x_batch_test)
|
419 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
420 |
+
|
421 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
422 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
423 |
+
|
424 |
+
print("Generating latent plots")
|
425 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
426 |
+
verbose=False)
|
427 |
+
n_samples = np.minimum(1000, len(x_batch))
|
428 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
429 |
+
x = y_batch.values[:n_samples]
|
430 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
431 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
432 |
+
|
433 |
+
class_0 = features_umap#[index_0]
|
434 |
+
class_1 = features_umap#[index_1]
|
435 |
+
print("Generating latent plots : Done")
|
436 |
+
|
437 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
438 |
+
|
439 |
+
|
440 |
+
elif downstream_model == "DefaultRegressor":
|
441 |
+
regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
|
442 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
443 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
444 |
+
).fit(x_batch, y_batch)
|
445 |
+
|
446 |
+
y_prob = model.predict(x_batch_test)
|
447 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
448 |
+
|
449 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
450 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
451 |
+
|
452 |
+
print("Generating latent plots")
|
453 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
454 |
+
verbose=False)
|
455 |
+
n_samples = np.minimum(1000, len(x_batch))
|
456 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
457 |
+
x = y_batch.values[:n_samples]
|
458 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
459 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
460 |
+
|
461 |
+
class_0 = features_umap#[index_0]
|
462 |
+
class_1 = features_umap#[index_1]
|
463 |
+
print("Generating latent plots : Done")
|
464 |
+
|
465 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
466 |
+
|
467 |
+
|
468 |
+
def multi_modal(model_list,dataset, downstream_model,params):
|
469 |
+
print(model_list)
|
470 |
+
data = avail_datasets()
|
471 |
+
df = pd.DataFrame(data)
|
472 |
+
list(df["Dataset"].values)
|
473 |
+
|
474 |
+
if dataset in list(df["Dataset"].values):
|
475 |
+
task = dataset
|
476 |
+
predefined = True
|
477 |
+
else:
|
478 |
+
predefined = False
|
479 |
+
components = dataset.split(",")
|
480 |
+
train_data = pd.read_csv(components[0])[components[2]]
|
481 |
+
test_data = pd.read_csv(components[1])[components[2]]
|
482 |
+
|
483 |
+
y_batch = pd.read_csv(components[0])[components[3]]
|
484 |
+
y_batch_test = pd.read_csv(components[1])[components[3]]
|
485 |
+
|
486 |
+
print("Custom Dataset loaded")
|
487 |
+
|
488 |
+
|
489 |
+
data = avail_models(raw=True)
|
490 |
+
df = pd.DataFrame(data)
|
491 |
+
list(df["Name"].values)
|
492 |
+
|
493 |
+
alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "SMI-TED":"smi-ted"}
|
494 |
+
#if set(model_list).issubset(list(df["Name"].values)):
|
495 |
+
if set(model_list).issubset(list(alias.keys())):
|
496 |
+
for i, model in enumerate(model_list):
|
497 |
+
if model in alias.keys():
|
498 |
+
model_type = alias[model]
|
499 |
+
else:
|
500 |
+
model_type = model
|
501 |
+
|
502 |
+
if i == 0:
|
503 |
+
if predefined:
|
504 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
505 |
+
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
506 |
+
print(f" Loaded representation/{task}_{model_type}.pkl")
|
507 |
+
else:
|
508 |
+
x_batch, x_batch_test = get_representation(train_data, test_data, model_type)
|
509 |
+
x_batch = pd.DataFrame(x_batch)
|
510 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
511 |
+
|
512 |
+
else:
|
513 |
+
if predefined:
|
514 |
+
with open(f"./representation/{task}_{model_type}.pkl", "rb") as f1:
|
515 |
+
x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
|
516 |
+
print(f" Loaded representation/{task}_{model_type}.pkl")
|
517 |
+
else:
|
518 |
+
x_batch_1, x_batch_test_1 = get_representation(train_data, test_data, model_type)
|
519 |
+
x_batch_1 = pd.DataFrame(x_batch_1)
|
520 |
+
x_batch_test_1 = pd.DataFrame(x_batch_test_1)
|
521 |
+
|
522 |
+
x_batch = pd.concat([x_batch, x_batch_1], axis=1)
|
523 |
+
x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
|
524 |
+
|
525 |
+
|
526 |
+
else:
|
527 |
+
print("Model not available")
|
528 |
+
return
|
529 |
+
|
530 |
+
num_columns = x_batch_test.shape[1]
|
531 |
+
x_batch_test.columns = [f'{i + 1}' for i in range(num_columns)]
|
532 |
+
|
533 |
+
num_columns = x_batch.shape[1]
|
534 |
+
x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
|
535 |
+
|
536 |
+
|
537 |
+
print(f"Representations loaded successfully")
|
538 |
+
try:
|
539 |
+
with open(f"./plot_emb/{task}_multi.pkl", "rb") as f1:
|
540 |
+
class_0, class_1 = pickle.load(f1)
|
541 |
+
except:
|
542 |
+
print("Generating latent plots")
|
543 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
544 |
+
verbose=False)
|
545 |
+
n_samples = np.minimum(1000, len(x_batch))
|
546 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
547 |
+
|
548 |
+
if "Classifier" in downstream_model:
|
549 |
+
x = y_batch.values[:n_samples]
|
550 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
551 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
552 |
+
|
553 |
+
class_0 = features_umap[index_0]
|
554 |
+
class_1 = features_umap[index_1]
|
555 |
+
|
556 |
+
else:
|
557 |
+
class_0 = features_umap
|
558 |
+
class_1 = features_umap
|
559 |
+
|
560 |
+
print("Generating latent plots : Done")
|
561 |
+
|
562 |
+
print(f" Calculating ROC AUC Score ...")
|
563 |
+
|
564 |
+
|
565 |
+
if downstream_model == "XGBClassifier":
|
566 |
+
xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
|
567 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
568 |
+
|
569 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
570 |
+
|
571 |
+
|
572 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
573 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
574 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
575 |
+
|
576 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
577 |
+
|
578 |
+
#vizualize(x_batch_test, y_batch_test)
|
579 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
580 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
581 |
+
|
582 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
583 |
+
|
584 |
+
elif downstream_model == "DefaultClassifier":
|
585 |
+
xgb_predict_concat = XGBClassifier()#n_estimators=5000, learning_rate=0.01, max_depth=10)
|
586 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
587 |
+
|
588 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
589 |
+
|
590 |
+
|
591 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
592 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
593 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
594 |
+
|
595 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
596 |
+
|
597 |
+
#vizualize(x_batch_test, y_batch_test)
|
598 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
599 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
600 |
+
|
601 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
602 |
+
|
603 |
+
elif downstream_model == "SVR":
|
604 |
+
regressor = SVR(**params)
|
605 |
+
model = TransformedTargetRegressor(regressor= regressor,
|
606 |
+
transformer = MinMaxScaler(feature_range=(-1, 1))
|
607 |
+
).fit(x_batch,y_batch)
|
608 |
+
|
609 |
+
y_prob = model.predict(x_batch_test)
|
610 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
611 |
+
|
612 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
613 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
614 |
+
|
615 |
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
|
616 |
+
|
617 |
+
elif downstream_model == "Linear Regression":
|
618 |
+
regressor = LinearRegression(**params)
|
619 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
620 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
621 |
+
).fit(x_batch, y_batch)
|
622 |
+
|
623 |
+
y_prob = model.predict(x_batch_test)
|
624 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
625 |
+
|
626 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
627 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
628 |
+
|
629 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
630 |
+
|
631 |
+
elif downstream_model == "Kernel Ridge":
|
632 |
+
regressor = KernelRidge(**params)
|
633 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
634 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
635 |
+
).fit(x_batch, y_batch)
|
636 |
+
|
637 |
+
y_prob = model.predict(x_batch_test)
|
638 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
639 |
+
|
640 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
641 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
642 |
+
|
643 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
644 |
+
|
645 |
+
elif downstream_model == "DefaultRegressor":
|
646 |
+
regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
|
647 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
648 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
649 |
+
).fit(x_batch, y_batch)
|
650 |
+
|
651 |
+
y_prob = model.predict(x_batch_test)
|
652 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
653 |
+
|
654 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
655 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
656 |
+
|
657 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
658 |
+
|
659 |
+
|
660 |
+
|
661 |
+
|
662 |
+
|
663 |
+
|
models/mhg_model/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
models/mhg_model/README.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mhg-gnn
|
2 |
+
|
3 |
+
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
4 |
+
|
5 |
+
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
6 |
+
|
7 |
+
![mhg-gnn](images/mhg_example1.png)
|
8 |
+
|
9 |
+
## Introduction
|
10 |
+
|
11 |
+
We present MHG-GNN, an autoencoder architecture
|
12 |
+
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
13 |
+
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
14 |
+
demonstrate high predictive performance on molecular graph data.
|
15 |
+
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
16 |
+
|
17 |
+
## Table of Contents
|
18 |
+
|
19 |
+
1. [Getting Started](#getting-started)
|
20 |
+
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
21 |
+
2. [Installation](#installation)
|
22 |
+
2. [Feature Extraction](#feature-extraction)
|
23 |
+
|
24 |
+
## Getting Started
|
25 |
+
|
26 |
+
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
27 |
+
|
28 |
+
### Pretrained Models and Training Logs
|
29 |
+
|
30 |
+
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
|
31 |
+
|
32 |
+
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
33 |
+
|
34 |
+
### Installation
|
35 |
+
|
36 |
+
We recommend to create a virtual environment. For example:
|
37 |
+
|
38 |
+
```
|
39 |
+
python3 -m venv .venv
|
40 |
+
. .venv/bin/activate
|
41 |
+
```
|
42 |
+
|
43 |
+
Type the following command once the virtual environment is activated:
|
44 |
+
|
45 |
+
```
|
46 |
+
git clone [email protected]:CMD-TRL/mhg-gnn.git
|
47 |
+
cd ./mhg-gnn
|
48 |
+
pip install .
|
49 |
+
```
|
50 |
+
|
51 |
+
## Feature Extraction
|
52 |
+
|
53 |
+
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
54 |
+
|
55 |
+
To load mhg-gnn, you can simply use:
|
56 |
+
|
57 |
+
```python
|
58 |
+
import torch
|
59 |
+
import load
|
60 |
+
|
61 |
+
model = load.load()
|
62 |
+
```
|
63 |
+
|
64 |
+
To encode SMILES into embeddings, you can use:
|
65 |
+
|
66 |
+
```python
|
67 |
+
with torch.no_grad():
|
68 |
+
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
69 |
+
```
|
70 |
+
|
71 |
+
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
72 |
+
|
73 |
+
```python
|
74 |
+
orig = model.decode(repr)
|
75 |
+
```
|
models/mhg_model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
models/mhg_model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (224 Bytes). View file
|
|
models/mhg_model/__pycache__/load.cpython-310.pyc
ADDED
Binary file (3.16 kB). View file
|
|
models/mhg_model/graph_grammar/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|
12 |
+
|
13 |
+
""" Title """
|
14 |
+
|
15 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
16 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
17 |
+
__version__ = "0.1"
|
18 |
+
__date__ = "Jan 1 2018"
|
19 |
+
|
models/mhg_model/graph_grammar/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (676 Bytes). View file
|
|
models/mhg_model/graph_grammar/__pycache__/hypergraph.cpython-310.pyc
ADDED
Binary file (15.3 kB). View file
|
|
models/mhg_model/graph_grammar/algo/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (681 Bytes). View file
|
|
models/mhg_model/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc
ADDED
Binary file (19.5 kB). View file
|
|
models/mhg_model/graph_grammar/algo/tree_decomposition.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from itertools import combinations
|
23 |
+
from ..hypergraph import Hypergraph
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTree(nx.Graph):
|
29 |
+
''' clique tree object
|
30 |
+
|
31 |
+
Attributes
|
32 |
+
----------
|
33 |
+
hg : Hypergraph
|
34 |
+
This hypergraph will be decomposed.
|
35 |
+
root_hg : Hypergraph
|
36 |
+
Hypergraph on the root node.
|
37 |
+
ident_node_dict : dict
|
38 |
+
ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
|
39 |
+
'''
|
40 |
+
def __init__(self, hg=None, **kwargs):
|
41 |
+
self.hg = deepcopy(hg)
|
42 |
+
if self.hg is not None:
|
43 |
+
self.ident_node_dict = self.hg.get_identical_node_dict()
|
44 |
+
else:
|
45 |
+
self.ident_node_dict = {}
|
46 |
+
super().__init__(**kwargs)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def root_hg(self):
|
50 |
+
''' return the hypergraph on the root node
|
51 |
+
'''
|
52 |
+
return self.nodes[0]['subhg']
|
53 |
+
|
54 |
+
@root_hg.setter
|
55 |
+
def root_hg(self, hypergraph):
|
56 |
+
''' set the hypergraph on the root node
|
57 |
+
'''
|
58 |
+
self.nodes[0]['subhg'] = hypergraph
|
59 |
+
|
60 |
+
def insert_subhg(self, subhypergraph: Hypergraph) -> None:
|
61 |
+
''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
subhg : Hypergraph
|
66 |
+
'''
|
67 |
+
num_nodes = self.number_of_nodes()
|
68 |
+
self.add_node(num_nodes, subhg=subhypergraph)
|
69 |
+
self.add_edge(num_nodes, 0)
|
70 |
+
adj_nodes = deepcopy(list(self.adj[0].keys()))
|
71 |
+
for each_node in adj_nodes:
|
72 |
+
if len(self.nodes[each_node]["subhg"].nodes.intersection(
|
73 |
+
self.nodes[num_nodes]["subhg"].nodes)\
|
74 |
+
- self.root_hg.nodes) != 0 and each_node != num_nodes:
|
75 |
+
self.remove_edge(0, each_node)
|
76 |
+
self.add_edge(each_node, num_nodes)
|
77 |
+
|
78 |
+
def to_irredundant(self) -> None:
|
79 |
+
''' convert the clique tree to be irredundant
|
80 |
+
'''
|
81 |
+
for each_node in self.hg.nodes:
|
82 |
+
subtree = self.subgraph([
|
83 |
+
each_tree_node for each_tree_node in self.nodes()\
|
84 |
+
if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
|
85 |
+
leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
|
86 |
+
redundant_leaf_node_list = []
|
87 |
+
for each_leaf_node in leaf_node_list:
|
88 |
+
if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
|
89 |
+
redundant_leaf_node_list.append(each_leaf_node)
|
90 |
+
for each_red_leaf_node in redundant_leaf_node_list:
|
91 |
+
current_node = each_red_leaf_node
|
92 |
+
while subtree.degree(current_node) == 1 \
|
93 |
+
and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
|
94 |
+
self.nodes[current_node]["subhg"].remove_node(each_node)
|
95 |
+
remove_node = current_node
|
96 |
+
current_node = list(dict(subtree[remove_node]).keys())[0]
|
97 |
+
subtree.remove_node(remove_node)
|
98 |
+
|
99 |
+
fixed_node_set = deepcopy(self.nodes)
|
100 |
+
for each_node in fixed_node_set:
|
101 |
+
if self.nodes[each_node]["subhg"].num_edges == 0:
|
102 |
+
if len(self[each_node]) == 1:
|
103 |
+
self.remove_node(each_node)
|
104 |
+
elif len(self[each_node]) == 2:
|
105 |
+
self.add_edge(*self[each_node])
|
106 |
+
self.remove_node(each_node)
|
107 |
+
else:
|
108 |
+
pass
|
109 |
+
else:
|
110 |
+
pass
|
111 |
+
|
112 |
+
redundant = True
|
113 |
+
while redundant:
|
114 |
+
redundant = False
|
115 |
+
fixed_edge_set = deepcopy(self.edges)
|
116 |
+
remove_node_set = set()
|
117 |
+
for node_1, node_2 in fixed_edge_set:
|
118 |
+
if node_1 in remove_node_set or node_2 in remove_node_set:
|
119 |
+
pass
|
120 |
+
else:
|
121 |
+
if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
|
122 |
+
redundant = True
|
123 |
+
adj_node_list = set(self.adj[node_1]) - {node_2}
|
124 |
+
self.remove_node(node_1)
|
125 |
+
remove_node_set.add(node_1)
|
126 |
+
for each_node in adj_node_list:
|
127 |
+
self.add_edge(node_2, each_node)
|
128 |
+
|
129 |
+
elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
|
130 |
+
redundant = True
|
131 |
+
adj_node_list = set(self.adj[node_2]) - {node_1}
|
132 |
+
self.remove_node(node_2)
|
133 |
+
remove_node_set.add(node_2)
|
134 |
+
for each_node in adj_node_list:
|
135 |
+
self.add_edge(node_1, each_node)
|
136 |
+
|
137 |
+
def node_update(self, key_node: str, subhg) -> None:
|
138 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
key_node : str
|
143 |
+
key node that must be removed.
|
144 |
+
subhg : Hypegraph
|
145 |
+
"""
|
146 |
+
for each_edge in subhg.edges:
|
147 |
+
self.root_hg.remove_edge(each_edge)
|
148 |
+
self.root_hg.remove_nodes(self.ident_node_dict[key_node])
|
149 |
+
|
150 |
+
adj_node_list = list(subhg.nodes)
|
151 |
+
for each_node in subhg.nodes:
|
152 |
+
if each_node not in self.ident_node_dict[key_node]:
|
153 |
+
if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
|
154 |
+
self.root_hg.remove_node(each_node)
|
155 |
+
adj_node_list.remove(each_node)
|
156 |
+
else:
|
157 |
+
adj_node_list.remove(each_node)
|
158 |
+
|
159 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
160 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
161 |
+
self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
|
162 |
+
|
163 |
+
subhg.remove_edges_with_attr({'tmp' : True})
|
164 |
+
self.insert_subhg(subhg)
|
165 |
+
|
166 |
+
def update(self, subhg, remove_nodes=False):
|
167 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
168 |
+
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
subhg : Hypegraph
|
172 |
+
"""
|
173 |
+
for each_edge in subhg.edges:
|
174 |
+
self.root_hg.remove_edge(each_edge)
|
175 |
+
if remove_nodes:
|
176 |
+
remove_edge_list = []
|
177 |
+
for each_edge in self.root_hg.edges:
|
178 |
+
if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
|
179 |
+
and self.root_hg.edge_attr(each_edge).get('tmp', False):
|
180 |
+
remove_edge_list.append(each_edge)
|
181 |
+
self.root_hg.remove_edges(remove_edge_list)
|
182 |
+
|
183 |
+
adj_node_list = list(subhg.nodes)
|
184 |
+
for each_node in subhg.nodes:
|
185 |
+
if self.root_hg.degree(each_node) == 0:
|
186 |
+
self.root_hg.remove_node(each_node)
|
187 |
+
adj_node_list.remove(each_node)
|
188 |
+
|
189 |
+
if len(adj_node_list) != 1 and not remove_nodes:
|
190 |
+
self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
|
191 |
+
'''
|
192 |
+
else:
|
193 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
194 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
195 |
+
self.root_hg.add_edge(
|
196 |
+
[each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
197 |
+
'''
|
198 |
+
subhg.remove_edges_with_attr({'tmp':True})
|
199 |
+
self.insert_subhg(subhg)
|
200 |
+
|
201 |
+
|
202 |
+
def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
|
203 |
+
if mode == 'standard':
|
204 |
+
degree_dict = hg.degrees()
|
205 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
206 |
+
min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
|
207 |
+
return min_deg_node, min_deg_subhg
|
208 |
+
elif mode == 'mol':
|
209 |
+
degree_dict = hg.degrees()
|
210 |
+
min_deg = min(degree_dict.values())
|
211 |
+
min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
|
212 |
+
min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
|
213 |
+
for each_min_deg_node in min_deg_node_list]
|
214 |
+
best_score = np.inf
|
215 |
+
best_idx = -1
|
216 |
+
for each_idx in range(len(min_deg_subhg_list)):
|
217 |
+
if min_deg_subhg_list[each_idx].num_nodes < best_score:
|
218 |
+
best_idx = each_idx
|
219 |
+
return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
|
220 |
+
else:
|
221 |
+
raise ValueError
|
222 |
+
|
223 |
+
|
224 |
+
def tree_decomposition(hg, irredundant=True):
|
225 |
+
""" compute a tree decomposition of the input hypergraph
|
226 |
+
|
227 |
+
Parameters
|
228 |
+
----------
|
229 |
+
hg : Hypergraph
|
230 |
+
hypergraph to be decomposed
|
231 |
+
irredundant : bool
|
232 |
+
if True, irredundant tree decomposition will be computed.
|
233 |
+
|
234 |
+
Returns
|
235 |
+
-------
|
236 |
+
clique_tree : nx.Graph
|
237 |
+
each node contains a subhypergraph of `hg`
|
238 |
+
"""
|
239 |
+
org_hg = hg.copy()
|
240 |
+
ident_node_dict = hg.get_identical_node_dict()
|
241 |
+
clique_tree = CliqueTree(org_hg)
|
242 |
+
clique_tree.add_node(0, subhg=org_hg)
|
243 |
+
while True:
|
244 |
+
degree_dict = org_hg.degrees()
|
245 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
246 |
+
min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
|
247 |
+
if org_hg.nodes == min_deg_subhg.nodes:
|
248 |
+
break
|
249 |
+
|
250 |
+
# org_hg and min_deg_subhg are divided
|
251 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
252 |
+
|
253 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
254 |
+
|
255 |
+
if irredundant:
|
256 |
+
clique_tree.to_irredundant()
|
257 |
+
return clique_tree
|
258 |
+
|
259 |
+
|
260 |
+
def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
|
261 |
+
''' compute a tree decomposition given a hyperedge replacement grammar.
|
262 |
+
the resultant clique tree should induce a less compact HRG.
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
hg : Hypergraph
|
267 |
+
hypergraph to be decomposed
|
268 |
+
hrg : HyperedgeReplacementGrammar
|
269 |
+
current HRG
|
270 |
+
irredundant : bool
|
271 |
+
if True, irredundant tree decomposition will be computed.
|
272 |
+
|
273 |
+
Returns
|
274 |
+
-------
|
275 |
+
clique_tree : nx.Graph
|
276 |
+
each node contains a subhypergraph of `hg`
|
277 |
+
'''
|
278 |
+
org_hg = hg.copy()
|
279 |
+
ident_node_dict = hg.get_identical_node_dict()
|
280 |
+
clique_tree = CliqueTree(org_hg)
|
281 |
+
clique_tree.add_node(0, subhg=org_hg)
|
282 |
+
root_node = 0
|
283 |
+
|
284 |
+
# construct a clique tree using HRG
|
285 |
+
success_any = True
|
286 |
+
while success_any:
|
287 |
+
success_any = False
|
288 |
+
for each_prod_rule in hrg.prod_rule_list:
|
289 |
+
org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
|
290 |
+
if success:
|
291 |
+
if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
|
292 |
+
success_any = True
|
293 |
+
subhg.remove_edges_with_attr({'terminal' : False})
|
294 |
+
clique_tree.root_hg = org_hg
|
295 |
+
clique_tree.insert_subhg(subhg)
|
296 |
+
|
297 |
+
clique_tree.root_hg = org_hg
|
298 |
+
|
299 |
+
for each_edge in deepcopy(org_hg.edges):
|
300 |
+
if not org_hg.edge_attr(each_edge)['terminal']:
|
301 |
+
node_list = org_hg.nodes_in_edge(each_edge)
|
302 |
+
org_hg.remove_edge(each_edge)
|
303 |
+
|
304 |
+
for each_node_1, each_node_2 in combinations(node_list, 2):
|
305 |
+
if not org_hg.is_adj(each_node_1, each_node_2):
|
306 |
+
org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
307 |
+
|
308 |
+
# construct a clique tree using the existing algorithm
|
309 |
+
degree_dict = org_hg.degrees()
|
310 |
+
if degree_dict:
|
311 |
+
while True:
|
312 |
+
min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
|
313 |
+
if org_hg.nodes == min_deg_subhg.nodes: break
|
314 |
+
|
315 |
+
# org_hg and min_deg_subhg are divided
|
316 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
317 |
+
|
318 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
319 |
+
if irredundant:
|
320 |
+
clique_tree.to_irredundant()
|
321 |
+
|
322 |
+
if return_root:
|
323 |
+
if root_node == 0 and 0 not in clique_tree.nodes:
|
324 |
+
root_node = clique_tree.number_of_nodes()
|
325 |
+
while root_node not in clique_tree.nodes:
|
326 |
+
root_node -= 1
|
327 |
+
elif root_node not in clique_tree.nodes:
|
328 |
+
while root_node not in clique_tree.nodes:
|
329 |
+
root_node -= 1
|
330 |
+
else:
|
331 |
+
pass
|
332 |
+
return clique_tree, root_node
|
333 |
+
else:
|
334 |
+
return clique_tree
|
335 |
+
|
336 |
+
|
337 |
+
def tree_decomposition_from_leaf(hg, irredundant=True):
|
338 |
+
""" compute a tree decomposition of the input hypergraph
|
339 |
+
|
340 |
+
Parameters
|
341 |
+
----------
|
342 |
+
hg : Hypergraph
|
343 |
+
hypergraph to be decomposed
|
344 |
+
irredundant : bool
|
345 |
+
if True, irredundant tree decomposition will be computed.
|
346 |
+
|
347 |
+
Returns
|
348 |
+
-------
|
349 |
+
clique_tree : nx.Graph
|
350 |
+
each node contains a subhypergraph of `hg`
|
351 |
+
"""
|
352 |
+
def apply_normal_decomposition(clique_tree):
|
353 |
+
degree_dict = clique_tree.root_hg.degrees()
|
354 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
355 |
+
min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
|
356 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
357 |
+
return clique_tree, False
|
358 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
359 |
+
return clique_tree, True
|
360 |
+
|
361 |
+
def apply_min_edge_deg_decomposition(clique_tree):
|
362 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
363 |
+
non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
364 |
+
if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
|
365 |
+
if not non_tmp_edge_list:
|
366 |
+
return clique_tree, False
|
367 |
+
min_deg_edge = None
|
368 |
+
min_deg = np.inf
|
369 |
+
for each_edge in non_tmp_edge_list:
|
370 |
+
if min_deg > edge_degree_dict[each_edge]:
|
371 |
+
min_deg_edge = each_edge
|
372 |
+
min_deg = edge_degree_dict[each_edge]
|
373 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
374 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
375 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
376 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
377 |
+
return clique_tree, False
|
378 |
+
clique_tree.update(min_deg_subhg)
|
379 |
+
return clique_tree, True
|
380 |
+
|
381 |
+
org_hg = hg.copy()
|
382 |
+
clique_tree = CliqueTree(org_hg)
|
383 |
+
clique_tree.add_node(0, subhg=org_hg)
|
384 |
+
|
385 |
+
success = True
|
386 |
+
while success:
|
387 |
+
clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
|
388 |
+
if not success:
|
389 |
+
clique_tree, success = apply_normal_decomposition(clique_tree)
|
390 |
+
|
391 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
392 |
+
if irredundant:
|
393 |
+
clique_tree.to_irredundant()
|
394 |
+
return clique_tree
|
395 |
+
|
396 |
+
def topological_tree_decomposition(
|
397 |
+
hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
|
398 |
+
''' compute a tree decomposition of the input hypergraph
|
399 |
+
|
400 |
+
Parameters
|
401 |
+
----------
|
402 |
+
hg : Hypergraph
|
403 |
+
hypergraph to be decomposed
|
404 |
+
irredundant : bool
|
405 |
+
if True, irredundant tree decomposition will be computed.
|
406 |
+
|
407 |
+
Returns
|
408 |
+
-------
|
409 |
+
clique_tree : CliqueTree
|
410 |
+
each node contains a subhypergraph of `hg`
|
411 |
+
'''
|
412 |
+
def _contract_tree(clique_tree):
|
413 |
+
''' contract a single leaf
|
414 |
+
|
415 |
+
Parameters
|
416 |
+
----------
|
417 |
+
clique_tree : CliqueTree
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
CliqueTree, bool
|
422 |
+
bool represents whether this operation succeeds or not.
|
423 |
+
'''
|
424 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
425 |
+
leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
426 |
+
if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
|
427 |
+
and edge_degree_dict[each_edge] == 1]
|
428 |
+
if not leaf_edge_list:
|
429 |
+
return clique_tree, False
|
430 |
+
min_deg_edge = leaf_edge_list[0]
|
431 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
432 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
433 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
434 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
435 |
+
return clique_tree, False
|
436 |
+
clique_tree.update(min_deg_subhg)
|
437 |
+
return clique_tree, True
|
438 |
+
|
439 |
+
def _rip_labels_from_cycles(clique_tree, org_hg):
|
440 |
+
''' rip hyperedge-labels off
|
441 |
+
|
442 |
+
Parameters
|
443 |
+
----------
|
444 |
+
clique_tree : CliqueTree
|
445 |
+
org_hg : Hypergraph
|
446 |
+
|
447 |
+
Returns
|
448 |
+
-------
|
449 |
+
CliqueTree, bool
|
450 |
+
bool represents whether this operation succeeds or not.
|
451 |
+
'''
|
452 |
+
ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
|
453 |
+
for each_edge in clique_tree.root_hg.edges:
|
454 |
+
if each_edge in org_hg.edges:
|
455 |
+
if org_hg.in_cycle(each_edge):
|
456 |
+
node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
|
457 |
+
subhg = clique_tree.root_hg.get_subhg(
|
458 |
+
node_list, [each_edge], ident_node_dict)
|
459 |
+
if clique_tree.root_hg.nodes == subhg.nodes:
|
460 |
+
return clique_tree, False
|
461 |
+
clique_tree.update(subhg)
|
462 |
+
'''
|
463 |
+
in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
|
464 |
+
if not all(in_cycle_dict.values()):
|
465 |
+
node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
|
466 |
+
node_list = [node_not_in_cycle]
|
467 |
+
node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
|
468 |
+
edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
|
469 |
+
import pdb; pdb.set_trace()
|
470 |
+
subhg = clique_tree.root_hg.get_subhg(
|
471 |
+
node_list, edge_list, ident_node_dict)
|
472 |
+
|
473 |
+
clique_tree.update(subhg)
|
474 |
+
'''
|
475 |
+
return clique_tree, True
|
476 |
+
return clique_tree, False
|
477 |
+
|
478 |
+
def _shrink_cycle(clique_tree):
|
479 |
+
''' shrink a cycle
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
clique_tree : CliqueTree
|
484 |
+
|
485 |
+
Returns
|
486 |
+
-------
|
487 |
+
CliqueTree, bool
|
488 |
+
bool represents whether this operation succeeds or not.
|
489 |
+
'''
|
490 |
+
def filter_subhg(subhg, hg, key_node):
|
491 |
+
num_nodes_cycle = 0
|
492 |
+
nodes_in_cycle_list = []
|
493 |
+
for each_node in subhg.nodes:
|
494 |
+
if hg.in_cycle(each_node):
|
495 |
+
num_nodes_cycle += 1
|
496 |
+
if each_node != key_node:
|
497 |
+
nodes_in_cycle_list.append(each_node)
|
498 |
+
if num_nodes_cycle > 3:
|
499 |
+
break
|
500 |
+
if num_nodes_cycle != 3:
|
501 |
+
return False
|
502 |
+
else:
|
503 |
+
for each_edge in hg.edges:
|
504 |
+
if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
|
505 |
+
return False
|
506 |
+
return True
|
507 |
+
|
508 |
+
#ident_node_dict = hg.get_identical_node_dict()
|
509 |
+
ident_node_dict = clique_tree.ident_node_dict
|
510 |
+
for each_node in clique_tree.root_hg.nodes:
|
511 |
+
if clique_tree.root_hg.in_cycle(each_node)\
|
512 |
+
and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
|
513 |
+
clique_tree.root_hg,
|
514 |
+
each_node):
|
515 |
+
target_node = each_node
|
516 |
+
target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
|
517 |
+
if clique_tree.root_hg.nodes == target_subhg.nodes:
|
518 |
+
return clique_tree, False
|
519 |
+
clique_tree.update(target_subhg)
|
520 |
+
return clique_tree, True
|
521 |
+
return clique_tree, False
|
522 |
+
|
523 |
+
def _contract_cycles(clique_tree):
|
524 |
+
'''
|
525 |
+
remove a subhypergraph that looks like a cycle on a leaf.
|
526 |
+
|
527 |
+
Parameters
|
528 |
+
----------
|
529 |
+
clique_tree : CliqueTree
|
530 |
+
|
531 |
+
Returns
|
532 |
+
-------
|
533 |
+
CliqueTree, bool
|
534 |
+
bool represents whether this operation succeeds or not.
|
535 |
+
'''
|
536 |
+
def _divide_hg(hg):
|
537 |
+
''' divide a hypergraph into subhypergraphs such that
|
538 |
+
each subhypergraph is connected to each other in a tree-like way.
|
539 |
+
|
540 |
+
Parameters
|
541 |
+
----------
|
542 |
+
hg : Hypergraph
|
543 |
+
|
544 |
+
Returns
|
545 |
+
-------
|
546 |
+
list of Hypergraphs
|
547 |
+
each element corresponds to a subhypergraph of `hg`
|
548 |
+
'''
|
549 |
+
for each_node in hg.nodes:
|
550 |
+
if hg.is_dividable(each_node):
|
551 |
+
adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
|
552 |
+
'''
|
553 |
+
if any(adj_edges_dict.values()):
|
554 |
+
import pdb; pdb.set_trace()
|
555 |
+
edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
|
556 |
+
subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
|
557 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
|
558 |
+
else:
|
559 |
+
'''
|
560 |
+
subhg1, subhg2 = hg.divide(each_node)
|
561 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
562 |
+
return [hg]
|
563 |
+
|
564 |
+
def _is_leaf(hg, divided_subhg) -> bool:
|
565 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
566 |
+
|
567 |
+
Parameters
|
568 |
+
----------
|
569 |
+
hg : Hypergraph
|
570 |
+
divided_subhg : Hypergraph
|
571 |
+
`divided_subhg` is a subhypergraph of `hg`
|
572 |
+
|
573 |
+
Returns
|
574 |
+
-------
|
575 |
+
bool
|
576 |
+
'''
|
577 |
+
'''
|
578 |
+
adj_edges_set = set([])
|
579 |
+
for each_node in divided_subhg.nodes:
|
580 |
+
adj_edges_set.update(set(hg.adj_edges(each_node)))
|
581 |
+
|
582 |
+
|
583 |
+
_hg = deepcopy(hg)
|
584 |
+
_hg.remove_subhg(divided_subhg)
|
585 |
+
if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
|
586 |
+
import pdb; pdb.set_trace()
|
587 |
+
return len(adj_edges_set - divided_subhg.edges) == 1
|
588 |
+
'''
|
589 |
+
_hg = deepcopy(hg)
|
590 |
+
_hg.remove_subhg(divided_subhg)
|
591 |
+
return nx.is_connected(_hg.hg)
|
592 |
+
|
593 |
+
subhg_list = _divide_hg(clique_tree.root_hg)
|
594 |
+
if len(subhg_list) == 1:
|
595 |
+
return clique_tree, False
|
596 |
+
else:
|
597 |
+
while len(subhg_list) > 1:
|
598 |
+
max_leaf_subhg = None
|
599 |
+
for each_subhg in subhg_list:
|
600 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
601 |
+
if max_leaf_subhg is None:
|
602 |
+
max_leaf_subhg = each_subhg
|
603 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
604 |
+
max_leaf_subhg = each_subhg
|
605 |
+
clique_tree.update(max_leaf_subhg)
|
606 |
+
subhg_list.remove(max_leaf_subhg)
|
607 |
+
return clique_tree, True
|
608 |
+
|
609 |
+
org_hg = hg.copy()
|
610 |
+
clique_tree = CliqueTree(org_hg)
|
611 |
+
clique_tree.add_node(0, subhg=org_hg)
|
612 |
+
|
613 |
+
success = True
|
614 |
+
while success:
|
615 |
+
'''
|
616 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
617 |
+
if not success:
|
618 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
619 |
+
'''
|
620 |
+
clique_tree, success = _contract_tree(clique_tree)
|
621 |
+
if not success:
|
622 |
+
if rip_labels:
|
623 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
624 |
+
if not success:
|
625 |
+
if shrink_cycle:
|
626 |
+
clique_tree, success = _shrink_cycle(clique_tree)
|
627 |
+
if not success:
|
628 |
+
if contract_cycles:
|
629 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
630 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
631 |
+
if irredundant:
|
632 |
+
clique_tree.to_irredundant()
|
633 |
+
return clique_tree
|
634 |
+
|
635 |
+
def molecular_tree_decomposition(hg, irredundant=True):
|
636 |
+
""" compute a tree decomposition of the input molecular hypergraph
|
637 |
+
|
638 |
+
Parameters
|
639 |
+
----------
|
640 |
+
hg : Hypergraph
|
641 |
+
molecular hypergraph to be decomposed
|
642 |
+
irredundant : bool
|
643 |
+
if True, irredundant tree decomposition will be computed.
|
644 |
+
|
645 |
+
Returns
|
646 |
+
-------
|
647 |
+
clique_tree : CliqueTree
|
648 |
+
each node contains a subhypergraph of `hg`
|
649 |
+
"""
|
650 |
+
def _divide_hg(hg):
|
651 |
+
''' divide a hypergraph into subhypergraphs such that
|
652 |
+
each subhypergraph is connected to each other in a tree-like way.
|
653 |
+
|
654 |
+
Parameters
|
655 |
+
----------
|
656 |
+
hg : Hypergraph
|
657 |
+
|
658 |
+
Returns
|
659 |
+
-------
|
660 |
+
list of Hypergraphs
|
661 |
+
each element corresponds to a subhypergraph of `hg`
|
662 |
+
'''
|
663 |
+
is_ring = False
|
664 |
+
for each_node in hg.nodes:
|
665 |
+
if hg.node_attr(each_node)['is_in_ring']:
|
666 |
+
is_ring = True
|
667 |
+
if not hg.node_attr(each_node)['is_in_ring'] \
|
668 |
+
and hg.degree(each_node) == 2:
|
669 |
+
subhg1, subhg2 = hg.divide(each_node)
|
670 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
671 |
+
|
672 |
+
if is_ring:
|
673 |
+
subhg_list = []
|
674 |
+
remove_edge_list = []
|
675 |
+
remove_node_list = []
|
676 |
+
for each_edge in hg.edges:
|
677 |
+
node_list = hg.nodes_in_edge(each_edge)
|
678 |
+
subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
|
679 |
+
subhg_list.append(subhg)
|
680 |
+
remove_edge_list.append(each_edge)
|
681 |
+
for each_node in node_list:
|
682 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
683 |
+
remove_node_list.append(each_node)
|
684 |
+
hg.remove_edges(remove_edge_list)
|
685 |
+
hg.remove_nodes(remove_node_list, False)
|
686 |
+
return subhg_list + [hg]
|
687 |
+
else:
|
688 |
+
return [hg]
|
689 |
+
|
690 |
+
org_hg = hg.copy()
|
691 |
+
clique_tree = CliqueTree(org_hg)
|
692 |
+
clique_tree.add_node(0, subhg=org_hg)
|
693 |
+
|
694 |
+
subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
|
695 |
+
#_subhg_list = deepcopy(subhg_list)
|
696 |
+
if len(subhg_list) == 1:
|
697 |
+
pass
|
698 |
+
else:
|
699 |
+
while len(subhg_list) > 1:
|
700 |
+
max_leaf_subhg = None
|
701 |
+
for each_subhg in subhg_list:
|
702 |
+
if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
|
703 |
+
if max_leaf_subhg is None:
|
704 |
+
max_leaf_subhg = each_subhg
|
705 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
706 |
+
max_leaf_subhg = each_subhg
|
707 |
+
|
708 |
+
if max_leaf_subhg is None:
|
709 |
+
for each_subhg in subhg_list:
|
710 |
+
if _is_ring_label(clique_tree.root_hg, each_subhg):
|
711 |
+
if max_leaf_subhg is None:
|
712 |
+
max_leaf_subhg = each_subhg
|
713 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
714 |
+
max_leaf_subhg = each_subhg
|
715 |
+
if max_leaf_subhg is not None:
|
716 |
+
clique_tree.update(max_leaf_subhg)
|
717 |
+
subhg_list.remove(max_leaf_subhg)
|
718 |
+
else:
|
719 |
+
for each_subhg in subhg_list:
|
720 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
721 |
+
if max_leaf_subhg is None:
|
722 |
+
max_leaf_subhg = each_subhg
|
723 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
724 |
+
max_leaf_subhg = each_subhg
|
725 |
+
if max_leaf_subhg is not None:
|
726 |
+
clique_tree.update(max_leaf_subhg, True)
|
727 |
+
subhg_list.remove(max_leaf_subhg)
|
728 |
+
else:
|
729 |
+
break
|
730 |
+
if len(subhg_list) > 1:
|
731 |
+
'''
|
732 |
+
for each_idx, each_subhg in enumerate(subhg_list):
|
733 |
+
each_subhg.draw(f'{each_idx}', True)
|
734 |
+
clique_tree.root_hg.draw('root', True)
|
735 |
+
import pickle
|
736 |
+
with open('buggy_hg.pkl', 'wb') as f:
|
737 |
+
pickle.dump(hg, f)
|
738 |
+
return clique_tree, subhg_list, _subhg_list
|
739 |
+
'''
|
740 |
+
raise RuntimeError('bug in tree decomposition algorithm')
|
741 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
742 |
+
|
743 |
+
'''
|
744 |
+
for each_tree_node in clique_tree.adj[0]:
|
745 |
+
subhg = clique_tree.nodes[each_tree_node]['subhg']
|
746 |
+
for each_edge in subhg.edges:
|
747 |
+
if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
|
748 |
+
clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
|
749 |
+
'''
|
750 |
+
if irredundant:
|
751 |
+
clique_tree.to_irredundant()
|
752 |
+
return clique_tree #, _subhg_list
|
753 |
+
|
754 |
+
def _is_leaf(hg, subhg) -> bool:
|
755 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
756 |
+
|
757 |
+
Parameters
|
758 |
+
----------
|
759 |
+
hg : Hypergraph
|
760 |
+
subhg : Hypergraph
|
761 |
+
`subhg` is a subhypergraph of `hg`
|
762 |
+
|
763 |
+
Returns
|
764 |
+
-------
|
765 |
+
bool
|
766 |
+
'''
|
767 |
+
if len(subhg.edges) == 0:
|
768 |
+
adj_edge_set = set([])
|
769 |
+
subhg_edge_set = set([])
|
770 |
+
for each_edge in hg.edges:
|
771 |
+
if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
|
772 |
+
subhg_edge_set.add(each_edge)
|
773 |
+
for each_node in subhg.nodes:
|
774 |
+
adj_edge_set.update(set(hg.adj_edges(each_node)))
|
775 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
776 |
+
return True
|
777 |
+
else:
|
778 |
+
return False
|
779 |
+
elif len(subhg.edges) == 1:
|
780 |
+
adj_edge_set = set([])
|
781 |
+
subhg_edge_set = subhg.edges
|
782 |
+
for each_node in subhg.nodes:
|
783 |
+
for each_adj_edge in hg.adj_edges(each_node):
|
784 |
+
adj_edge_set.add(each_adj_edge)
|
785 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
786 |
+
return True
|
787 |
+
else:
|
788 |
+
return False
|
789 |
+
else:
|
790 |
+
raise ValueError('subhg should be nodes only or one-edge hypergraph.')
|
791 |
+
|
792 |
+
def _is_ring_label(hg, subhg):
|
793 |
+
if len(subhg.edges) != 1:
|
794 |
+
return False
|
795 |
+
edge_name = list(subhg.edges)[0]
|
796 |
+
#assert edge_name in hg.edges, f'{edge_name}'
|
797 |
+
is_in_ring = False
|
798 |
+
for each_node in subhg.nodes:
|
799 |
+
if subhg.node_attr(each_node)['is_in_ring']:
|
800 |
+
is_in_ring = True
|
801 |
+
else:
|
802 |
+
adj_edge_list = list(hg.adj_edges(each_node))
|
803 |
+
adj_edge_list.remove(edge_name)
|
804 |
+
if len(adj_edge_list) == 1:
|
805 |
+
if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
|
806 |
+
return False
|
807 |
+
elif len(adj_edge_list) == 0:
|
808 |
+
pass
|
809 |
+
else:
|
810 |
+
raise ValueError
|
811 |
+
if is_in_ring:
|
812 |
+
return True
|
813 |
+
else:
|
814 |
+
return False
|
815 |
+
|
816 |
+
def _is_ring(hg):
|
817 |
+
for each_node in hg.nodes:
|
818 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
819 |
+
return False
|
820 |
+
return True
|
821 |
+
|
models/mhg_model/graph_grammar/graph_grammar/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (690 Bytes). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc
ADDED
Binary file (1.19 kB). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc
ADDED
Binary file (4.73 kB). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc
ADDED
Binary file (29.1 kB). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc
ADDED
Binary file (5.39 kB). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
models/mhg_model/graph_grammar/graph_grammar/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from abc import ABCMeta, abstractmethod
|
22 |
+
|
23 |
+
class GraphGrammarBase(metaclass=ABCMeta):
|
24 |
+
@abstractmethod
|
25 |
+
def learn(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def sample(self):
|
30 |
+
pass
|
models/mhg_model/graph_grammar/graph_grammar/corpus.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from collections import Counter
|
22 |
+
from functools import partial
|
23 |
+
from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
|
24 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTreeCorpus(object):
|
29 |
+
|
30 |
+
''' clique tree corpus
|
31 |
+
|
32 |
+
Attributes
|
33 |
+
----------
|
34 |
+
clique_tree_list : list of CliqueTree
|
35 |
+
subhg_list : list of Hypergraph
|
36 |
+
'''
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
self.clique_tree_list = []
|
40 |
+
self.subhg_list = []
|
41 |
+
|
42 |
+
@property
|
43 |
+
def size(self):
|
44 |
+
return len(self.subhg_list)
|
45 |
+
|
46 |
+
def add_clique_tree(self, clique_tree):
|
47 |
+
for each_node in clique_tree.nodes:
|
48 |
+
subhg = clique_tree.nodes[each_node]['subhg']
|
49 |
+
subhg_idx = self.add_subhg(subhg)
|
50 |
+
clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
|
51 |
+
self.clique_tree_list.append(clique_tree)
|
52 |
+
|
53 |
+
def add_to_subhg_list(self, clique_tree, root_node):
|
54 |
+
parent_node_dict = {}
|
55 |
+
current_node = None
|
56 |
+
parent_node_dict[root_node] = None
|
57 |
+
stack = [root_node]
|
58 |
+
while stack:
|
59 |
+
current_node = stack.pop()
|
60 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
61 |
+
for each_child in clique_tree.adj[current_node]:
|
62 |
+
if each_child != parent_node_dict[current_node]:
|
63 |
+
stack.append(each_child)
|
64 |
+
parent_node_dict[each_child] = current_node
|
65 |
+
if parent_node_dict[current_node] is not None:
|
66 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
67 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
68 |
+
parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
|
69 |
+
|
70 |
+
parent_node_dict = {}
|
71 |
+
current_node = None
|
72 |
+
parent_node_dict[root_node] = None
|
73 |
+
stack = [root_node]
|
74 |
+
while stack:
|
75 |
+
current_node = stack.pop()
|
76 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
77 |
+
for each_child in clique_tree.adj[current_node]:
|
78 |
+
if each_child != parent_node_dict[current_node]:
|
79 |
+
stack.append(each_child)
|
80 |
+
parent_node_dict[each_child] = current_node
|
81 |
+
if parent_node_dict[current_node] is not None:
|
82 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
83 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
84 |
+
for each_idx, each_node in enumerate(common):
|
85 |
+
current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
|
86 |
+
|
87 |
+
subhg_idx, is_new = self.add_subhg(current_subhg)
|
88 |
+
clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
|
89 |
+
return clique_tree
|
90 |
+
|
91 |
+
def add_subhg(self, subhg):
|
92 |
+
if len(self.subhg_list) == 0:
|
93 |
+
node_dict = {}
|
94 |
+
for each_node in subhg.nodes:
|
95 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
96 |
+
node_list = []
|
97 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
98 |
+
node_list.append(each_key)
|
99 |
+
for each_idx, each_node in enumerate(node_list):
|
100 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
101 |
+
self.subhg_list.append(subhg)
|
102 |
+
return 0, True
|
103 |
+
else:
|
104 |
+
match = False
|
105 |
+
subhg_bond_symbol_counter \
|
106 |
+
= Counter([subhg.node_attr(each_node)['symbol'] \
|
107 |
+
for each_node in subhg.nodes])
|
108 |
+
subhg_atom_symbol_counter \
|
109 |
+
= Counter([subhg.edge_attr(each_edge).get('symbol', None) \
|
110 |
+
for each_edge in subhg.edges])
|
111 |
+
for each_idx, each_subhg in enumerate(self.subhg_list):
|
112 |
+
each_bond_symbol_counter \
|
113 |
+
= Counter([each_subhg.node_attr(each_node)['symbol'] \
|
114 |
+
for each_node in each_subhg.nodes])
|
115 |
+
each_atom_symbol_counter \
|
116 |
+
= Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
|
117 |
+
for each_edge in each_subhg.edges])
|
118 |
+
if not match \
|
119 |
+
and (subhg.num_nodes == each_subhg.num_nodes
|
120 |
+
and subhg.num_edges == each_subhg.num_edges
|
121 |
+
and subhg_bond_symbol_counter == each_bond_symbol_counter
|
122 |
+
and subhg_atom_symbol_counter == each_atom_symbol_counter):
|
123 |
+
gm = GraphMatcher(each_subhg.hg,
|
124 |
+
subhg.hg,
|
125 |
+
node_match=_easy_node_match,
|
126 |
+
edge_match=_edge_match)
|
127 |
+
try:
|
128 |
+
isomap = next(gm.isomorphisms_iter())
|
129 |
+
match = True
|
130 |
+
for each_node in each_subhg.nodes:
|
131 |
+
subhg.node_attr(isomap[each_node])['order4hrg'] \
|
132 |
+
= each_subhg.node_attr(each_node)['order4hrg']
|
133 |
+
if 'ext_id' in each_subhg.node_attr(each_node):
|
134 |
+
subhg.node_attr(isomap[each_node])['ext_id'] \
|
135 |
+
= each_subhg.node_attr(each_node)['ext_id']
|
136 |
+
return each_idx, False
|
137 |
+
except StopIteration:
|
138 |
+
match = False
|
139 |
+
if not match:
|
140 |
+
node_dict = {}
|
141 |
+
for each_node in subhg.nodes:
|
142 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
143 |
+
node_list = []
|
144 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
145 |
+
node_list.append(each_key)
|
146 |
+
for each_idx, each_node in enumerate(node_list):
|
147 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
148 |
+
|
149 |
+
#for each_idx, each_node in enumerate(subhg.nodes):
|
150 |
+
# subhg.node_attr(each_node)['order4hrg'] = each_idx
|
151 |
+
self.subhg_list.append(subhg)
|
152 |
+
return len(self.subhg_list) - 1, True
|
models/mhg_model/graph_grammar/graph_grammar/hrg.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from .corpus import CliqueTreeCorpus
|
22 |
+
from .base import GraphGrammarBase
|
23 |
+
from .symbols import TSymbol, NTSymbol, BondSymbol
|
24 |
+
from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
|
25 |
+
from ..hypergraph import Hypergraph
|
26 |
+
from collections import Counter
|
27 |
+
from copy import deepcopy
|
28 |
+
from ..algo.tree_decomposition import (
|
29 |
+
tree_decomposition,
|
30 |
+
tree_decomposition_with_hrg,
|
31 |
+
tree_decomposition_from_leaf,
|
32 |
+
topological_tree_decomposition,
|
33 |
+
molecular_tree_decomposition)
|
34 |
+
from functools import partial
|
35 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
36 |
+
from typing import List, Dict, Tuple
|
37 |
+
import networkx as nx
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import os
|
41 |
+
import random
|
42 |
+
|
43 |
+
DEBUG = False
|
44 |
+
|
45 |
+
|
46 |
+
class ProductionRule(object):
|
47 |
+
""" A class of a production rule
|
48 |
+
|
49 |
+
Attributes
|
50 |
+
----------
|
51 |
+
lhs : Hypergraph or None
|
52 |
+
the left hand side of the production rule.
|
53 |
+
if None, the rule is a starting rule.
|
54 |
+
rhs : Hypergraph
|
55 |
+
the right hand side of the production rule.
|
56 |
+
"""
|
57 |
+
def __init__(self, lhs, rhs):
|
58 |
+
self.lhs = lhs
|
59 |
+
self.rhs = rhs
|
60 |
+
|
61 |
+
@property
|
62 |
+
def is_start_rule(self) -> bool:
|
63 |
+
return self.lhs.num_nodes == 0
|
64 |
+
|
65 |
+
@property
|
66 |
+
def ext_node(self) -> Dict[int, str]:
|
67 |
+
""" return a dict of external nodes
|
68 |
+
"""
|
69 |
+
if self.is_start_rule:
|
70 |
+
return {}
|
71 |
+
else:
|
72 |
+
ext_node_dict = {}
|
73 |
+
for each_node in self.lhs.nodes:
|
74 |
+
ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
|
75 |
+
return ext_node_dict
|
76 |
+
|
77 |
+
@property
|
78 |
+
def lhs_nt_symbol(self) -> NTSymbol:
|
79 |
+
if self.is_start_rule:
|
80 |
+
return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
|
81 |
+
else:
|
82 |
+
return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
|
83 |
+
|
84 |
+
def rhs_adj_mat(self, node_edge_list):
|
85 |
+
''' return the adjacency matrix of rhs of the production rule
|
86 |
+
'''
|
87 |
+
return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
|
88 |
+
|
89 |
+
def draw(self, file_path=None):
|
90 |
+
return self.rhs.draw(file_path)
|
91 |
+
|
92 |
+
def is_same(self, prod_rule, ignore_order=False):
|
93 |
+
""" judge whether this production rule is
|
94 |
+
the same as the input one, `prod_rule`
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
prod_rule : ProductionRule
|
99 |
+
production rule to be compared
|
100 |
+
|
101 |
+
Returns
|
102 |
+
-------
|
103 |
+
is_same : bool
|
104 |
+
isomap : dict
|
105 |
+
isomorphism of nodes and hyperedges.
|
106 |
+
ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
|
107 |
+
'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
|
108 |
+
'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
|
109 |
+
key comes from `prod_rule`, value comes from `self`.
|
110 |
+
"""
|
111 |
+
if self.is_start_rule:
|
112 |
+
if not prod_rule.is_start_rule:
|
113 |
+
return False, {}
|
114 |
+
else:
|
115 |
+
if prod_rule.is_start_rule:
|
116 |
+
return False, {}
|
117 |
+
else:
|
118 |
+
if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
|
119 |
+
return False, {}
|
120 |
+
|
121 |
+
if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
|
122 |
+
return False, {}
|
123 |
+
if prod_rule.rhs.num_edges != self.rhs.num_edges:
|
124 |
+
return False, {}
|
125 |
+
|
126 |
+
subhg_bond_symbol_counter \
|
127 |
+
= Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
|
128 |
+
for each_node in prod_rule.rhs.nodes])
|
129 |
+
each_bond_symbol_counter \
|
130 |
+
= Counter([self.rhs.node_attr(each_node)['symbol'] \
|
131 |
+
for each_node in self.rhs.nodes])
|
132 |
+
if subhg_bond_symbol_counter != each_bond_symbol_counter:
|
133 |
+
return False, {}
|
134 |
+
|
135 |
+
subhg_atom_symbol_counter \
|
136 |
+
= Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
|
137 |
+
for each_edge in prod_rule.rhs.edges])
|
138 |
+
each_atom_symbol_counter \
|
139 |
+
= Counter([self.rhs.edge_attr(each_edge)['symbol'] \
|
140 |
+
for each_edge in self.rhs.edges])
|
141 |
+
if subhg_atom_symbol_counter != each_atom_symbol_counter:
|
142 |
+
return False, {}
|
143 |
+
|
144 |
+
gm = GraphMatcher(prod_rule.rhs.hg,
|
145 |
+
self.rhs.hg,
|
146 |
+
partial(_node_match_prod_rule,
|
147 |
+
ignore_order=ignore_order),
|
148 |
+
partial(_edge_match,
|
149 |
+
ignore_order=ignore_order))
|
150 |
+
try:
|
151 |
+
return True, next(gm.isomorphisms_iter())
|
152 |
+
except StopIteration:
|
153 |
+
return False, {}
|
154 |
+
|
155 |
+
def applied_to(self,
|
156 |
+
hg: Hypergraph,
|
157 |
+
edge: str) -> Tuple[Hypergraph, List[str]]:
|
158 |
+
""" augment `hg` by replacing `edge` with `self.rhs`.
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
edge : str
|
164 |
+
`edge` must belong to `hg`
|
165 |
+
|
166 |
+
Returns
|
167 |
+
-------
|
168 |
+
hg : Hypergraph
|
169 |
+
resultant hypergraph
|
170 |
+
nt_edge_list : list
|
171 |
+
list of non-terminal edges
|
172 |
+
"""
|
173 |
+
nt_edge_dict = {}
|
174 |
+
if self.is_start_rule:
|
175 |
+
if (edge is not None) or (hg is not None):
|
176 |
+
ValueError("edge and hg must be None for this prod rule.")
|
177 |
+
hg = Hypergraph()
|
178 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
179 |
+
for num_idx, each_node in enumerate(self.rhs.nodes):
|
180 |
+
hg.add_node(f"bond_{num_idx}",
|
181 |
+
#attr_dict=deepcopy(self.rhs.node_attr(each_node)))
|
182 |
+
attr_dict=self.rhs.node_attr(each_node))
|
183 |
+
node_map_rhs[each_node] = f"bond_{num_idx}"
|
184 |
+
for each_edge in self.rhs.edges:
|
185 |
+
node_list = []
|
186 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
187 |
+
node_list.append(node_map_rhs[each_node])
|
188 |
+
if isinstance(self.rhs.nodes_in_edge(each_edge), set):
|
189 |
+
node_list = set(node_list)
|
190 |
+
edge_id = hg.add_edge(
|
191 |
+
node_list,
|
192 |
+
#attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
|
193 |
+
attr_dict=self.rhs.edge_attr(each_edge))
|
194 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
195 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
196 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
197 |
+
return hg, nt_edge_list
|
198 |
+
else:
|
199 |
+
if edge not in hg.edges:
|
200 |
+
raise ValueError("the input hyperedge does not exist.")
|
201 |
+
if hg.edge_attr(edge)["terminal"]:
|
202 |
+
raise ValueError("the input hyperedge is terminal.")
|
203 |
+
if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
|
204 |
+
print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
|
205 |
+
raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
|
206 |
+
if DEBUG:
|
207 |
+
for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
208 |
+
other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
|
209 |
+
attr = deepcopy(self.lhs.node_attr(other_node))
|
210 |
+
attr.pop('ext_id')
|
211 |
+
if hg.node_attr(each_node) != attr:
|
212 |
+
raise ValueError('node attributes are inconsistent.')
|
213 |
+
|
214 |
+
# order of nodes that belong to the non-terminal edge in hg
|
215 |
+
nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
|
216 |
+
nt_order_dict_inv = {} # order -> hg_node
|
217 |
+
for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
218 |
+
nt_order_dict[each_node] = each_idx
|
219 |
+
nt_order_dict_inv[each_idx] = each_node
|
220 |
+
|
221 |
+
# construct a node_map_rhs: rhs -> new hg
|
222 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
223 |
+
node_idx = hg.num_nodes
|
224 |
+
for each_node in self.rhs.nodes:
|
225 |
+
if "ext_id" in self.rhs.node_attr(each_node):
|
226 |
+
node_map_rhs[each_node] \
|
227 |
+
= nt_order_dict_inv[
|
228 |
+
self.rhs.node_attr(each_node)["ext_id"]]
|
229 |
+
else:
|
230 |
+
node_map_rhs[each_node] = f"bond_{node_idx}"
|
231 |
+
node_idx += 1
|
232 |
+
|
233 |
+
# delete non-terminal
|
234 |
+
hg.remove_edge(edge)
|
235 |
+
|
236 |
+
# add nodes to hg
|
237 |
+
for each_node in self.rhs.nodes:
|
238 |
+
hg.add_node(node_map_rhs[each_node],
|
239 |
+
attr_dict=self.rhs.node_attr(each_node))
|
240 |
+
|
241 |
+
# add hyperedges to hg
|
242 |
+
for each_edge in self.rhs.edges:
|
243 |
+
node_list_hg = []
|
244 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
245 |
+
node_list_hg.append(node_map_rhs[each_node])
|
246 |
+
edge_id = hg.add_edge(
|
247 |
+
node_list_hg,
|
248 |
+
attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
|
249 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
250 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
251 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
252 |
+
return hg, nt_edge_list
|
253 |
+
|
254 |
+
def revert(self, hg: Hypergraph, return_subhg=False):
|
255 |
+
''' revert applying this production rule.
|
256 |
+
i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
|
257 |
+
this method replaces the subhypergraph with a non-terminal hyperedge.
|
258 |
+
|
259 |
+
Parameters
|
260 |
+
----------
|
261 |
+
hg : Hypergraph
|
262 |
+
hypergraph to be reverted
|
263 |
+
return_subhg : bool
|
264 |
+
if True, the removed subhypergraph will be returned.
|
265 |
+
|
266 |
+
Returns
|
267 |
+
-------
|
268 |
+
hg : Hypergraph
|
269 |
+
the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
|
270 |
+
success : bool
|
271 |
+
this indicates whether reverting is successed or not.
|
272 |
+
'''
|
273 |
+
gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
|
274 |
+
edge_match=_edge_match)
|
275 |
+
try:
|
276 |
+
# in case when the matched subhg is connected to the other part via external nodes and more.
|
277 |
+
not_iso = True
|
278 |
+
while not_iso:
|
279 |
+
isomap = next(gm.subgraph_isomorphisms_iter())
|
280 |
+
adj_node_set = set([]) # reachable nodes from the internal nodes
|
281 |
+
subhg_node_set = set(isomap.keys()) # nodes in subhg
|
282 |
+
for each_node in subhg_node_set:
|
283 |
+
adj_node_set.add(each_node)
|
284 |
+
if isomap[each_node] not in self.ext_node.values():
|
285 |
+
adj_node_set.update(hg.hg.adj[each_node])
|
286 |
+
if adj_node_set == subhg_node_set:
|
287 |
+
not_iso = False
|
288 |
+
else:
|
289 |
+
if return_subhg:
|
290 |
+
return hg, False, Hypergraph()
|
291 |
+
else:
|
292 |
+
return hg, False
|
293 |
+
inv_isomap = {v: k for k, v in isomap.items()}
|
294 |
+
'''
|
295 |
+
isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
|
296 |
+
'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
|
297 |
+
where keys come from `hg` and values come from `self.rhs`
|
298 |
+
'''
|
299 |
+
except StopIteration:
|
300 |
+
if return_subhg:
|
301 |
+
return hg, False, Hypergraph()
|
302 |
+
else:
|
303 |
+
return hg, False
|
304 |
+
|
305 |
+
if return_subhg:
|
306 |
+
subhg = Hypergraph()
|
307 |
+
for each_node in hg.nodes:
|
308 |
+
if each_node in isomap:
|
309 |
+
subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
|
310 |
+
for each_edge in hg.edges:
|
311 |
+
if each_edge in isomap:
|
312 |
+
subhg.add_edge(hg.nodes_in_edge(each_edge),
|
313 |
+
attr_dict=hg.edge_attr(each_edge),
|
314 |
+
edge_name=each_edge)
|
315 |
+
subhg.edge_idx = hg.edge_idx
|
316 |
+
|
317 |
+
# remove subhg except for the externael nodes
|
318 |
+
for each_key, each_val in isomap.items():
|
319 |
+
if each_key.startswith('e'):
|
320 |
+
hg.remove_edge(each_key)
|
321 |
+
for each_key, each_val in isomap.items():
|
322 |
+
if each_key.startswith('bond_'):
|
323 |
+
if each_val not in self.ext_node.values():
|
324 |
+
hg.remove_node(each_key)
|
325 |
+
|
326 |
+
# add non-terminal hyperedge
|
327 |
+
nt_node_list = []
|
328 |
+
for each_ext_id in self.ext_node.keys():
|
329 |
+
nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
|
330 |
+
|
331 |
+
hg.add_edge(nt_node_list,
|
332 |
+
attr_dict=dict(
|
333 |
+
terminal=False,
|
334 |
+
symbol=self.lhs_nt_symbol))
|
335 |
+
if return_subhg:
|
336 |
+
return hg, True, subhg
|
337 |
+
else:
|
338 |
+
return hg, True
|
339 |
+
|
340 |
+
|
341 |
+
class ProductionRuleCorpus(object):
|
342 |
+
|
343 |
+
'''
|
344 |
+
A corpus of production rules.
|
345 |
+
This class maintains
|
346 |
+
(i) list of unique production rules,
|
347 |
+
(ii) list of unique edge symbols (both terminal and non-terminal), and
|
348 |
+
(iii) list of unique node symbols.
|
349 |
+
|
350 |
+
Attributes
|
351 |
+
----------
|
352 |
+
prod_rule_list : list
|
353 |
+
list of unique production rules
|
354 |
+
edge_symbol_list : list
|
355 |
+
list of unique symbols (including both terminal and non-terminal)
|
356 |
+
node_symbol_list : list
|
357 |
+
list of node symbols
|
358 |
+
nt_symbol_list : list
|
359 |
+
list of unique lhs symbols
|
360 |
+
ext_id_list : list
|
361 |
+
list of ext_ids
|
362 |
+
lhs_in_prod_rule : array
|
363 |
+
a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
|
364 |
+
'''
|
365 |
+
|
366 |
+
def __init__(self):
|
367 |
+
self.prod_rule_list = []
|
368 |
+
self.edge_symbol_list = []
|
369 |
+
self.edge_symbol_dict = {}
|
370 |
+
self.node_symbol_list = []
|
371 |
+
self.node_symbol_dict = {}
|
372 |
+
self.nt_symbol_list = []
|
373 |
+
self.ext_id_list = []
|
374 |
+
self._lhs_in_prod_rule = None
|
375 |
+
self.lhs_in_prod_rule_row_list = []
|
376 |
+
self.lhs_in_prod_rule_col_list = []
|
377 |
+
|
378 |
+
@property
|
379 |
+
def lhs_in_prod_rule(self):
|
380 |
+
if self._lhs_in_prod_rule is None:
|
381 |
+
self._lhs_in_prod_rule = torch.sparse.FloatTensor(
|
382 |
+
torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
|
383 |
+
torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
|
384 |
+
torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
|
385 |
+
).to_dense()
|
386 |
+
return self._lhs_in_prod_rule
|
387 |
+
|
388 |
+
@property
|
389 |
+
def num_prod_rule(self):
|
390 |
+
''' return the number of production rules
|
391 |
+
|
392 |
+
Returns
|
393 |
+
-------
|
394 |
+
int : the number of unique production rules
|
395 |
+
'''
|
396 |
+
return len(self.prod_rule_list)
|
397 |
+
|
398 |
+
@property
|
399 |
+
def start_rule_list(self):
|
400 |
+
''' return a list of start rules
|
401 |
+
|
402 |
+
Returns
|
403 |
+
-------
|
404 |
+
list : list of start rules
|
405 |
+
'''
|
406 |
+
start_rule_list = []
|
407 |
+
for each_prod_rule in self.prod_rule_list:
|
408 |
+
if each_prod_rule.is_start_rule:
|
409 |
+
start_rule_list.append(each_prod_rule)
|
410 |
+
return start_rule_list
|
411 |
+
|
412 |
+
@property
|
413 |
+
def num_edge_symbol(self):
|
414 |
+
return len(self.edge_symbol_list)
|
415 |
+
|
416 |
+
@property
|
417 |
+
def num_node_symbol(self):
|
418 |
+
return len(self.node_symbol_list)
|
419 |
+
|
420 |
+
@property
|
421 |
+
def num_ext_id(self):
|
422 |
+
return len(self.ext_id_list)
|
423 |
+
|
424 |
+
def construct_feature_vectors(self):
|
425 |
+
''' this method constructs feature vectors for the production rules collected so far.
|
426 |
+
currently, NTSymbol and TSymbol are treated in the same manner.
|
427 |
+
'''
|
428 |
+
feature_id_dict = {}
|
429 |
+
feature_id_dict['TSymbol'] = 0
|
430 |
+
feature_id_dict['NTSymbol'] = 1
|
431 |
+
feature_id_dict['BondSymbol'] = 2
|
432 |
+
for each_edge_symbol in self.edge_symbol_list:
|
433 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
434 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
435 |
+
if isinstance(each_val, list):
|
436 |
+
each_val = tuple(each_val)
|
437 |
+
if (each_attr, each_val) not in feature_id_dict:
|
438 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
439 |
+
|
440 |
+
for each_node_symbol in self.node_symbol_list:
|
441 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
442 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
443 |
+
if isinstance(each_val, list):
|
444 |
+
each_val = tuple(each_val)
|
445 |
+
if (each_attr, each_val) not in feature_id_dict:
|
446 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
447 |
+
for each_ext_id in self.ext_id_list:
|
448 |
+
feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
|
449 |
+
dim = len(feature_id_dict)
|
450 |
+
|
451 |
+
feature_dict = {}
|
452 |
+
for each_edge_symbol in self.edge_symbol_list:
|
453 |
+
idx_list = []
|
454 |
+
idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
|
455 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
456 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
457 |
+
if isinstance(each_val, list):
|
458 |
+
each_val = tuple(each_val)
|
459 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
460 |
+
feature = torch.sparse.LongTensor(
|
461 |
+
torch.LongTensor([idx_list]),
|
462 |
+
torch.ones(len(idx_list)),
|
463 |
+
torch.Size([len(feature_id_dict)])
|
464 |
+
)
|
465 |
+
feature_dict[each_edge_symbol] = feature
|
466 |
+
|
467 |
+
for each_node_symbol in self.node_symbol_list:
|
468 |
+
idx_list = []
|
469 |
+
idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
|
470 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
471 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
472 |
+
if isinstance(each_val, list):
|
473 |
+
each_val = tuple(each_val)
|
474 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
475 |
+
feature = torch.sparse.LongTensor(
|
476 |
+
torch.LongTensor([idx_list]),
|
477 |
+
torch.ones(len(idx_list)),
|
478 |
+
torch.Size([len(feature_id_dict)])
|
479 |
+
)
|
480 |
+
feature_dict[each_node_symbol] = feature
|
481 |
+
for each_ext_id in self.ext_id_list:
|
482 |
+
idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
|
483 |
+
feature_dict[('ext_id', each_ext_id)] \
|
484 |
+
= torch.sparse.LongTensor(
|
485 |
+
torch.LongTensor([idx_list]),
|
486 |
+
torch.ones(len(idx_list)),
|
487 |
+
torch.Size([len(feature_id_dict)])
|
488 |
+
)
|
489 |
+
return feature_dict, dim
|
490 |
+
|
491 |
+
def edge_symbol_idx(self, symbol):
|
492 |
+
return self.edge_symbol_dict[symbol]
|
493 |
+
|
494 |
+
def node_symbol_idx(self, symbol):
|
495 |
+
return self.node_symbol_dict[symbol]
|
496 |
+
|
497 |
+
def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
|
498 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
499 |
+
Production rules are regarded as the same if
|
500 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
501 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
502 |
+
|
503 |
+
Parameters
|
504 |
+
----------
|
505 |
+
prod_rule : ProductionRule
|
506 |
+
|
507 |
+
Returns
|
508 |
+
-------
|
509 |
+
prod_rule_id : int
|
510 |
+
production rule index. if new, a new index will be assigned.
|
511 |
+
prod_rule : ProductionRule
|
512 |
+
"""
|
513 |
+
num_lhs = len(self.nt_symbol_list)
|
514 |
+
for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
|
515 |
+
is_same, isomap = prod_rule.is_same(each_prod_rule)
|
516 |
+
if is_same:
|
517 |
+
# we do not care about edge and node names, but care about the order of non-terminal edges.
|
518 |
+
for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
|
519 |
+
if key.startswith("bond_"):
|
520 |
+
continue
|
521 |
+
|
522 |
+
# rewrite `nt_idx` in `prod_rule` for further processing
|
523 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
|
524 |
+
if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
|
525 |
+
raise ValueError
|
526 |
+
prod_rule.rhs.set_edge_attr(
|
527 |
+
val,
|
528 |
+
{'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
|
529 |
+
return each_idx, prod_rule
|
530 |
+
self.prod_rule_list.append(prod_rule)
|
531 |
+
self._update_edge_symbol_list(prod_rule)
|
532 |
+
self._update_node_symbol_list(prod_rule)
|
533 |
+
self._update_ext_id_list(prod_rule)
|
534 |
+
|
535 |
+
lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
|
536 |
+
self.lhs_in_prod_rule_row_list.append(lhs_idx)
|
537 |
+
self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
|
538 |
+
self._lhs_in_prod_rule = None
|
539 |
+
return len(self.prod_rule_list)-1, prod_rule
|
540 |
+
|
541 |
+
def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
|
542 |
+
return self.prod_rule_list[prod_rule_idx]
|
543 |
+
|
544 |
+
def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
|
545 |
+
''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
|
546 |
+
|
547 |
+
Parameters
|
548 |
+
----------
|
549 |
+
unmasked_logit_array : array-like, length `num_prod_rule`
|
550 |
+
nt_symbol : NTSymbol
|
551 |
+
'''
|
552 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
553 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
554 |
+
if deterministic:
|
555 |
+
prob = masked_softmax(unmasked_logit_array,
|
556 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
557 |
+
return self.prod_rule_list[np.argmax(prob)]
|
558 |
+
else:
|
559 |
+
return np.random.choice(
|
560 |
+
self.prod_rule_list, 1,
|
561 |
+
p=masked_softmax(unmasked_logit_array,
|
562 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
|
563 |
+
|
564 |
+
def masked_logprob(self, unmasked_logit_array, nt_symbol):
|
565 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
566 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
567 |
+
prob = masked_softmax(unmasked_logit_array,
|
568 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
569 |
+
return np.log(prob)
|
570 |
+
|
571 |
+
def _update_edge_symbol_list(self, prod_rule: ProductionRule):
|
572 |
+
''' update edge symbol list
|
573 |
+
|
574 |
+
Parameters
|
575 |
+
----------
|
576 |
+
prod_rule : ProductionRule
|
577 |
+
'''
|
578 |
+
if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
|
579 |
+
self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
|
580 |
+
|
581 |
+
for each_edge in prod_rule.rhs.edges:
|
582 |
+
if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
|
583 |
+
edge_symbol_idx = len(self.edge_symbol_list)
|
584 |
+
self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
|
585 |
+
self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
|
586 |
+
else:
|
587 |
+
edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
588 |
+
prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
|
589 |
+
pass
|
590 |
+
|
591 |
+
def _update_node_symbol_list(self, prod_rule: ProductionRule):
|
592 |
+
''' update node symbol list
|
593 |
+
|
594 |
+
Parameters
|
595 |
+
----------
|
596 |
+
prod_rule : ProductionRule
|
597 |
+
'''
|
598 |
+
for each_node in prod_rule.rhs.nodes:
|
599 |
+
if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
|
600 |
+
node_symbol_idx = len(self.node_symbol_list)
|
601 |
+
self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
|
602 |
+
self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
|
603 |
+
else:
|
604 |
+
node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
|
605 |
+
prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
|
606 |
+
|
607 |
+
def _update_ext_id_list(self, prod_rule: ProductionRule):
|
608 |
+
for each_node in prod_rule.rhs.nodes:
|
609 |
+
if 'ext_id' in prod_rule.rhs.node_attr(each_node):
|
610 |
+
if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
|
611 |
+
self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
|
612 |
+
|
613 |
+
|
614 |
+
class HyperedgeReplacementGrammar(GraphGrammarBase):
|
615 |
+
"""
|
616 |
+
Learn a hyperedge replacement grammar from a set of hypergraphs.
|
617 |
+
|
618 |
+
Attributes
|
619 |
+
----------
|
620 |
+
prod_rule_list : list of ProductionRule
|
621 |
+
production rules learned from the input hypergraphs
|
622 |
+
"""
|
623 |
+
def __init__(self,
|
624 |
+
tree_decomposition=molecular_tree_decomposition,
|
625 |
+
ignore_order=False, **kwargs):
|
626 |
+
from functools import partial
|
627 |
+
self.prod_rule_corpus = ProductionRuleCorpus()
|
628 |
+
self.clique_tree_corpus = CliqueTreeCorpus()
|
629 |
+
self.ignore_order = ignore_order
|
630 |
+
self.tree_decomposition = partial(tree_decomposition, **kwargs)
|
631 |
+
|
632 |
+
@property
|
633 |
+
def num_prod_rule(self):
|
634 |
+
''' return the number of production rules
|
635 |
+
|
636 |
+
Returns
|
637 |
+
-------
|
638 |
+
int : the number of unique production rules
|
639 |
+
'''
|
640 |
+
return self.prod_rule_corpus.num_prod_rule
|
641 |
+
|
642 |
+
@property
|
643 |
+
def start_rule_list(self):
|
644 |
+
''' return a list of start rules
|
645 |
+
|
646 |
+
Returns
|
647 |
+
-------
|
648 |
+
list : list of start rules
|
649 |
+
'''
|
650 |
+
return self.prod_rule_corpus.start_rule_list
|
651 |
+
|
652 |
+
@property
|
653 |
+
def prod_rule_list(self):
|
654 |
+
return self.prod_rule_corpus.prod_rule_list
|
655 |
+
|
656 |
+
def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
|
657 |
+
""" learn from a list of hypergraphs
|
658 |
+
|
659 |
+
Parameters
|
660 |
+
----------
|
661 |
+
hg_list : list of Hypergraph
|
662 |
+
|
663 |
+
Returns
|
664 |
+
-------
|
665 |
+
prod_rule_seq_list : list of integers
|
666 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
667 |
+
"""
|
668 |
+
prod_rule_seq_list = []
|
669 |
+
idx = 0
|
670 |
+
for each_idx, each_hg in enumerate(hg_list):
|
671 |
+
clique_tree = self.tree_decomposition(each_hg)
|
672 |
+
|
673 |
+
# get a pair of myself and children
|
674 |
+
root_node = _find_root(clique_tree)
|
675 |
+
clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
|
676 |
+
prod_rule_seq = []
|
677 |
+
stack = []
|
678 |
+
|
679 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
680 |
+
|
681 |
+
# extract a temporary production rule
|
682 |
+
prod_rule = extract_prod_rule(
|
683 |
+
None,
|
684 |
+
clique_tree.nodes[root_node]["subhg"],
|
685 |
+
[clique_tree.nodes[each_child]["subhg"]
|
686 |
+
for each_child in children],
|
687 |
+
clique_tree.nodes[root_node].get('subhg_idx', None))
|
688 |
+
|
689 |
+
# update the production rule list
|
690 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
691 |
+
children = reorder_children(root_node,
|
692 |
+
children,
|
693 |
+
prod_rule,
|
694 |
+
clique_tree)
|
695 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
696 |
+
prod_rule_seq.append(prod_rule_id)
|
697 |
+
|
698 |
+
while len(stack) != 0:
|
699 |
+
# get a triple of parent, myself, and children
|
700 |
+
parent, myself = stack.pop()
|
701 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
702 |
+
children.remove(parent)
|
703 |
+
|
704 |
+
# extract a temp prod rule
|
705 |
+
prod_rule = extract_prod_rule(
|
706 |
+
clique_tree.nodes[parent]["subhg"],
|
707 |
+
clique_tree.nodes[myself]["subhg"],
|
708 |
+
[clique_tree.nodes[each_child]["subhg"]
|
709 |
+
for each_child in children],
|
710 |
+
clique_tree.nodes[myself].get('subhg_idx', None))
|
711 |
+
|
712 |
+
# update the prod rule list
|
713 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
714 |
+
children = reorder_children(myself,
|
715 |
+
children,
|
716 |
+
prod_rule,
|
717 |
+
clique_tree)
|
718 |
+
stack.extend([(myself, each_child)
|
719 |
+
for each_child in children[::-1]])
|
720 |
+
prod_rule_seq.append(prod_rule_id)
|
721 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
722 |
+
if (each_idx+1) % print_freq == 0:
|
723 |
+
msg = f'#(molecules processed)={each_idx+1}\t'\
|
724 |
+
f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
|
725 |
+
logger(msg)
|
726 |
+
if each_idx > max_mol:
|
727 |
+
break
|
728 |
+
|
729 |
+
print(f'corpus_size = {self.clique_tree_corpus.size}')
|
730 |
+
return prod_rule_seq_list
|
731 |
+
|
732 |
+
def sample(self, z, deterministic=False):
|
733 |
+
""" sample a new hypergraph from HRG.
|
734 |
+
|
735 |
+
Parameters
|
736 |
+
----------
|
737 |
+
z : array-like, shape (len, num_prod_rule)
|
738 |
+
logit
|
739 |
+
deterministic : bool
|
740 |
+
if True, deterministic sampling
|
741 |
+
|
742 |
+
Returns
|
743 |
+
-------
|
744 |
+
Hypergraph
|
745 |
+
"""
|
746 |
+
seq_idx = 0
|
747 |
+
stack = []
|
748 |
+
z = z[:, :-1]
|
749 |
+
init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
|
750 |
+
is_aromatic=False,
|
751 |
+
bond_symbol_list=[]),
|
752 |
+
deterministic=deterministic)
|
753 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
754 |
+
stack = deepcopy(nt_edge_list[::-1])
|
755 |
+
while len(stack) != 0 and seq_idx < z.shape[0]-1:
|
756 |
+
seq_idx += 1
|
757 |
+
nt_edge = stack.pop()
|
758 |
+
nt_symbol = hg.edge_attr(nt_edge)['symbol']
|
759 |
+
prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
|
760 |
+
hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
|
761 |
+
stack.extend(nt_edge_list[::-1])
|
762 |
+
if len(stack) != 0:
|
763 |
+
raise RuntimeError(f'{len(stack)} non-terminals are left.')
|
764 |
+
return hg
|
765 |
+
|
766 |
+
def construct(self, prod_rule_seq):
|
767 |
+
""" construct a hypergraph following `prod_rule_seq`
|
768 |
+
|
769 |
+
Parameters
|
770 |
+
----------
|
771 |
+
prod_rule_seq : list of integers
|
772 |
+
a sequence of production rules.
|
773 |
+
|
774 |
+
Returns
|
775 |
+
-------
|
776 |
+
UndirectedHypergraph
|
777 |
+
"""
|
778 |
+
seq_idx = 0
|
779 |
+
init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
|
780 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
781 |
+
stack = deepcopy(nt_edge_list[::-1])
|
782 |
+
while len(stack) != 0:
|
783 |
+
seq_idx += 1
|
784 |
+
nt_edge = stack.pop()
|
785 |
+
hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
|
786 |
+
stack.extend(nt_edge_list[::-1])
|
787 |
+
return hg
|
788 |
+
|
789 |
+
def update_prod_rule_list(self, prod_rule):
|
790 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
791 |
+
Production rules are regarded as the same if
|
792 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
793 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
794 |
+
|
795 |
+
Parameters
|
796 |
+
----------
|
797 |
+
prod_rule : ProductionRule
|
798 |
+
|
799 |
+
Returns
|
800 |
+
-------
|
801 |
+
is_new : bool
|
802 |
+
if True, this production rule is new
|
803 |
+
prod_rule_id : int
|
804 |
+
production rule index. if new, a new index will be assigned.
|
805 |
+
"""
|
806 |
+
return self.prod_rule_corpus.append(prod_rule)
|
807 |
+
|
808 |
+
|
809 |
+
class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
|
810 |
+
'''
|
811 |
+
This class learns HRG incrementally leveraging the previously obtained production rules.
|
812 |
+
'''
|
813 |
+
def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
|
814 |
+
self.prod_rule_list = []
|
815 |
+
self.tree_decomposition = tree_decomposition
|
816 |
+
self.ignore_order = ignore_order
|
817 |
+
|
818 |
+
def learn(self, hg_list):
|
819 |
+
""" learn from a list of hypergraphs
|
820 |
+
|
821 |
+
Parameters
|
822 |
+
----------
|
823 |
+
hg_list : list of UndirectedHypergraph
|
824 |
+
|
825 |
+
Returns
|
826 |
+
-------
|
827 |
+
prod_rule_seq_list : list of integers
|
828 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
829 |
+
"""
|
830 |
+
prod_rule_seq_list = []
|
831 |
+
for each_hg in hg_list:
|
832 |
+
clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
|
833 |
+
|
834 |
+
prod_rule_seq = []
|
835 |
+
stack = []
|
836 |
+
|
837 |
+
# get a pair of myself and children
|
838 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
839 |
+
|
840 |
+
# extract a temporary production rule
|
841 |
+
prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
|
842 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
843 |
+
|
844 |
+
# update the production rule list
|
845 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
846 |
+
children = reorder_children(root_node, children, prod_rule, clique_tree)
|
847 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
848 |
+
prod_rule_seq.append(prod_rule_id)
|
849 |
+
|
850 |
+
while len(stack) != 0:
|
851 |
+
# get a triple of parent, myself, and children
|
852 |
+
parent, myself = stack.pop()
|
853 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
854 |
+
children.remove(parent)
|
855 |
+
|
856 |
+
# extract a temp prod rule
|
857 |
+
prod_rule = extract_prod_rule(
|
858 |
+
clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
|
859 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
860 |
+
|
861 |
+
# update the prod rule list
|
862 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
863 |
+
children = reorder_children(myself, children, prod_rule, clique_tree)
|
864 |
+
stack.extend([(myself, each_child) for each_child in children[::-1]])
|
865 |
+
prod_rule_seq.append(prod_rule_id)
|
866 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
867 |
+
self._compute_stats()
|
868 |
+
return prod_rule_seq_list
|
869 |
+
|
870 |
+
|
871 |
+
def reorder_children(myself, children, prod_rule, clique_tree):
|
872 |
+
""" reorder children so that they match the order in `prod_rule`.
|
873 |
+
|
874 |
+
Parameters
|
875 |
+
----------
|
876 |
+
myself : int
|
877 |
+
children : list of int
|
878 |
+
prod_rule : ProductionRule
|
879 |
+
clique_tree : nx.Graph
|
880 |
+
|
881 |
+
Returns
|
882 |
+
-------
|
883 |
+
new_children : list of str
|
884 |
+
reordered children
|
885 |
+
"""
|
886 |
+
perm = {} # key : `nt_idx`, val : child
|
887 |
+
for each_edge in prod_rule.rhs.edges:
|
888 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
|
889 |
+
for each_child in children:
|
890 |
+
common_node_set = set(
|
891 |
+
common_node_list(clique_tree.nodes[myself]["subhg"],
|
892 |
+
clique_tree.nodes[each_child]["subhg"])[0])
|
893 |
+
if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
|
894 |
+
assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
|
895 |
+
perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
|
896 |
+
new_children = []
|
897 |
+
assert len(perm) == len(children)
|
898 |
+
for i in range(len(perm)):
|
899 |
+
new_children.append(perm[i])
|
900 |
+
return new_children
|
901 |
+
|
902 |
+
|
903 |
+
def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
|
904 |
+
""" extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
|
905 |
+
|
906 |
+
Parameters
|
907 |
+
----------
|
908 |
+
parent_hg : Hypergraph
|
909 |
+
myself_hg : Hypergraph
|
910 |
+
children_hg_list : list of Hypergraph
|
911 |
+
|
912 |
+
Returns
|
913 |
+
-------
|
914 |
+
ProductionRule, consisting of
|
915 |
+
lhs : Hypergraph or None
|
916 |
+
rhs : Hypergraph
|
917 |
+
"""
|
918 |
+
def _add_ext_node(hg, ext_nodes):
|
919 |
+
""" mark nodes to be external (ordered ids are assigned)
|
920 |
+
|
921 |
+
Parameters
|
922 |
+
----------
|
923 |
+
hg : UndirectedHypergraph
|
924 |
+
ext_nodes : list of str
|
925 |
+
list of external nodes
|
926 |
+
|
927 |
+
Returns
|
928 |
+
-------
|
929 |
+
hg : Hypergraph
|
930 |
+
nodes in `ext_nodes` are marked to be external
|
931 |
+
"""
|
932 |
+
ext_id = 0
|
933 |
+
ext_id_exists = []
|
934 |
+
for each_node in ext_nodes:
|
935 |
+
ext_id_exists.append('ext_id' in hg.node_attr(each_node))
|
936 |
+
if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
|
937 |
+
raise ValueError
|
938 |
+
if not all(ext_id_exists):
|
939 |
+
for each_node in ext_nodes:
|
940 |
+
hg.node_attr(each_node)['ext_id'] = ext_id
|
941 |
+
ext_id += 1
|
942 |
+
return hg
|
943 |
+
|
944 |
+
def _check_aromatic(hg, node_list):
|
945 |
+
is_aromatic = False
|
946 |
+
node_aromatic_list = []
|
947 |
+
for each_node in node_list:
|
948 |
+
if hg.node_attr(each_node)['symbol'].is_aromatic:
|
949 |
+
is_aromatic = True
|
950 |
+
node_aromatic_list.append(True)
|
951 |
+
else:
|
952 |
+
node_aromatic_list.append(False)
|
953 |
+
return is_aromatic, node_aromatic_list
|
954 |
+
|
955 |
+
def _check_ring(hg):
|
956 |
+
for each_edge in hg.edges:
|
957 |
+
if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
|
958 |
+
return False
|
959 |
+
return True
|
960 |
+
|
961 |
+
if parent_hg is None:
|
962 |
+
lhs = Hypergraph()
|
963 |
+
node_list = []
|
964 |
+
else:
|
965 |
+
lhs = Hypergraph()
|
966 |
+
node_list, edge_exists = common_node_list(parent_hg, myself_hg)
|
967 |
+
for each_node in node_list:
|
968 |
+
lhs.add_node(each_node,
|
969 |
+
deepcopy(myself_hg.node_attr(each_node)))
|
970 |
+
is_aromatic, _ = _check_aromatic(parent_hg, node_list)
|
971 |
+
for_ring = _check_ring(myself_hg)
|
972 |
+
bond_symbol_list = []
|
973 |
+
for each_node in node_list:
|
974 |
+
bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
|
975 |
+
lhs.add_edge(
|
976 |
+
node_list,
|
977 |
+
attr_dict=dict(
|
978 |
+
terminal=False,
|
979 |
+
edge_exists=edge_exists,
|
980 |
+
symbol=NTSymbol(
|
981 |
+
degree=len(node_list),
|
982 |
+
is_aromatic=is_aromatic,
|
983 |
+
bond_symbol_list=bond_symbol_list,
|
984 |
+
for_ring=for_ring)))
|
985 |
+
try:
|
986 |
+
lhs = _add_ext_node(lhs, node_list)
|
987 |
+
except ValueError:
|
988 |
+
import pdb; pdb.set_trace()
|
989 |
+
|
990 |
+
rhs = remove_tmp_edge(deepcopy(myself_hg))
|
991 |
+
#rhs = remove_ext_node(rhs)
|
992 |
+
#rhs = remove_nt_edge(rhs)
|
993 |
+
try:
|
994 |
+
rhs = _add_ext_node(rhs, node_list)
|
995 |
+
except ValueError:
|
996 |
+
import pdb; pdb.set_trace()
|
997 |
+
|
998 |
+
nt_idx = 0
|
999 |
+
if children_hg_list is not None:
|
1000 |
+
for each_child_hg in children_hg_list:
|
1001 |
+
node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
|
1002 |
+
is_aromatic, _ = _check_aromatic(myself_hg, node_list)
|
1003 |
+
for_ring = _check_ring(each_child_hg)
|
1004 |
+
bond_symbol_list = []
|
1005 |
+
for each_node in node_list:
|
1006 |
+
bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
|
1007 |
+
rhs.add_edge(
|
1008 |
+
node_list,
|
1009 |
+
attr_dict=dict(
|
1010 |
+
terminal=False,
|
1011 |
+
nt_idx=nt_idx,
|
1012 |
+
edge_exists=edge_exists,
|
1013 |
+
symbol=NTSymbol(degree=len(node_list),
|
1014 |
+
is_aromatic=is_aromatic,
|
1015 |
+
bond_symbol_list=bond_symbol_list,
|
1016 |
+
for_ring=for_ring)))
|
1017 |
+
nt_idx += 1
|
1018 |
+
prod_rule = ProductionRule(lhs, rhs)
|
1019 |
+
prod_rule.subhg_idx = subhg_idx
|
1020 |
+
if DEBUG:
|
1021 |
+
if sorted(list(prod_rule.ext_node.keys())) \
|
1022 |
+
!= list(np.arange(len(prod_rule.ext_node))):
|
1023 |
+
raise RuntimeError('ext_id is not continuous')
|
1024 |
+
return prod_rule
|
1025 |
+
|
1026 |
+
|
1027 |
+
def _find_root(clique_tree):
|
1028 |
+
max_node = None
|
1029 |
+
num_nodes_max = -np.inf
|
1030 |
+
for each_node in clique_tree.nodes:
|
1031 |
+
if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
|
1032 |
+
max_node = each_node
|
1033 |
+
num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
|
1034 |
+
'''
|
1035 |
+
children = sorted(list(clique_tree[each_node].keys()))
|
1036 |
+
prod_rule = extract_prod_rule(None,
|
1037 |
+
clique_tree.nodes[each_node]["subhg"],
|
1038 |
+
[clique_tree.nodes[each_child]["subhg"]
|
1039 |
+
for each_child in children])
|
1040 |
+
for each_start_rule in start_rule_list:
|
1041 |
+
if prod_rule.is_same(each_start_rule):
|
1042 |
+
return each_node
|
1043 |
+
'''
|
1044 |
+
return max_node
|
1045 |
+
|
1046 |
+
def remove_ext_node(hg):
|
1047 |
+
for each_node in hg.nodes:
|
1048 |
+
hg.node_attr(each_node).pop('ext_id', None)
|
1049 |
+
return hg
|
1050 |
+
|
1051 |
+
def remove_nt_edge(hg):
|
1052 |
+
remove_edge_list = []
|
1053 |
+
for each_edge in hg.edges:
|
1054 |
+
if not hg.edge_attr(each_edge)['terminal']:
|
1055 |
+
remove_edge_list.append(each_edge)
|
1056 |
+
hg.remove_edges(remove_edge_list)
|
1057 |
+
return hg
|
1058 |
+
|
1059 |
+
def remove_tmp_edge(hg):
|
1060 |
+
remove_edge_list = []
|
1061 |
+
for each_edge in hg.edges:
|
1062 |
+
if hg.edge_attr(each_edge).get('tmp', False):
|
1063 |
+
remove_edge_list.append(each_edge)
|
1064 |
+
hg.remove_edges(remove_edge_list)
|
1065 |
+
return hg
|
models/mhg_model/graph_grammar/graph_grammar/symbols.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
""" Title """
|
16 |
+
|
17 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
18 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
19 |
+
__version__ = "0.1"
|
20 |
+
__date__ = "Jan 1 2018"
|
21 |
+
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
class TSymbol(object):
|
25 |
+
|
26 |
+
''' terminal symbol
|
27 |
+
|
28 |
+
Attributes
|
29 |
+
----------
|
30 |
+
degree : int
|
31 |
+
the number of nodes in a hyperedge
|
32 |
+
is_aromatic : bool
|
33 |
+
whether or not the hyperedge is in an aromatic ring
|
34 |
+
symbol : str
|
35 |
+
atomic symbol
|
36 |
+
num_explicit_Hs : int
|
37 |
+
the number of hydrogens associated to this hyperedge
|
38 |
+
formal_charge : int
|
39 |
+
charge
|
40 |
+
chirality : int
|
41 |
+
chirality
|
42 |
+
'''
|
43 |
+
|
44 |
+
def __init__(self, degree, is_aromatic,
|
45 |
+
symbol, num_explicit_Hs, formal_charge, chirality):
|
46 |
+
self.degree = degree
|
47 |
+
self.is_aromatic = is_aromatic
|
48 |
+
self.symbol = symbol
|
49 |
+
self.num_explicit_Hs = num_explicit_Hs
|
50 |
+
self.formal_charge = formal_charge
|
51 |
+
self.chirality = chirality
|
52 |
+
|
53 |
+
@property
|
54 |
+
def terminal(self):
|
55 |
+
return True
|
56 |
+
|
57 |
+
def __eq__(self, other):
|
58 |
+
if not isinstance(other, TSymbol):
|
59 |
+
return False
|
60 |
+
if self.degree != other.degree:
|
61 |
+
return False
|
62 |
+
if self.is_aromatic != other.is_aromatic:
|
63 |
+
return False
|
64 |
+
if self.symbol != other.symbol:
|
65 |
+
return False
|
66 |
+
if self.num_explicit_Hs != other.num_explicit_Hs:
|
67 |
+
return False
|
68 |
+
if self.formal_charge != other.formal_charge:
|
69 |
+
return False
|
70 |
+
if self.chirality != other.chirality:
|
71 |
+
return False
|
72 |
+
return True
|
73 |
+
|
74 |
+
def __hash__(self):
|
75 |
+
return self.__str__().__hash__()
|
76 |
+
|
77 |
+
def __str__(self):
|
78 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
79 |
+
f'symbol={self.symbol}, '\
|
80 |
+
f'num_explicit_Hs={self.num_explicit_Hs}, '\
|
81 |
+
f'formal_charge={self.formal_charge}, chirality={self.chirality}'
|
82 |
+
|
83 |
+
|
84 |
+
class NTSymbol(object):
|
85 |
+
|
86 |
+
''' non-terminal symbol
|
87 |
+
|
88 |
+
Attributes
|
89 |
+
----------
|
90 |
+
degree : int
|
91 |
+
degree of the hyperedge
|
92 |
+
is_aromatic : bool
|
93 |
+
if True, at least one of the associated bonds must be aromatic.
|
94 |
+
node_aromatic_list : list of bool
|
95 |
+
indicate whether each of the nodes is aromatic or not.
|
96 |
+
bond_type_list : list of int
|
97 |
+
bond type of each node"
|
98 |
+
'''
|
99 |
+
|
100 |
+
def __init__(self, degree: int, is_aromatic: bool,
|
101 |
+
bond_symbol_list: list,
|
102 |
+
for_ring=False):
|
103 |
+
self.degree = degree
|
104 |
+
self.is_aromatic = is_aromatic
|
105 |
+
self.for_ring = for_ring
|
106 |
+
self.bond_symbol_list = bond_symbol_list
|
107 |
+
|
108 |
+
@property
|
109 |
+
def terminal(self) -> bool:
|
110 |
+
return False
|
111 |
+
|
112 |
+
@property
|
113 |
+
def symbol(self):
|
114 |
+
return f'NT{self.degree}'
|
115 |
+
|
116 |
+
def __eq__(self, other) -> bool:
|
117 |
+
if not isinstance(other, NTSymbol):
|
118 |
+
return False
|
119 |
+
|
120 |
+
if self.degree != other.degree:
|
121 |
+
return False
|
122 |
+
if self.is_aromatic != other.is_aromatic:
|
123 |
+
return False
|
124 |
+
if self.for_ring != other.for_ring:
|
125 |
+
return False
|
126 |
+
if len(self.bond_symbol_list) != len(other.bond_symbol_list):
|
127 |
+
return False
|
128 |
+
for each_idx in range(len(self.bond_symbol_list)):
|
129 |
+
if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
|
130 |
+
return False
|
131 |
+
return True
|
132 |
+
|
133 |
+
def __hash__(self):
|
134 |
+
return self.__str__().__hash__()
|
135 |
+
|
136 |
+
def __str__(self) -> str:
|
137 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
138 |
+
f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
|
139 |
+
f'for_ring={self.for_ring}'
|
140 |
+
|
141 |
+
|
142 |
+
class BondSymbol(object):
|
143 |
+
|
144 |
+
|
145 |
+
''' Bond symbol
|
146 |
+
|
147 |
+
Attributes
|
148 |
+
----------
|
149 |
+
is_aromatic : bool
|
150 |
+
if True, at least one of the associated bonds must be aromatic.
|
151 |
+
bond_type : int
|
152 |
+
bond type of each node"
|
153 |
+
'''
|
154 |
+
|
155 |
+
def __init__(self, is_aromatic: bool,
|
156 |
+
bond_type: int,
|
157 |
+
stereo: int):
|
158 |
+
self.is_aromatic = is_aromatic
|
159 |
+
self.bond_type = bond_type
|
160 |
+
self.stereo = stereo
|
161 |
+
|
162 |
+
def __eq__(self, other) -> bool:
|
163 |
+
if not isinstance(other, BondSymbol):
|
164 |
+
return False
|
165 |
+
|
166 |
+
if self.is_aromatic != other.is_aromatic:
|
167 |
+
return False
|
168 |
+
if self.bond_type != other.bond_type:
|
169 |
+
return False
|
170 |
+
if self.stereo != other.stereo:
|
171 |
+
return False
|
172 |
+
return True
|
173 |
+
|
174 |
+
def __hash__(self):
|
175 |
+
return self.__str__().__hash__()
|
176 |
+
|
177 |
+
def __str__(self) -> str:
|
178 |
+
return f'is_aromatic={self.is_aromatic}, '\
|
179 |
+
f'bond_type={self.bond_type}, '\
|
180 |
+
f'stereo={self.stereo}, '
|
models/mhg_model/graph_grammar/graph_grammar/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from ..hypergraph import Hypergraph
|
22 |
+
from copy import deepcopy
|
23 |
+
from typing import List
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
|
27 |
+
def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
|
28 |
+
""" return a list of common nodes
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
hg1, hg2 : Hypergraph
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
list of str
|
37 |
+
list of common nodes
|
38 |
+
"""
|
39 |
+
if hg1 is None or hg2 is None:
|
40 |
+
return [], False
|
41 |
+
else:
|
42 |
+
node_set = hg1.nodes.intersection(hg2.nodes)
|
43 |
+
node_dict = {}
|
44 |
+
if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
|
45 |
+
for each_node in node_set:
|
46 |
+
node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
|
47 |
+
else:
|
48 |
+
for each_node in node_set:
|
49 |
+
node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
|
50 |
+
node_list = []
|
51 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
52 |
+
node_list.append(each_key)
|
53 |
+
edge_name = hg1.has_edge(node_list, ignore_order=True)
|
54 |
+
if edge_name:
|
55 |
+
if not hg1.edge_attr(edge_name).get('terminal', True):
|
56 |
+
node_list = hg1.nodes_in_edge(edge_name)
|
57 |
+
return node_list, True
|
58 |
+
else:
|
59 |
+
return node_list, False
|
60 |
+
|
61 |
+
|
62 |
+
def _node_match(node1, node2):
|
63 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
64 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
65 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
66 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
67 |
+
# bond_symbol
|
68 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
def _easy_node_match(node1, node2):
|
73 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
74 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
75 |
+
return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
|
76 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
77 |
+
# bond_symbol
|
78 |
+
return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
|
79 |
+
and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
80 |
+
else:
|
81 |
+
return False
|
82 |
+
|
83 |
+
|
84 |
+
def _node_match_prod_rule(node1, node2, ignore_order=False):
|
85 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
86 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
87 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
88 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
89 |
+
# ext_id, order4hrg, bond_symbol
|
90 |
+
if ignore_order:
|
91 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
92 |
+
else:
|
93 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
|
94 |
+
and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
|
95 |
+
else:
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
def _edge_match(edge1, edge2, ignore_order=False):
|
100 |
+
#return True
|
101 |
+
if ignore_order:
|
102 |
+
return True
|
103 |
+
else:
|
104 |
+
return edge1["order"] == edge2["order"]
|
105 |
+
|
106 |
+
def masked_softmax(logit, mask):
|
107 |
+
''' compute a probability distribution from logit
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
logit : array-like, length D
|
112 |
+
each element indicates how each dimension is likely to be chosen
|
113 |
+
(the larger, the more likely)
|
114 |
+
mask : array-like, length D
|
115 |
+
each element is either 0 or 1.
|
116 |
+
if 0, the dimension is ignored
|
117 |
+
when computing the probability distribution.
|
118 |
+
|
119 |
+
Returns
|
120 |
+
-------
|
121 |
+
prob_dist : array, length D
|
122 |
+
probability distribution computed from logit.
|
123 |
+
if `mask[d] = 0`, `prob_dist[d] = 0`.
|
124 |
+
'''
|
125 |
+
if logit.shape != mask.shape:
|
126 |
+
raise ValueError('logit and mask must have the same shape')
|
127 |
+
c = np.max(logit)
|
128 |
+
exp_logit = np.exp(logit - c) * mask
|
129 |
+
sum_exp_logit = exp_logit @ mask
|
130 |
+
return exp_logit / sum_exp_logit
|
models/mhg_model/graph_grammar/hypergraph.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 31 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from typing import List, Dict, Tuple
|
23 |
+
import networkx as nx
|
24 |
+
import numpy as np
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class Hypergraph(object):
|
29 |
+
'''
|
30 |
+
A class of a hypergraph.
|
31 |
+
Each hyperedge can be ordered. For the ordered case,
|
32 |
+
edges adjacent to the hyperedge node are labeled by their orders.
|
33 |
+
|
34 |
+
Attributes
|
35 |
+
----------
|
36 |
+
hg : nx.Graph
|
37 |
+
a bipartite graph representation of a hypergraph
|
38 |
+
edge_idx : int
|
39 |
+
total number of hyperedges that exist so far
|
40 |
+
'''
|
41 |
+
def __init__(self):
|
42 |
+
self.hg = nx.Graph()
|
43 |
+
self.edge_idx = 0
|
44 |
+
self.nodes = set([])
|
45 |
+
self.num_nodes = 0
|
46 |
+
self.edges = set([])
|
47 |
+
self.num_edges = 0
|
48 |
+
self.nodes_in_edge_dict = {}
|
49 |
+
|
50 |
+
def add_node(self, node: str, attr_dict=None):
|
51 |
+
''' add a node to hypergraph
|
52 |
+
|
53 |
+
Parameters
|
54 |
+
----------
|
55 |
+
node : str
|
56 |
+
node name
|
57 |
+
attr_dict : dict
|
58 |
+
dictionary of node attributes
|
59 |
+
'''
|
60 |
+
self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
|
61 |
+
if node not in self.nodes:
|
62 |
+
self.num_nodes += 1
|
63 |
+
self.nodes.add(node)
|
64 |
+
|
65 |
+
def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
|
66 |
+
''' add an edge consisting of nodes `node_list`
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
node_list : list
|
71 |
+
ordered list of nodes that consist the edge
|
72 |
+
attr_dict : dict
|
73 |
+
dictionary of edge attributes
|
74 |
+
'''
|
75 |
+
if edge_name is None:
|
76 |
+
edge = 'e{}'.format(self.edge_idx)
|
77 |
+
else:
|
78 |
+
assert edge_name not in self.edges
|
79 |
+
edge = edge_name
|
80 |
+
self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
|
81 |
+
if edge not in self.edges:
|
82 |
+
self.num_edges += 1
|
83 |
+
self.edges.add(edge)
|
84 |
+
self.nodes_in_edge_dict[edge] = node_list
|
85 |
+
if type(node_list) == list:
|
86 |
+
for node_idx, each_node in enumerate(node_list):
|
87 |
+
self.hg.add_edge(edge, each_node, order=node_idx)
|
88 |
+
if each_node not in self.nodes:
|
89 |
+
self.num_nodes += 1
|
90 |
+
self.nodes.add(each_node)
|
91 |
+
|
92 |
+
elif type(node_list) == set:
|
93 |
+
for each_node in node_list:
|
94 |
+
self.hg.add_edge(edge, each_node, order=-1)
|
95 |
+
if each_node not in self.nodes:
|
96 |
+
self.num_nodes += 1
|
97 |
+
self.nodes.add(each_node)
|
98 |
+
else:
|
99 |
+
raise ValueError
|
100 |
+
self.edge_idx += 1
|
101 |
+
return edge
|
102 |
+
|
103 |
+
def remove_node(self, node: str, remove_connected_edges=True):
|
104 |
+
''' remove a node
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
node : str
|
109 |
+
node name
|
110 |
+
remove_connected_edges : bool
|
111 |
+
if True, remove edges that are adjacent to the node
|
112 |
+
'''
|
113 |
+
if remove_connected_edges:
|
114 |
+
connected_edges = deepcopy(self.adj_edges(node))
|
115 |
+
for each_edge in connected_edges:
|
116 |
+
self.remove_edge(each_edge)
|
117 |
+
self.hg.remove_node(node)
|
118 |
+
self.num_nodes -= 1
|
119 |
+
self.nodes.remove(node)
|
120 |
+
|
121 |
+
def remove_nodes(self, node_iter, remove_connected_edges=True):
|
122 |
+
''' remove a set of nodes
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
node_iter : iterator of strings
|
127 |
+
nodes to be removed
|
128 |
+
remove_connected_edges : bool
|
129 |
+
if True, remove edges that are adjacent to the node
|
130 |
+
'''
|
131 |
+
for each_node in node_iter:
|
132 |
+
self.remove_node(each_node, remove_connected_edges)
|
133 |
+
|
134 |
+
def remove_edge(self, edge: str):
|
135 |
+
''' remove an edge
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
edge : str
|
140 |
+
edge to be removed
|
141 |
+
'''
|
142 |
+
self.hg.remove_node(edge)
|
143 |
+
self.edges.remove(edge)
|
144 |
+
self.num_edges -= 1
|
145 |
+
self.nodes_in_edge_dict.pop(edge)
|
146 |
+
|
147 |
+
def remove_edges(self, edge_iter):
|
148 |
+
''' remove a set of edges
|
149 |
+
|
150 |
+
Parameters
|
151 |
+
----------
|
152 |
+
edge_iter : iterator of strings
|
153 |
+
edges to be removed
|
154 |
+
'''
|
155 |
+
for each_edge in edge_iter:
|
156 |
+
self.remove_edge(each_edge)
|
157 |
+
|
158 |
+
def remove_edges_with_attr(self, edge_attr_dict):
|
159 |
+
remove_edge_list = []
|
160 |
+
for each_edge in self.edges:
|
161 |
+
satisfy = True
|
162 |
+
for each_key, each_val in edge_attr_dict.items():
|
163 |
+
if not satisfy:
|
164 |
+
break
|
165 |
+
try:
|
166 |
+
if self.edge_attr(each_edge)[each_key] != each_val:
|
167 |
+
satisfy = False
|
168 |
+
except KeyError:
|
169 |
+
satisfy = False
|
170 |
+
if satisfy:
|
171 |
+
remove_edge_list.append(each_edge)
|
172 |
+
self.remove_edges(remove_edge_list)
|
173 |
+
|
174 |
+
def remove_subhg(self, subhg):
|
175 |
+
''' remove subhypergraph.
|
176 |
+
all of the hyperedges are removed.
|
177 |
+
each node of subhg is removed if its degree becomes 0 after removing hyperedges.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
subhg : Hypergraph
|
182 |
+
'''
|
183 |
+
for each_edge in subhg.edges:
|
184 |
+
self.remove_edge(each_edge)
|
185 |
+
for each_node in subhg.nodes:
|
186 |
+
if self.degree(each_node) == 0:
|
187 |
+
self.remove_node(each_node)
|
188 |
+
|
189 |
+
def nodes_in_edge(self, edge):
|
190 |
+
''' return an ordered list of nodes in a given edge.
|
191 |
+
|
192 |
+
Parameters
|
193 |
+
----------
|
194 |
+
edge : str
|
195 |
+
edge whose nodes are returned
|
196 |
+
|
197 |
+
Returns
|
198 |
+
-------
|
199 |
+
list or set
|
200 |
+
ordered list or set of nodes that belong to the edge
|
201 |
+
'''
|
202 |
+
if edge.startswith('e'):
|
203 |
+
return self.nodes_in_edge_dict[edge]
|
204 |
+
else:
|
205 |
+
adj_node_list = self.hg.adj[edge]
|
206 |
+
adj_node_order_list = []
|
207 |
+
adj_node_name_list = []
|
208 |
+
for each_node in adj_node_list:
|
209 |
+
adj_node_order_list.append(adj_node_list[each_node]['order'])
|
210 |
+
adj_node_name_list.append(each_node)
|
211 |
+
if adj_node_order_list == [-1] * len(adj_node_order_list):
|
212 |
+
return set(adj_node_name_list)
|
213 |
+
else:
|
214 |
+
return [adj_node_name_list[each_idx] for each_idx
|
215 |
+
in np.argsort(adj_node_order_list)]
|
216 |
+
|
217 |
+
def adj_edges(self, node):
|
218 |
+
''' return a dict of adjacent hyperedges
|
219 |
+
|
220 |
+
Parameters
|
221 |
+
----------
|
222 |
+
node : str
|
223 |
+
|
224 |
+
Returns
|
225 |
+
-------
|
226 |
+
set
|
227 |
+
set of edges that are adjacent to `node`
|
228 |
+
'''
|
229 |
+
return self.hg.adj[node]
|
230 |
+
|
231 |
+
def adj_nodes(self, node):
|
232 |
+
''' return a set of adjacent nodes
|
233 |
+
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
node : str
|
237 |
+
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
set
|
241 |
+
set of nodes that are adjacent to `node`
|
242 |
+
'''
|
243 |
+
node_set = set([])
|
244 |
+
for each_adj_edge in self.adj_edges(node):
|
245 |
+
node_set.update(set(self.nodes_in_edge(each_adj_edge)))
|
246 |
+
node_set.discard(node)
|
247 |
+
return node_set
|
248 |
+
|
249 |
+
def has_edge(self, node_list, ignore_order=False):
|
250 |
+
for each_edge in self.edges:
|
251 |
+
if ignore_order:
|
252 |
+
if set(self.nodes_in_edge(each_edge)) == set(node_list):
|
253 |
+
return each_edge
|
254 |
+
else:
|
255 |
+
if self.nodes_in_edge(each_edge) == node_list:
|
256 |
+
return each_edge
|
257 |
+
return False
|
258 |
+
|
259 |
+
def degree(self, node):
|
260 |
+
return len(self.hg.adj[node])
|
261 |
+
|
262 |
+
def degrees(self):
|
263 |
+
return {each_node: self.degree(each_node) for each_node in self.nodes}
|
264 |
+
|
265 |
+
def edge_degree(self, edge):
|
266 |
+
return len(self.nodes_in_edge(edge))
|
267 |
+
|
268 |
+
def edge_degrees(self):
|
269 |
+
return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
|
270 |
+
|
271 |
+
def is_adj(self, node1, node2):
|
272 |
+
return node1 in self.adj_nodes(node2)
|
273 |
+
|
274 |
+
def adj_subhg(self, node, ident_node_dict=None):
|
275 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
276 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
277 |
+
|
278 |
+
Parameters
|
279 |
+
----------
|
280 |
+
node : str
|
281 |
+
ident_node_dict : dict
|
282 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
283 |
+
|
284 |
+
Returns
|
285 |
+
-------
|
286 |
+
subhg : Hypergraph
|
287 |
+
"""
|
288 |
+
if ident_node_dict is None:
|
289 |
+
ident_node_dict = self.get_identical_node_dict()
|
290 |
+
adj_node_set = set(ident_node_dict[node])
|
291 |
+
adj_edge_set = set([])
|
292 |
+
for each_node in ident_node_dict[node]:
|
293 |
+
adj_edge_set.update(set(self.adj_edges(each_node)))
|
294 |
+
fixed_adj_edge_set = deepcopy(adj_edge_set)
|
295 |
+
for each_edge in fixed_adj_edge_set:
|
296 |
+
other_nodes = self.nodes_in_edge(each_edge)
|
297 |
+
adj_node_set.update(other_nodes)
|
298 |
+
|
299 |
+
# if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
|
300 |
+
for each_node in other_nodes:
|
301 |
+
for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
|
302 |
+
if len(set(self.nodes_in_edge(other_edge)) \
|
303 |
+
- set(self.nodes_in_edge(each_edge))) == 0:
|
304 |
+
adj_edge_set.update(set([other_edge]))
|
305 |
+
subhg = Hypergraph()
|
306 |
+
for each_node in adj_node_set:
|
307 |
+
subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
|
308 |
+
for each_edge in adj_edge_set:
|
309 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
310 |
+
attr_dict=self.edge_attr(each_edge),
|
311 |
+
edge_name=each_edge)
|
312 |
+
subhg.edge_idx = self.edge_idx
|
313 |
+
return subhg
|
314 |
+
|
315 |
+
def get_subhg(self, node_list, edge_list, ident_node_dict=None):
|
316 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
317 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
318 |
+
|
319 |
+
Parameters
|
320 |
+
----------
|
321 |
+
node : str
|
322 |
+
ident_node_dict : dict
|
323 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
324 |
+
|
325 |
+
Returns
|
326 |
+
-------
|
327 |
+
subhg : Hypergraph
|
328 |
+
"""
|
329 |
+
if ident_node_dict is None:
|
330 |
+
ident_node_dict = self.get_identical_node_dict()
|
331 |
+
adj_node_set = set([])
|
332 |
+
for each_node in node_list:
|
333 |
+
adj_node_set.update(set(ident_node_dict[each_node]))
|
334 |
+
adj_edge_set = set(edge_list)
|
335 |
+
|
336 |
+
subhg = Hypergraph()
|
337 |
+
for each_node in adj_node_set:
|
338 |
+
subhg.add_node(each_node,
|
339 |
+
attr_dict=deepcopy(self.node_attr(each_node)))
|
340 |
+
for each_edge in adj_edge_set:
|
341 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
342 |
+
attr_dict=deepcopy(self.edge_attr(each_edge)),
|
343 |
+
edge_name=each_edge)
|
344 |
+
subhg.edge_idx = self.edge_idx
|
345 |
+
return subhg
|
346 |
+
|
347 |
+
def copy(self):
|
348 |
+
''' return a copy of the object
|
349 |
+
|
350 |
+
Returns
|
351 |
+
-------
|
352 |
+
Hypergraph
|
353 |
+
'''
|
354 |
+
return deepcopy(self)
|
355 |
+
|
356 |
+
def node_attr(self, node):
|
357 |
+
return self.hg.nodes[node]['attr_dict']
|
358 |
+
|
359 |
+
def edge_attr(self, edge):
|
360 |
+
return self.hg.nodes[edge]['attr_dict']
|
361 |
+
|
362 |
+
def set_node_attr(self, node, attr_dict):
|
363 |
+
for each_key, each_val in attr_dict.items():
|
364 |
+
self.hg.nodes[node]['attr_dict'][each_key] = each_val
|
365 |
+
|
366 |
+
def set_edge_attr(self, edge, attr_dict):
|
367 |
+
for each_key, each_val in attr_dict.items():
|
368 |
+
self.hg.nodes[edge]['attr_dict'][each_key] = each_val
|
369 |
+
|
370 |
+
def get_identical_node_dict(self):
|
371 |
+
''' get identical nodes
|
372 |
+
nodes are identical if they share the same set of adjacent edges.
|
373 |
+
|
374 |
+
Returns
|
375 |
+
-------
|
376 |
+
ident_node_dict : dict
|
377 |
+
ident_node_dict[node] returns a list of nodes that are identical to `node`.
|
378 |
+
'''
|
379 |
+
ident_node_dict = {}
|
380 |
+
for each_node in self.nodes:
|
381 |
+
ident_node_list = []
|
382 |
+
for each_other_node in self.nodes:
|
383 |
+
if each_other_node == each_node:
|
384 |
+
ident_node_list.append(each_other_node)
|
385 |
+
elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
|
386 |
+
and len(self.adj_edges(each_node)) != 0:
|
387 |
+
ident_node_list.append(each_other_node)
|
388 |
+
ident_node_dict[each_node] = ident_node_list
|
389 |
+
return ident_node_dict
|
390 |
+
'''
|
391 |
+
ident_node_dict = {}
|
392 |
+
for each_node in self.nodes:
|
393 |
+
ident_node_dict[each_node] = [each_node]
|
394 |
+
return ident_node_dict
|
395 |
+
'''
|
396 |
+
|
397 |
+
def get_leaf_edge(self):
|
398 |
+
''' get an edge that is incident only to one edge
|
399 |
+
|
400 |
+
Returns
|
401 |
+
-------
|
402 |
+
if exists, return a leaf edge. otherwise, return None.
|
403 |
+
'''
|
404 |
+
for each_edge in self.edges:
|
405 |
+
if len(self.adj_nodes(each_edge)) == 1:
|
406 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
407 |
+
return each_edge
|
408 |
+
return None
|
409 |
+
|
410 |
+
def get_nontmp_edge(self):
|
411 |
+
for each_edge in self.edges:
|
412 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
413 |
+
return each_edge
|
414 |
+
return None
|
415 |
+
|
416 |
+
def is_subhg(self, hg):
|
417 |
+
''' return whether this hypergraph is a subhypergraph of `hg`
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
True if self \in hg,
|
422 |
+
False otherwise.
|
423 |
+
'''
|
424 |
+
for each_node in self.nodes:
|
425 |
+
if each_node not in hg.nodes:
|
426 |
+
return False
|
427 |
+
for each_edge in self.edges:
|
428 |
+
if each_edge not in hg.edges:
|
429 |
+
return False
|
430 |
+
return True
|
431 |
+
|
432 |
+
def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
|
433 |
+
''' if `node` is in a cycle, then return True. otherwise, False.
|
434 |
+
|
435 |
+
Parameters
|
436 |
+
----------
|
437 |
+
node : str
|
438 |
+
node in a hypergraph
|
439 |
+
visited : list
|
440 |
+
list of visited nodes, used for recursion
|
441 |
+
parent : str
|
442 |
+
parent node, used to eliminate a cycle consisting of two nodes and one edge.
|
443 |
+
|
444 |
+
Returns
|
445 |
+
-------
|
446 |
+
bool
|
447 |
+
'''
|
448 |
+
if visited is None:
|
449 |
+
visited = []
|
450 |
+
if parent == '':
|
451 |
+
visited = []
|
452 |
+
if root_node == '':
|
453 |
+
root_node = node
|
454 |
+
visited.append(node)
|
455 |
+
for each_adj_node in self.adj_nodes(node):
|
456 |
+
if each_adj_node not in visited:
|
457 |
+
if self.in_cycle(each_adj_node, visited, node, root_node):
|
458 |
+
return True
|
459 |
+
elif each_adj_node != parent and each_adj_node == root_node:
|
460 |
+
return True
|
461 |
+
return False
|
462 |
+
|
463 |
+
|
464 |
+
def draw(self, file_path=None, with_node=False, with_edge_name=False):
|
465 |
+
''' draw hypergraph
|
466 |
+
'''
|
467 |
+
import graphviz
|
468 |
+
G = graphviz.Graph(format='png')
|
469 |
+
for each_node in self.nodes:
|
470 |
+
if 'ext_id' in self.node_attr(each_node):
|
471 |
+
G.node(each_node, label='',
|
472 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
473 |
+
fillcolor='black')
|
474 |
+
else:
|
475 |
+
if with_node:
|
476 |
+
G.node(each_node, label='',
|
477 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
478 |
+
fillcolor='gray')
|
479 |
+
edge_list = []
|
480 |
+
for each_edge in self.edges:
|
481 |
+
if self.edge_attr(each_edge).get('terminal', False):
|
482 |
+
G.node(each_edge,
|
483 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
484 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
485 |
+
fontcolor='black', shape='square')
|
486 |
+
elif self.edge_attr(each_edge).get('tmp', False):
|
487 |
+
G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
|
488 |
+
fontcolor='black', shape='square')
|
489 |
+
else:
|
490 |
+
G.node(each_edge,
|
491 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
492 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
493 |
+
fontcolor='black', shape='square', style='filled')
|
494 |
+
if with_node:
|
495 |
+
for each_node in self.nodes_in_edge(each_edge):
|
496 |
+
G.edge(each_edge, each_node)
|
497 |
+
else:
|
498 |
+
for each_node in self.nodes_in_edge(each_edge):
|
499 |
+
if 'ext_id' in self.node_attr(each_node)\
|
500 |
+
and set([each_node, each_edge]) not in edge_list:
|
501 |
+
G.edge(each_edge, each_node)
|
502 |
+
edge_list.append(set([each_node, each_edge]))
|
503 |
+
for each_other_edge in self.adj_nodes(each_edge):
|
504 |
+
if set([each_edge, each_other_edge]) not in edge_list:
|
505 |
+
num_bond = 0
|
506 |
+
common_node_set = set(self.nodes_in_edge(each_edge))\
|
507 |
+
.intersection(set(self.nodes_in_edge(each_other_edge)))
|
508 |
+
for each_node in common_node_set:
|
509 |
+
if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
|
510 |
+
num_bond += self.node_attr(each_node)['symbol'].bond_type
|
511 |
+
elif self.node_attr(each_node)['symbol'].bond_type in [12]:
|
512 |
+
num_bond += 1
|
513 |
+
else:
|
514 |
+
raise NotImplementedError('unsupported bond type')
|
515 |
+
for _ in range(num_bond):
|
516 |
+
G.edge(each_edge, each_other_edge)
|
517 |
+
edge_list.append(set([each_edge, each_other_edge]))
|
518 |
+
if file_path is not None:
|
519 |
+
G.render(file_path, cleanup=True)
|
520 |
+
#os.remove(file_path)
|
521 |
+
return G
|
522 |
+
|
523 |
+
def is_dividable(self, node):
|
524 |
+
_hg = deepcopy(self.hg)
|
525 |
+
_hg.remove_node(node)
|
526 |
+
return (not nx.is_connected(_hg))
|
527 |
+
|
528 |
+
def divide(self, node):
|
529 |
+
subhg_list = []
|
530 |
+
|
531 |
+
hg_wo_node = deepcopy(self)
|
532 |
+
hg_wo_node.remove_node(node, remove_connected_edges=False)
|
533 |
+
connected_components = nx.connected_components(hg_wo_node.hg)
|
534 |
+
for each_component in connected_components:
|
535 |
+
node_list = [node]
|
536 |
+
edge_list = []
|
537 |
+
node_list.extend([each_node for each_node in each_component
|
538 |
+
if each_node.startswith('bond_')])
|
539 |
+
edge_list.extend([each_edge for each_edge in each_component
|
540 |
+
if each_edge.startswith('e')])
|
541 |
+
subhg_list.append(self.get_subhg(node_list, edge_list))
|
542 |
+
#subhg_list[-1].set_node_attr(node, {'divided': True})
|
543 |
+
return subhg_list
|
544 |
+
|
models/mhg_model/graph_grammar/io/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/io/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (679 Bytes). View file
|
|
models/mhg_model/graph_grammar/io/__pycache__/smi.cpython-310.pyc
ADDED
Binary file (13 kB). View file
|
|
models/mhg_model/graph_grammar/io/smi.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 12 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from rdkit import Chem
|
23 |
+
from rdkit import RDLogger
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
from ..hypergraph import Hypergraph
|
27 |
+
from ..graph_grammar.symbols import TSymbol, BondSymbol
|
28 |
+
|
29 |
+
# supress warnings
|
30 |
+
lg = RDLogger.logger()
|
31 |
+
lg.setLevel(RDLogger.CRITICAL)
|
32 |
+
|
33 |
+
|
34 |
+
class HGGen(object):
|
35 |
+
"""
|
36 |
+
load .smi file and yield a hypergraph.
|
37 |
+
|
38 |
+
Attributes
|
39 |
+
----------
|
40 |
+
path_to_file : str
|
41 |
+
path to .smi file
|
42 |
+
kekulize : bool
|
43 |
+
kekulize or not
|
44 |
+
add_Hs : bool
|
45 |
+
add implicit hydrogens to the molecule or not.
|
46 |
+
all_single : bool
|
47 |
+
if True, all multiple bonds are summarized into a single bond with some attributes
|
48 |
+
|
49 |
+
Yields
|
50 |
+
------
|
51 |
+
Hypergraph
|
52 |
+
"""
|
53 |
+
def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
|
54 |
+
self.num_line = 1
|
55 |
+
self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
|
56 |
+
self.kekulize = kekulize
|
57 |
+
self.add_Hs = add_Hs
|
58 |
+
self.all_single = all_single
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
return self
|
62 |
+
|
63 |
+
def __next__(self):
|
64 |
+
'''
|
65 |
+
each_mol = None
|
66 |
+
while each_mol is None:
|
67 |
+
each_mol = next(self.mol_gen)
|
68 |
+
'''
|
69 |
+
# not ignoring parse errors
|
70 |
+
each_mol = next(self.mol_gen)
|
71 |
+
if each_mol is None:
|
72 |
+
raise ValueError(f'incorrect smiles in line {self.num_line}')
|
73 |
+
else:
|
74 |
+
self.num_line += 1
|
75 |
+
return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
|
76 |
+
|
77 |
+
|
78 |
+
def mol_to_bipartite(mol, kekulize):
|
79 |
+
"""
|
80 |
+
get a bipartite representation of a molecule.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
mol : rdkit.Chem.rdchem.Mol
|
85 |
+
molecule object
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
nx.Graph
|
90 |
+
a bipartite graph representing which bond is connected to which atoms.
|
91 |
+
"""
|
92 |
+
try:
|
93 |
+
mol = standardize_stereo(mol)
|
94 |
+
except KeyError:
|
95 |
+
print(Chem.MolToSmiles(mol))
|
96 |
+
raise KeyError
|
97 |
+
|
98 |
+
if kekulize:
|
99 |
+
Chem.Kekulize(mol)
|
100 |
+
|
101 |
+
bipartite_g = nx.Graph()
|
102 |
+
for each_atom in mol.GetAtoms():
|
103 |
+
bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
|
104 |
+
atom_attr=atom_attr(each_atom, kekulize))
|
105 |
+
|
106 |
+
for each_bond in mol.GetBonds():
|
107 |
+
bond_idx = each_bond.GetIdx()
|
108 |
+
bipartite_g.add_node(
|
109 |
+
f"bond_{bond_idx}",
|
110 |
+
bond_attr=bond_attr(each_bond, kekulize))
|
111 |
+
bipartite_g.add_edge(
|
112 |
+
f"atom_{each_bond.GetBeginAtomIdx()}",
|
113 |
+
f"bond_{bond_idx}")
|
114 |
+
bipartite_g.add_edge(
|
115 |
+
f"atom_{each_bond.GetEndAtomIdx()}",
|
116 |
+
f"bond_{bond_idx}")
|
117 |
+
return bipartite_g
|
118 |
+
|
119 |
+
|
120 |
+
def mol_to_hg(mol, kekulize, add_Hs):
|
121 |
+
"""
|
122 |
+
get a bipartite representation of a molecule.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
mol : rdkit.Chem.rdchem.Mol
|
127 |
+
molecule object
|
128 |
+
kekulize : bool
|
129 |
+
kekulize or not
|
130 |
+
add_Hs : bool
|
131 |
+
add implicit hydrogens to the molecule or not.
|
132 |
+
|
133 |
+
Returns
|
134 |
+
-------
|
135 |
+
Hypergraph
|
136 |
+
"""
|
137 |
+
if add_Hs:
|
138 |
+
mol = Chem.AddHs(mol)
|
139 |
+
|
140 |
+
if kekulize:
|
141 |
+
Chem.Kekulize(mol)
|
142 |
+
|
143 |
+
bipartite_g = mol_to_bipartite(mol, kekulize)
|
144 |
+
hg = Hypergraph()
|
145 |
+
for each_atom in [each_node for each_node in bipartite_g.nodes()
|
146 |
+
if each_node.startswith('atom_')]:
|
147 |
+
node_set = set([])
|
148 |
+
for each_bond in bipartite_g.adj[each_atom]:
|
149 |
+
hg.add_node(each_bond,
|
150 |
+
attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
|
151 |
+
node_set.add(each_bond)
|
152 |
+
hg.add_edge(node_set,
|
153 |
+
attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
|
154 |
+
return hg
|
155 |
+
|
156 |
+
|
157 |
+
def hg_to_mol(hg, verbose=False):
|
158 |
+
""" convert a hypergraph into Mol object
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
mol : Chem.RWMol
|
167 |
+
"""
|
168 |
+
mol = Chem.RWMol()
|
169 |
+
atom_dict = {}
|
170 |
+
bond_set = set([])
|
171 |
+
for each_edge in hg.edges:
|
172 |
+
atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
|
173 |
+
atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
|
174 |
+
atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
|
175 |
+
atom.SetChiralTag(
|
176 |
+
Chem.rdchem.ChiralType.values[
|
177 |
+
hg.edge_attr(each_edge)['symbol'].chirality])
|
178 |
+
atom_idx = mol.AddAtom(atom)
|
179 |
+
atom_dict[each_edge] = atom_idx
|
180 |
+
|
181 |
+
for each_node in hg.nodes:
|
182 |
+
edge_1, edge_2 = hg.adj_edges(each_node)
|
183 |
+
if edge_1+edge_2 not in bond_set:
|
184 |
+
if hg.node_attr(each_node)['symbol'].bond_type <= 3:
|
185 |
+
num_bond = hg.node_attr(each_node)['symbol'].bond_type
|
186 |
+
elif hg.node_attr(each_node)['symbol'].bond_type == 12:
|
187 |
+
num_bond = 1
|
188 |
+
else:
|
189 |
+
raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
|
190 |
+
_ = mol.AddBond(atom_dict[edge_1],
|
191 |
+
atom_dict[edge_2],
|
192 |
+
order=Chem.rdchem.BondType.values[num_bond])
|
193 |
+
bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
|
194 |
+
|
195 |
+
# stereo
|
196 |
+
mol.GetBondWithIdx(bond_idx).SetStereo(
|
197 |
+
Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
|
198 |
+
bond_set.update([edge_1+edge_2])
|
199 |
+
bond_set.update([edge_2+edge_1])
|
200 |
+
mol.UpdatePropertyCache()
|
201 |
+
mol = mol.GetMol()
|
202 |
+
not_stereo_mol = deepcopy(mol)
|
203 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
|
204 |
+
raise RuntimeError('no valid molecule was obtained.')
|
205 |
+
try:
|
206 |
+
mol = set_stereo(mol)
|
207 |
+
is_stereo = True
|
208 |
+
except:
|
209 |
+
import traceback
|
210 |
+
traceback.print_exc()
|
211 |
+
is_stereo = False
|
212 |
+
mol_tmp = deepcopy(mol)
|
213 |
+
Chem.SetAromaticity(mol_tmp)
|
214 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
|
215 |
+
mol = mol_tmp
|
216 |
+
else:
|
217 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
|
218 |
+
mol = not_stereo_mol
|
219 |
+
mol.UpdatePropertyCache()
|
220 |
+
Chem.GetSymmSSSR(mol)
|
221 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
222 |
+
if verbose:
|
223 |
+
return mol, is_stereo
|
224 |
+
else:
|
225 |
+
return mol
|
226 |
+
|
227 |
+
def hgs_to_mols(hg_list, ignore_error=False):
|
228 |
+
if ignore_error:
|
229 |
+
mol_list = []
|
230 |
+
for each_hg in hg_list:
|
231 |
+
try:
|
232 |
+
mol = hg_to_mol(each_hg)
|
233 |
+
except:
|
234 |
+
mol = None
|
235 |
+
mol_list.append(mol)
|
236 |
+
else:
|
237 |
+
mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
|
238 |
+
return mol_list
|
239 |
+
|
240 |
+
def hgs_to_smiles(hg_list, ignore_error=False):
|
241 |
+
mol_list = hgs_to_mols(hg_list, ignore_error)
|
242 |
+
smiles_list = []
|
243 |
+
for each_mol in mol_list:
|
244 |
+
try:
|
245 |
+
smiles_list.append(
|
246 |
+
Chem.MolToSmiles(
|
247 |
+
Chem.MolFromSmiles(
|
248 |
+
Chem.MolToSmiles(
|
249 |
+
each_mol))))
|
250 |
+
except:
|
251 |
+
smiles_list.append(None)
|
252 |
+
return smiles_list
|
253 |
+
|
254 |
+
def atom_attr(atom, kekulize):
|
255 |
+
"""
|
256 |
+
get atom's attributes
|
257 |
+
|
258 |
+
Parameters
|
259 |
+
----------
|
260 |
+
atom : rdkit.Chem.rdchem.Atom
|
261 |
+
kekulize : bool
|
262 |
+
kekulize or not
|
263 |
+
|
264 |
+
Returns
|
265 |
+
-------
|
266 |
+
atom_attr : dict
|
267 |
+
"is_aromatic" : bool
|
268 |
+
the atom is aromatic or not.
|
269 |
+
"smarts" : str
|
270 |
+
SMARTS representation of the atom.
|
271 |
+
"""
|
272 |
+
if kekulize:
|
273 |
+
return {'terminal': True,
|
274 |
+
'is_in_ring': atom.IsInRing(),
|
275 |
+
'symbol': TSymbol(degree=0,
|
276 |
+
#degree=atom.GetTotalDegree(),
|
277 |
+
is_aromatic=False,
|
278 |
+
symbol=atom.GetSymbol(),
|
279 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
280 |
+
formal_charge=atom.GetFormalCharge(),
|
281 |
+
chirality=atom.GetChiralTag().real
|
282 |
+
)}
|
283 |
+
else:
|
284 |
+
return {'terminal': True,
|
285 |
+
'is_in_ring': atom.IsInRing(),
|
286 |
+
'symbol': TSymbol(degree=0,
|
287 |
+
#degree=atom.GetTotalDegree(),
|
288 |
+
is_aromatic=atom.GetIsAromatic(),
|
289 |
+
symbol=atom.GetSymbol(),
|
290 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
291 |
+
formal_charge=atom.GetFormalCharge(),
|
292 |
+
chirality=atom.GetChiralTag().real
|
293 |
+
)}
|
294 |
+
|
295 |
+
def bond_attr(bond, kekulize):
|
296 |
+
"""
|
297 |
+
get atom's attributes
|
298 |
+
|
299 |
+
Parameters
|
300 |
+
----------
|
301 |
+
bond : rdkit.Chem.rdchem.Bond
|
302 |
+
kekulize : bool
|
303 |
+
kekulize or not
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
bond_attr : dict
|
308 |
+
"bond_type" : int
|
309 |
+
{0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
|
310 |
+
1: rdkit.Chem.rdchem.BondType.SINGLE,
|
311 |
+
2: rdkit.Chem.rdchem.BondType.DOUBLE,
|
312 |
+
3: rdkit.Chem.rdchem.BondType.TRIPLE,
|
313 |
+
4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
|
314 |
+
5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
|
315 |
+
6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
|
316 |
+
7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
|
317 |
+
8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
|
318 |
+
9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
|
319 |
+
10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
|
320 |
+
11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
|
321 |
+
12: rdkit.Chem.rdchem.BondType.AROMATIC,
|
322 |
+
13: rdkit.Chem.rdchem.BondType.IONIC,
|
323 |
+
14: rdkit.Chem.rdchem.BondType.HYDROGEN,
|
324 |
+
15: rdkit.Chem.rdchem.BondType.THREECENTER,
|
325 |
+
16: rdkit.Chem.rdchem.BondType.DATIVEONE,
|
326 |
+
17: rdkit.Chem.rdchem.BondType.DATIVE,
|
327 |
+
18: rdkit.Chem.rdchem.BondType.DATIVEL,
|
328 |
+
19: rdkit.Chem.rdchem.BondType.DATIVER,
|
329 |
+
20: rdkit.Chem.rdchem.BondType.OTHER,
|
330 |
+
21: rdkit.Chem.rdchem.BondType.ZERO}
|
331 |
+
"""
|
332 |
+
if kekulize:
|
333 |
+
is_aromatic = False
|
334 |
+
if bond.GetBondType().real == 12:
|
335 |
+
bond_type = 1
|
336 |
+
else:
|
337 |
+
bond_type = bond.GetBondType().real
|
338 |
+
else:
|
339 |
+
is_aromatic = bond.GetIsAromatic()
|
340 |
+
bond_type = bond.GetBondType().real
|
341 |
+
return {'symbol': BondSymbol(is_aromatic=is_aromatic,
|
342 |
+
bond_type=bond_type,
|
343 |
+
stereo=int(bond.GetStereo())),
|
344 |
+
'is_in_ring': bond.IsInRing()}
|
345 |
+
|
346 |
+
|
347 |
+
def standardize_stereo(mol):
|
348 |
+
'''
|
349 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
350 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
351 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
352 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
353 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
354 |
+
|
355 |
+
'''
|
356 |
+
# mol = Chem.AddHs(mol) # this removes CIPRank !!!
|
357 |
+
for each_bond in mol.GetBonds():
|
358 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
359 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
360 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
361 |
+
atom_idx_1 = each_bond.GetStereoAtoms()[0]
|
362 |
+
atom_idx_2 = each_bond.GetStereoAtoms()[1]
|
363 |
+
if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
|
364 |
+
begin_atom_idx = atom_idx_1
|
365 |
+
end_atom_idx = atom_idx_2
|
366 |
+
else:
|
367 |
+
begin_atom_idx = atom_idx_2
|
368 |
+
end_atom_idx = atom_idx_1
|
369 |
+
|
370 |
+
begin_another_atom_idx = None
|
371 |
+
assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
|
372 |
+
for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
|
373 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
374 |
+
if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
|
375 |
+
begin_another_atom_idx = each_neighbor_idx
|
376 |
+
|
377 |
+
end_another_atom_idx = None
|
378 |
+
assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
|
379 |
+
for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
|
380 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
381 |
+
if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
|
382 |
+
end_another_atom_idx = each_neighbor_idx
|
383 |
+
|
384 |
+
'''
|
385 |
+
relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
|
386 |
+
'''
|
387 |
+
begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
|
388 |
+
end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
|
389 |
+
try:
|
390 |
+
begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
|
391 |
+
except:
|
392 |
+
begin_another_atom_rank = np.inf
|
393 |
+
try:
|
394 |
+
end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
|
395 |
+
except:
|
396 |
+
end_another_atom_rank = np.inf
|
397 |
+
if begin_atom_rank < begin_another_atom_rank\
|
398 |
+
and end_atom_rank < end_another_atom_rank:
|
399 |
+
pass
|
400 |
+
elif begin_atom_rank < begin_another_atom_rank\
|
401 |
+
and end_atom_rank > end_another_atom_rank:
|
402 |
+
# (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
|
403 |
+
if each_bond.GetStereo() == 2:
|
404 |
+
# set stereo
|
405 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
406 |
+
# set bond dir
|
407 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
408 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
409 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
410 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
411 |
+
elif each_bond.GetStereo() == 3:
|
412 |
+
# set stereo
|
413 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
414 |
+
# set bond dir
|
415 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
416 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
417 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
418 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
419 |
+
else:
|
420 |
+
raise ValueError
|
421 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
|
422 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
423 |
+
and end_atom_rank < end_another_atom_rank:
|
424 |
+
# (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
|
425 |
+
if each_bond.GetStereo() == 2:
|
426 |
+
# set stereo
|
427 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
428 |
+
# set bond dir
|
429 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
430 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
431 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
432 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
433 |
+
elif each_bond.GetStereo() == 3:
|
434 |
+
# set stereo
|
435 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
436 |
+
# set bond dir
|
437 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
438 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
439 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
440 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
441 |
+
else:
|
442 |
+
raise ValueError
|
443 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
|
444 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
445 |
+
and end_atom_rank > end_another_atom_rank:
|
446 |
+
# begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
|
447 |
+
if each_bond.GetStereo() == 2:
|
448 |
+
# set bond dir
|
449 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
450 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
451 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
452 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
453 |
+
elif each_bond.GetStereo() == 3:
|
454 |
+
# set bond dir
|
455 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
456 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
457 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
458 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
459 |
+
else:
|
460 |
+
raise ValueError
|
461 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
|
462 |
+
else:
|
463 |
+
raise RuntimeError
|
464 |
+
return mol
|
465 |
+
|
466 |
+
|
467 |
+
def set_stereo(mol):
|
468 |
+
'''
|
469 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
470 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
471 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
472 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
473 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
474 |
+
'''
|
475 |
+
_mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
476 |
+
Chem.Kekulize(_mol, True)
|
477 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
478 |
+
if not substruct_match:
|
479 |
+
''' mol and _mol are kekulized.
|
480 |
+
sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
|
481 |
+
'''
|
482 |
+
Chem.SetAromaticity(mol)
|
483 |
+
Chem.SetAromaticity(_mol)
|
484 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
485 |
+
try:
|
486 |
+
atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
|
487 |
+
except:
|
488 |
+
raise ValueError('two molecules obtained from the same data do not match.')
|
489 |
+
|
490 |
+
for each_bond in mol.GetBonds():
|
491 |
+
begin_atom_idx = each_bond.GetBeginAtomIdx()
|
492 |
+
end_atom_idx = each_bond.GetEndAtomIdx()
|
493 |
+
_bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
|
494 |
+
_bond.SetStereo(each_bond.GetStereo())
|
495 |
+
|
496 |
+
mol = _mol
|
497 |
+
for each_bond in mol.GetBonds():
|
498 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
499 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
500 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
501 |
+
begin_atom_idx_set = set([each_neighbor.GetIdx()
|
502 |
+
for each_neighbor
|
503 |
+
in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
|
504 |
+
if each_neighbor.GetIdx() != end_stereo_atom_idx])
|
505 |
+
end_atom_idx_set = set([each_neighbor.GetIdx()
|
506 |
+
for each_neighbor
|
507 |
+
in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
|
508 |
+
if each_neighbor.GetIdx() != begin_stereo_atom_idx])
|
509 |
+
if not begin_atom_idx_set:
|
510 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
511 |
+
continue
|
512 |
+
if not end_atom_idx_set:
|
513 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
514 |
+
continue
|
515 |
+
if len(begin_atom_idx_set) == 1:
|
516 |
+
begin_atom_idx = begin_atom_idx_set.pop()
|
517 |
+
begin_another_atom_idx = None
|
518 |
+
if len(end_atom_idx_set) == 1:
|
519 |
+
end_atom_idx = end_atom_idx_set.pop()
|
520 |
+
end_another_atom_idx = None
|
521 |
+
if len(begin_atom_idx_set) == 2:
|
522 |
+
atom_idx_1 = begin_atom_idx_set.pop()
|
523 |
+
atom_idx_2 = begin_atom_idx_set.pop()
|
524 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
525 |
+
begin_atom_idx = atom_idx_1
|
526 |
+
begin_another_atom_idx = atom_idx_2
|
527 |
+
else:
|
528 |
+
begin_atom_idx = atom_idx_2
|
529 |
+
begin_another_atom_idx = atom_idx_1
|
530 |
+
if len(end_atom_idx_set) == 2:
|
531 |
+
atom_idx_1 = end_atom_idx_set.pop()
|
532 |
+
atom_idx_2 = end_atom_idx_set.pop()
|
533 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
534 |
+
end_atom_idx = atom_idx_1
|
535 |
+
end_another_atom_idx = atom_idx_2
|
536 |
+
else:
|
537 |
+
end_atom_idx = atom_idx_2
|
538 |
+
end_another_atom_idx = atom_idx_1
|
539 |
+
|
540 |
+
if each_bond.GetStereo() == 2: # same side
|
541 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
542 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
543 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
544 |
+
elif each_bond.GetStereo() == 3: # opposite side
|
545 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
546 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
547 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
548 |
+
else:
|
549 |
+
raise ValueError
|
550 |
+
return mol
|
551 |
+
|
552 |
+
|
553 |
+
def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
|
554 |
+
if atom_idx_1 is None or atom_idx_2 is None:
|
555 |
+
return mol
|
556 |
+
else:
|
557 |
+
mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
|
558 |
+
return mol
|
559 |
+
|
models/mhg_model/graph_grammar/nn/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|