mgalkin commited on
Commit
c6a2ce9
1 Parent(s): 9024eb5

modeling script

Browse files
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ output/
10
+ .vscode/
11
+ .DS_Store
12
+ datasets/
13
+ ckpts/
14
+ *.csv
15
+ *.txt
modeling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import PretrainedConfig, PreTrainedModel
4
+ #sys.path.append(os.path.dirname(os.path.dirname(__file__)))
5
+ from ultra.models import Ultra
6
+ from ultra.datasets import WN18RR, CoDExSmall, FB15k237, FB15k237Inductive
7
+ from ultra.eval import test
8
+
9
+
10
+ class UltraConfig(PretrainedConfig):
11
+
12
+ model_type = "ultra"
13
+
14
+ def __init__(
15
+ self,
16
+ relation_model_layers: int = 6,
17
+ relation_model_dim: int = 64,
18
+ entity_model_layers: int = 6,
19
+ entity_model_dim: int = 64,
20
+ **kwargs):
21
+
22
+ self.relation_model_cfg = dict(
23
+ input_dim=relation_model_dim,
24
+ hidden_dims=[relation_model_dim]*relation_model_layers,
25
+ message_func="distmult",
26
+ aggregate_func="sum",
27
+ short_cut=True,
28
+ layer_norm=True
29
+ )
30
+
31
+ self.entity_model_cfg = dict(
32
+ input_dim=entity_model_dim,
33
+ hidden_dims=[entity_model_dim]*entity_model_layers,
34
+ message_func="distmult",
35
+ aggregate_func="sum",
36
+ short_cut=True,
37
+ layer_norm=True
38
+ )
39
+
40
+ super().__init__(**kwargs)
41
+
42
+ class UltraLinkPrediction(PreTrainedModel):
43
+
44
+ config_class = UltraConfig
45
+
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+
49
+ self.model = Ultra(
50
+ rel_model_cfg=config.relation_model_cfg,
51
+ entity_model_cfg=config.entity_model_cfg,
52
+ )
53
+
54
+ def forward(self, data, batch):
55
+ # data: PyG data object
56
+ # batch shape: (bs, 1+num_negs, 3)
57
+ return self.model.forward(data, batch)
58
+
59
+
60
+ if __name__ == "__main__":
61
+
62
+ model = UltraLinkPrediction.from_pretrained("mgalkin/ultra_4g")
63
+ dataset = CoDExSmall(root="./datasets/")
64
+ test(model, mode="test", dataset=dataset, gpus=None)
65
+ # mrr: 0.463971
66
+ # hits@10: 0.666028
ultra/__init__.py ADDED
File without changes
ultra/base_nbfnet.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from collections.abc import Sequence
3
+
4
+ import torch
5
+ from torch import nn, autograd
6
+
7
+ from torch_scatter import scatter_add
8
+ from . import tasks, layers
9
+
10
+
11
+ class BaseNBFNet(nn.Module):
12
+
13
+ def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="sum",
14
+ short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2,
15
+ dependent=False, remove_one_hop=False, num_beam=10, path_topk=10, **kwargs):
16
+ super(BaseNBFNet, self).__init__()
17
+
18
+ if not isinstance(hidden_dims, Sequence):
19
+ hidden_dims = [hidden_dims]
20
+
21
+ self.dims = [input_dim] + list(hidden_dims)
22
+ self.num_relation = num_relation
23
+ self.short_cut = short_cut # whether to use residual connections between GNN layers
24
+ self.concat_hidden = concat_hidden # whether to compute final states as a function of all layer outputs or last
25
+ self.remove_one_hop = remove_one_hop # whether to dynamically remove one-hop edges from edge_index
26
+ self.num_beam = num_beam
27
+ self.path_topk = path_topk
28
+
29
+ self.message_func = message_func
30
+ self.aggregate_func = aggregate_func
31
+ self.layer_norm = layer_norm
32
+ self.activation = activation
33
+ self.num_mlp_layers = num_mlp_layer
34
+
35
+ # self.layers = nn.ModuleList()
36
+ # for i in range(len(self.dims) - 1):
37
+ # self.layers.append(layers.GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation,
38
+ # self.dims[0], message_func, aggregate_func, layer_norm,
39
+ # activation, dependent))
40
+
41
+ # feature_dim = (sum(hidden_dims) if concat_hidden else hidden_dims[-1]) + input_dim
42
+
43
+ # # additional relation embedding which serves as an initial 'query' for the NBFNet forward pass
44
+ # # each layer has its own learnable relations matrix, so we send the total number of relations, too
45
+ # self.query = nn.Embedding(num_relation, input_dim)
46
+ # self.mlp = nn.Sequential()
47
+ # mlp = []
48
+ # for i in range(num_mlp_layer - 1):
49
+ # mlp.append(nn.Linear(feature_dim, feature_dim))
50
+ # mlp.append(nn.ReLU())
51
+ # mlp.append(nn.Linear(feature_dim, 1))
52
+ # self.mlp = nn.Sequential(*mlp)
53
+
54
+ def remove_easy_edges(self, data, h_index, t_index, r_index=None):
55
+ # we remove training edges (we need to predict them at training time) from the edge index
56
+ # think of it as a dynamic edge dropout
57
+ h_index_ext = torch.cat([h_index, t_index], dim=-1)
58
+ t_index_ext = torch.cat([t_index, h_index], dim=-1)
59
+ r_index_ext = torch.cat([r_index, r_index + data.num_relations // 2], dim=-1)
60
+ if self.remove_one_hop:
61
+ # we remove all existing immediate edges between heads and tails in the batch
62
+ edge_index = data.edge_index
63
+ easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1)
64
+ index = tasks.edge_match(edge_index, easy_edge)[0]
65
+ mask = ~index_to_mask(index, data.num_edges)
66
+ else:
67
+ # we remove existing immediate edges between heads and tails in the batch with the given relation
68
+ edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)])
69
+ # note that here we add relation types r_index_ext to the matching query
70
+ easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1)
71
+ index = tasks.edge_match(edge_index, easy_edge)[0]
72
+ mask = ~index_to_mask(index, data.num_edges)
73
+
74
+ data = copy.copy(data)
75
+ data.edge_index = data.edge_index[:, mask]
76
+ data.edge_type = data.edge_type[mask]
77
+ return data
78
+
79
+ def negative_sample_to_tail(self, h_index, t_index, r_index, num_direct_rel):
80
+ # convert p(h | t, r) to p(t' | h', r')
81
+ # h' = t, r' = r^{-1}, t' = h
82
+ is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True)
83
+ new_h_index = torch.where(is_t_neg, h_index, t_index)
84
+ new_t_index = torch.where(is_t_neg, t_index, h_index)
85
+ new_r_index = torch.where(is_t_neg, r_index, r_index + num_direct_rel)
86
+ return new_h_index, new_t_index, new_r_index
87
+
88
+ def bellmanford(self, data, h_index, r_index, separate_grad=False):
89
+ batch_size = len(r_index)
90
+
91
+ # initialize queries (relation types of the given triples)
92
+ query = self.query(r_index)
93
+ index = h_index.unsqueeze(-1).expand_as(query)
94
+
95
+ # initial (boundary) condition - initialize all node states as zeros
96
+ boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
97
+ # by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
98
+ boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
99
+ size = (data.num_nodes, data.num_nodes)
100
+ edge_weight = torch.ones(data.num_edges, device=h_index.device)
101
+
102
+ hiddens = []
103
+ edge_weights = []
104
+ layer_input = boundary
105
+
106
+ for layer in self.layers:
107
+ if separate_grad:
108
+ edge_weight = edge_weight.clone().requires_grad_()
109
+ # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
110
+ hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
111
+ if self.short_cut and hidden.shape == layer_input.shape:
112
+ # residual connection here
113
+ hidden = hidden + layer_input
114
+ hiddens.append(hidden)
115
+ edge_weights.append(edge_weight)
116
+ layer_input = hidden
117
+
118
+ # original query (relation type) embeddings
119
+ node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
120
+ if self.concat_hidden:
121
+ output = torch.cat(hiddens + [node_query], dim=-1)
122
+ else:
123
+ output = torch.cat([hiddens[-1], node_query], dim=-1)
124
+
125
+ return {
126
+ "node_feature": output,
127
+ "edge_weights": edge_weights,
128
+ }
129
+
130
+ def forward(self, data, batch):
131
+ h_index, t_index, r_index = batch.unbind(-1)
132
+ if self.training:
133
+ # Edge dropout in the training mode
134
+ # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
135
+ # to make NBFNet iteration learn non-trivial paths
136
+ data = self.remove_easy_edges(data, h_index, t_index, r_index, data.num_relations // 2)
137
+
138
+ shape = h_index.shape
139
+ # turn all triples in a batch into a tail prediction mode
140
+ h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
141
+ assert (h_index[:, [0]] == h_index).all()
142
+ assert (r_index[:, [0]] == r_index).all()
143
+
144
+ # message passing and updated node representations
145
+ output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
146
+ feature = output["node_feature"]
147
+ index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
148
+ # extract representations of tail entities from the updated node states
149
+ feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
150
+
151
+ # probability logit for each tail node in the batch
152
+ # (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
153
+ score = self.mlp(feature).squeeze(-1)
154
+ return score.view(shape)
155
+
156
+ def visualize(self, data, batch):
157
+ assert batch.shape == (1, 3)
158
+ h_index, t_index, r_index = batch.unbind(-1)
159
+
160
+ output = self.bellmanford(data, h_index, r_index, separate_grad=True)
161
+ feature = output["node_feature"]
162
+ edge_weights = output["edge_weights"]
163
+
164
+ index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1])
165
+ feature = feature.gather(1, index).squeeze(0)
166
+ score = self.mlp(feature).squeeze(-1)
167
+
168
+ edge_grads = autograd.grad(score, edge_weights)
169
+ distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam)
170
+ paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
171
+
172
+ return paths, weights
173
+
174
+ @torch.no_grad()
175
+ def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10):
176
+ # beam search the top-k distance from h to t (and to every other node)
177
+ num_nodes = data.num_nodes
178
+ input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device)
179
+ input[h_index, 0] = 0
180
+ edge_mask = data.edge_index[0, :] != t_index
181
+
182
+ distances = []
183
+ back_edges = []
184
+ for edge_grad in edge_grads:
185
+ # we don't allow any path goes out of t once it arrives at t
186
+ node_in, node_out = data.edge_index[:, edge_mask]
187
+ relation = data.edge_type[edge_mask]
188
+ edge_grad = edge_grad[edge_mask]
189
+
190
+ message = input[node_in] + edge_grad.unsqueeze(-1) # (num_edges, num_beam)
191
+ # (num_edges, num_beam, 3)
192
+ msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1)
193
+
194
+ # (num_edges, num_beam)
195
+ is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \
196
+ (msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1)
197
+ # pick the first occurrence as the ranking in the previous node's beam
198
+ # this makes deduplication easier later
199
+ # and store it in msg_source
200
+ is_duplicate = is_duplicate.float() - \
201
+ torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1)
202
+ prev_rank = is_duplicate.argmax(dim=-1, keepdim=True)
203
+ msg_source = torch.cat([msg_source, prev_rank], dim=-1) # (num_edges, num_beam, 4)
204
+
205
+ node_out, order = node_out.sort()
206
+ node_out_set = torch.unique(node_out)
207
+ # sort messages w.r.t. node_out
208
+ message = message[order].flatten() # (num_edges * num_beam)
209
+ msg_source = msg_source[order].flatten(0, -2) # (num_edges * num_beam, 4)
210
+ size = node_out.bincount(minlength=num_nodes)
211
+ msg2out = size_to_index(size[node_out_set] * num_beam)
212
+ # deduplicate messages that are from the same source and the same beam
213
+ is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1)
214
+ is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate])
215
+ message = message[~is_duplicate]
216
+ msg_source = msg_source[~is_duplicate]
217
+ msg2out = msg2out[~is_duplicate]
218
+ size = msg2out.bincount(minlength=len(node_out_set))
219
+
220
+ if not torch.isinf(message).all():
221
+ # take the topk messages from the neighborhood
222
+ # distance: (len(node_out_set) * num_beam)
223
+ distance, rel_index = scatter_topk(message, size, k=num_beam)
224
+ abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1)
225
+ # store msg_source for backtracking
226
+ back_edge = msg_source[abs_index] # (len(node_out_set) * num_beam, 4)
227
+ distance = distance.view(len(node_out_set), num_beam)
228
+ back_edge = back_edge.view(len(node_out_set), num_beam, 4)
229
+ # scatter distance / back_edge back to all nodes
230
+ distance = scatter_add(distance, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam)
231
+ back_edge = scatter_add(back_edge, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam, 4)
232
+ else:
233
+ distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device)
234
+ back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device)
235
+
236
+ distances.append(distance)
237
+ back_edges.append(back_edge)
238
+ input = distance
239
+
240
+ return distances, back_edges
241
+
242
+ def topk_average_length(self, distances, back_edges, t_index, k=10):
243
+ # backtrack distances and back_edges to generate the paths
244
+ paths = []
245
+ average_lengths = []
246
+
247
+ for i in range(len(distances)):
248
+ distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True)
249
+ back_edge = back_edges[i][t_index].flatten(0, -2)[order]
250
+ for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()):
251
+ if d == float("-inf"):
252
+ break
253
+ path = [(h, t, r)]
254
+ for j in range(i - 1, -1, -1):
255
+ h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist()
256
+ path.append((h, t, r))
257
+ paths.append(path[::-1])
258
+ average_lengths.append(d / len(path))
259
+
260
+ if paths:
261
+ average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k])
262
+
263
+ return paths, average_lengths
264
+
265
+
266
+ def index_to_mask(index, size):
267
+ index = index.view(-1)
268
+ size = int(index.max()) + 1 if size is None else size
269
+ mask = index.new_zeros(size, dtype=torch.bool)
270
+ mask[index] = True
271
+ return mask
272
+
273
+
274
+ def size_to_index(size):
275
+ range = torch.arange(len(size), device=size.device)
276
+ index2sample = range.repeat_interleave(size)
277
+ return index2sample
278
+
279
+
280
+ def multi_slice_mask(starts, ends, length):
281
+ values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
282
+ slices = torch.cat([starts, ends])
283
+ mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1]
284
+ mask = mask.cumsum(0).bool()
285
+ return mask
286
+
287
+
288
+ def scatter_extend(data, size, input, input_size):
289
+ new_size = size + input_size
290
+ new_cum_size = new_size.cumsum(0)
291
+ new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
292
+ starts = new_cum_size - new_size
293
+ ends = starts + size
294
+ index = multi_slice_mask(starts, ends, new_cum_size[-1])
295
+ new_data[index] = data
296
+ new_data[~index] = input
297
+ return new_data, new_size
298
+
299
+
300
+ def scatter_topk(input, size, k, largest=True):
301
+ index2graph = size_to_index(size)
302
+ index2graph = index2graph.view([-1] + [1] * (input.ndim - 1))
303
+
304
+ mask = ~torch.isinf(input)
305
+ max = input[mask].max().item()
306
+ min = input[mask].min().item()
307
+ safe_input = input.clamp(2 * min - max, 2 * max - min)
308
+ offset = (max - min) * 4
309
+ if largest:
310
+ offset = -offset
311
+ input_ext = safe_input + offset * index2graph
312
+ index_ext = input_ext.argsort(dim=0, descending=largest)
313
+ num_actual = size.clamp(max=k)
314
+ num_padding = k - num_actual
315
+ starts = size.cumsum(0) - size
316
+ ends = starts + num_actual
317
+ mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten()
318
+
319
+ if (num_padding > 0).any():
320
+ # special case: size < k, pad with the last valid index
321
+ padding = ends - 1
322
+ padding2graph = size_to_index(num_padding)
323
+ mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0]
324
+
325
+ index = index_ext[mask] # (N * k, ...)
326
+ value = input.gather(0, index)
327
+ if isinstance(k, torch.Tensor) and k.shape == size.shape:
328
+ value = value.view(-1, *input.shape[1:])
329
+ index = index.view(-1, *input.shape[1:])
330
+ index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
331
+ else:
332
+ value = value.view(-1, k, *input.shape[1:])
333
+ index = index.view(-1, k, *input.shape[1:])
334
+ index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
335
+
336
+ return value, index
ultra/datasets.py ADDED
@@ -0,0 +1,1095 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import shutil
4
+ import torch
5
+ from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
6
+ from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
7
+
8
+ from ultra.tasks import build_relation_graph
9
+
10
+
11
+ class GrailInductiveDataset(InMemoryDataset):
12
+
13
+ def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, merge_valid_test=True):
14
+ self.version = version
15
+ assert version in ["v1", "v2", "v3", "v4"]
16
+
17
+ # by default, most models on Grail datasets merge inductive valid and test splits as the final test split
18
+ # with this choice, the validation set is that of the transductive train (on the seen graph)
19
+ # by default it's turned on but you can experiment with turning this option off
20
+ # you'll need to delete the processed datasets then and re-run to cache a new dataset
21
+ self.merge_valid_test = merge_valid_test
22
+ super().__init__(root, transform, pre_transform)
23
+ self.data, self.slices = torch.load(self.processed_paths[0])
24
+
25
+ @property
26
+ def num_relations(self):
27
+ return int(self.data.edge_type.max()) + 1
28
+
29
+ @property
30
+ def raw_dir(self):
31
+ return os.path.join(self.root, "grail", self.name, self.version, "raw")
32
+
33
+ @property
34
+ def processed_dir(self):
35
+ return os.path.join(self.root, "grail", self.name, self.version, "processed")
36
+
37
+ @property
38
+ def processed_file_names(self):
39
+ return "data.pt"
40
+
41
+ @property
42
+ def raw_file_names(self):
43
+ return [
44
+ "train_ind.txt", "valid_ind.txt", "test_ind.txt", "train.txt", "valid.txt"
45
+ ]
46
+
47
+ def download(self):
48
+ for url, path in zip(self.urls, self.raw_paths):
49
+ download_path = download_url(url % self.version, self.raw_dir)
50
+ os.rename(download_path, path)
51
+
52
+ def process(self):
53
+ test_files = self.raw_paths[:3]
54
+ train_files = self.raw_paths[3:]
55
+
56
+ inv_train_entity_vocab = {}
57
+ inv_test_entity_vocab = {}
58
+ inv_relation_vocab = {}
59
+ triplets = []
60
+ num_samples = []
61
+
62
+ for txt_file in train_files:
63
+ with open(txt_file, "r") as fin:
64
+ num_sample = 0
65
+ for line in fin:
66
+ h_token, r_token, t_token = line.strip().split("\t")
67
+ if h_token not in inv_train_entity_vocab:
68
+ inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab)
69
+ h = inv_train_entity_vocab[h_token]
70
+ if r_token not in inv_relation_vocab:
71
+ inv_relation_vocab[r_token] = len(inv_relation_vocab)
72
+ r = inv_relation_vocab[r_token]
73
+ if t_token not in inv_train_entity_vocab:
74
+ inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab)
75
+ t = inv_train_entity_vocab[t_token]
76
+ triplets.append((h, t, r))
77
+ num_sample += 1
78
+ num_samples.append(num_sample)
79
+
80
+ for txt_file in test_files:
81
+ with open(txt_file, "r") as fin:
82
+ num_sample = 0
83
+ for line in fin:
84
+ h_token, r_token, t_token = line.strip().split("\t")
85
+ if h_token not in inv_test_entity_vocab:
86
+ inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab)
87
+ h = inv_test_entity_vocab[h_token]
88
+ assert r_token in inv_relation_vocab
89
+ r = inv_relation_vocab[r_token]
90
+ if t_token not in inv_test_entity_vocab:
91
+ inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab)
92
+ t = inv_test_entity_vocab[t_token]
93
+ triplets.append((h, t, r))
94
+ num_sample += 1
95
+ num_samples.append(num_sample)
96
+ triplets = torch.tensor(triplets)
97
+
98
+ edge_index = triplets[:, :2].t()
99
+ edge_type = triplets[:, 2]
100
+ num_relations = int(edge_type.max()) + 1
101
+
102
+ # creating fact graphs - those are graphs sent to a model, based on which we'll predict missing facts
103
+ # also, those fact graphs will be used for filtered evaluation
104
+ train_fact_slice = slice(None, sum(num_samples[:1]))
105
+ test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3]))
106
+ train_fact_index = edge_index[:, train_fact_slice]
107
+ train_fact_type = edge_type[train_fact_slice]
108
+ test_fact_index = edge_index[:, test_fact_slice]
109
+ test_fact_type = edge_type[test_fact_slice]
110
+
111
+ # add flipped triplets for the fact graphs
112
+ train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1)
113
+ train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations])
114
+ test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1)
115
+ test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations])
116
+
117
+ train_slice = slice(None, sum(num_samples[:1]))
118
+ valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2]))
119
+ # by default, SOTA models on Grail datasets merge inductive valid and test splits as the final test split
120
+ # with this choice, the validation set is that of the transductive train (on the seen graph)
121
+ # by default it's turned on but you can experiment with turning this option off
122
+ test_slice = slice(sum(num_samples[:3]), sum(num_samples)) if self.merge_valid_test else slice(sum(num_samples[:4]), sum(num_samples))
123
+
124
+ train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
125
+ target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice], num_relations=num_relations*2)
126
+ valid_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
127
+ target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice], num_relations=num_relations*2)
128
+ test_data = Data(edge_index=test_fact_index, edge_type=test_fact_type, num_nodes=len(inv_test_entity_vocab),
129
+ target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice], num_relations=num_relations*2)
130
+
131
+ if self.pre_transform is not None:
132
+ train_data = self.pre_transform(train_data)
133
+ valid_data = self.pre_transform(valid_data)
134
+ test_data = self.pre_transform(test_data)
135
+
136
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
137
+
138
+ def __repr__(self):
139
+ return "%s(%s)" % (self.name, self.version)
140
+
141
+
142
+ class FB15k237Inductive(GrailInductiveDataset):
143
+
144
+ urls = [
145
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt",
146
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/valid.txt",
147
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt",
148
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt",
149
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt"
150
+ ]
151
+
152
+ name = "IndFB15k237"
153
+
154
+ def __init__(self, root, version):
155
+ super().__init__(root, version)
156
+
157
+ class WN18RRInductive(GrailInductiveDataset):
158
+
159
+ urls = [
160
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt",
161
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/valid.txt",
162
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt",
163
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt",
164
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt"
165
+ ]
166
+
167
+ name = "IndWN18RR"
168
+
169
+ def __init__(self, root, version):
170
+ super().__init__(root, version)
171
+
172
+ class NELLInductive(GrailInductiveDataset):
173
+ urls = [
174
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/train.txt",
175
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/valid.txt",
176
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/test.txt",
177
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/train.txt",
178
+ "https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/valid.txt"
179
+ ]
180
+ name = "IndNELL"
181
+
182
+ def __init__(self, root, version):
183
+ super().__init__(root, version)
184
+
185
+
186
+ def FB15k237(root):
187
+ dataset = RelLinkPredDataset(name="FB15k-237", root=root+"/fb15k237/")
188
+ data = dataset.data
189
+ train_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
190
+ target_edge_index=data.train_edge_index, target_edge_type=data.train_edge_type,
191
+ num_relations=dataset.num_relations)
192
+ valid_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
193
+ target_edge_index=data.valid_edge_index, target_edge_type=data.valid_edge_type,
194
+ num_relations=dataset.num_relations)
195
+ test_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
196
+ target_edge_index=data.test_edge_index, target_edge_type=data.test_edge_type,
197
+ num_relations=dataset.num_relations)
198
+
199
+ # build relation graphs
200
+ train_data = build_relation_graph(train_data)
201
+ valid_data = build_relation_graph(valid_data)
202
+ test_data = build_relation_graph(test_data)
203
+
204
+ dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
205
+ return dataset
206
+
207
+ def WN18RR(root):
208
+ dataset = WordNet18RR(root=root+"/wn18rr/")
209
+ # convert wn18rr into the same format as fb15k-237
210
+ data = dataset.data
211
+ num_nodes = int(data.edge_index.max()) + 1
212
+ num_relations = int(data.edge_type.max()) + 1
213
+ edge_index = data.edge_index[:, data.train_mask]
214
+ edge_type = data.edge_type[data.train_mask]
215
+ edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
216
+ edge_type = torch.cat([edge_type, edge_type + num_relations])
217
+ train_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
218
+ target_edge_index=data.edge_index[:, data.train_mask],
219
+ target_edge_type=data.edge_type[data.train_mask],
220
+ num_relations=num_relations*2)
221
+ valid_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
222
+ target_edge_index=data.edge_index[:, data.val_mask],
223
+ target_edge_type=data.edge_type[data.val_mask],
224
+ num_relations=num_relations*2)
225
+ test_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
226
+ target_edge_index=data.edge_index[:, data.test_mask],
227
+ target_edge_type=data.edge_type[data.test_mask],
228
+ num_relations=num_relations*2)
229
+
230
+ # build relation graphs
231
+ train_data = build_relation_graph(train_data)
232
+ valid_data = build_relation_graph(valid_data)
233
+ test_data = build_relation_graph(test_data)
234
+
235
+ dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
236
+ dataset.num_relations = num_relations * 2
237
+ return dataset
238
+
239
+
240
+ class TransductiveDataset(InMemoryDataset):
241
+
242
+ delimiter = None
243
+
244
+ def __init__(self, root, transform=None, pre_transform=build_relation_graph, **kwargs):
245
+
246
+ super().__init__(root, transform, pre_transform)
247
+ self.data, self.slices = torch.load(self.processed_paths[0])
248
+
249
+ @property
250
+ def raw_file_names(self):
251
+ return ["train.txt", "valid.txt", "test.txt"]
252
+
253
+ def download(self):
254
+ for url, path in zip(self.urls, self.raw_paths):
255
+ download_path = download_url(url, self.raw_dir)
256
+ os.rename(download_path, path)
257
+
258
+ def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
259
+
260
+ triplets = []
261
+ entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
262
+
263
+ with open(triplet_file, "r", encoding="utf-8") as fin:
264
+ for l in fin:
265
+ u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
266
+ if u not in inv_entity_vocab:
267
+ inv_entity_vocab[u] = entity_cnt
268
+ entity_cnt += 1
269
+ if v not in inv_entity_vocab:
270
+ inv_entity_vocab[v] = entity_cnt
271
+ entity_cnt += 1
272
+ if r not in inv_rel_vocab:
273
+ inv_rel_vocab[r] = rel_cnt
274
+ rel_cnt += 1
275
+ u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
276
+
277
+ triplets.append((u, v, r))
278
+
279
+ return {
280
+ "triplets": triplets,
281
+ "num_node": len(inv_entity_vocab), #entity_cnt,
282
+ "num_relation": rel_cnt,
283
+ "inv_entity_vocab": inv_entity_vocab,
284
+ "inv_rel_vocab": inv_rel_vocab
285
+ }
286
+
287
+ # default loading procedure: process train/valid/test files, create graphs from them
288
+ def process(self):
289
+
290
+ train_files = self.raw_paths[:3]
291
+
292
+ train_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
293
+ valid_results = self.load_file(train_files[1],
294
+ train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
295
+ test_results = self.load_file(train_files[2],
296
+ train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
297
+
298
+ # in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train adn 123,182 in YAGO test
299
+ # for consistency with other experimental results, we'll include those in the full vocab and num nodes
300
+ num_node = test_results["num_node"]
301
+ # the same for rels: in most cases train == test for transductive
302
+ # for AristoV4 train rels 1593, test 1604
303
+ num_relations = test_results["num_relation"]
304
+
305
+ train_triplets = train_results["triplets"]
306
+ valid_triplets = valid_results["triplets"]
307
+ test_triplets = test_results["triplets"]
308
+
309
+ train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
310
+ train_target_etypes = torch.tensor([t[2] for t in train_triplets])
311
+
312
+ valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
313
+ valid_etypes = torch.tensor([t[2] for t in valid_triplets])
314
+
315
+ test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
316
+ test_etypes = torch.tensor([t[2] for t in test_triplets])
317
+
318
+ train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
319
+ train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
320
+
321
+ train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
322
+ target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
323
+ valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
324
+ target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
325
+ test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
326
+ target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
327
+
328
+ # build graphs of relations
329
+ if self.pre_transform is not None:
330
+ train_data = self.pre_transform(train_data)
331
+ valid_data = self.pre_transform(valid_data)
332
+ test_data = self.pre_transform(test_data)
333
+
334
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
335
+
336
+ def __repr__(self):
337
+ return "%s()" % (self.name)
338
+
339
+ @property
340
+ def num_relations(self):
341
+ return int(self.data.edge_type.max()) + 1
342
+
343
+ @property
344
+ def raw_dir(self):
345
+ return os.path.join(self.root, self.name, "raw")
346
+
347
+ @property
348
+ def processed_dir(self):
349
+ return os.path.join(self.root, self.name, "processed")
350
+
351
+ @property
352
+ def processed_file_names(self):
353
+ return "data.pt"
354
+
355
+
356
+
357
+ class CoDEx(TransductiveDataset):
358
+
359
+ name = "codex"
360
+ urls = [
361
+ "https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/train.txt",
362
+ "https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/valid.txt",
363
+ "https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/test.txt",
364
+ ]
365
+
366
+ def download(self):
367
+ for url, path in zip(self.urls, self.raw_paths):
368
+ download_path = download_url(url % self.name, self.raw_dir)
369
+ os.rename(download_path, path)
370
+
371
+
372
+ class CoDExSmall(CoDEx):
373
+ """
374
+ #node: 2034
375
+ #edge: 36543
376
+ #relation: 42
377
+ """
378
+ url = "https://zenodo.org/record/4281094/files/codex-s.tar.gz"
379
+ md5 = "63cd8186fc2aeddc154e20cf4a10087e"
380
+ name = "codex-s"
381
+
382
+ def __init__(self, root):
383
+ super(CoDExSmall, self).__init__(root=root, size='s')
384
+
385
+
386
+ class CoDExMedium(CoDEx):
387
+ """
388
+ #node: 17050
389
+ #edge: 206205
390
+ #relation: 51
391
+ """
392
+ url = "https://zenodo.org/record/4281094/files/codex-m.tar.gz"
393
+ md5 = "43e561cfdca1c6ad9cc2f5b1ca4add76"
394
+ name = "codex-m"
395
+ def __init__(self, root):
396
+ super(CoDExMedium, self).__init__(root=root, size='m')
397
+
398
+
399
+ class CoDExLarge(CoDEx):
400
+ """
401
+ #node: 77951
402
+ #edge: 612437
403
+ #relation: 69
404
+ """
405
+ url = "https://zenodo.org/record/4281094/files/codex-l.tar.gz"
406
+ md5 = "9a10f4458c4bd2b16ef9b92b677e0d71"
407
+ name = "codex-l"
408
+ def __init__(self, root):
409
+ super(CoDExLarge, self).__init__(root=root, size='l')
410
+
411
+
412
+ class NELL995(TransductiveDataset):
413
+
414
+ # from the RED-GNN paper https://github.com/LARS-research/RED-GNN/tree/main/transductive/data/nell
415
+ # the OG dumps were found to have test set leakages
416
+ # training set is made out of facts+train files, so we sum up their samples to build one training graph
417
+
418
+ urls = [
419
+ "https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/facts.txt",
420
+ "https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/train.txt",
421
+ "https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/valid.txt",
422
+ "https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/test.txt",
423
+ ]
424
+ name = "nell995"
425
+
426
+ @property
427
+ def raw_file_names(self):
428
+ return ["facts.txt", "train.txt", "valid.txt", "test.txt"]
429
+
430
+
431
+ def process(self):
432
+ train_files = self.raw_paths[:4]
433
+
434
+ facts_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
435
+ train_results = self.load_file(train_files[1], facts_results["inv_entity_vocab"], facts_results["inv_rel_vocab"])
436
+ valid_results = self.load_file(train_files[2], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
437
+ test_results = self.load_file(train_files[3], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
438
+
439
+ num_node = valid_results["num_node"]
440
+ num_relations = train_results["num_relation"]
441
+
442
+ train_triplets = facts_results["triplets"] + train_results["triplets"]
443
+ valid_triplets = valid_results["triplets"]
444
+ test_triplets = test_results["triplets"]
445
+
446
+ train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
447
+ train_target_etypes = torch.tensor([t[2] for t in train_triplets])
448
+
449
+ valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
450
+ valid_etypes = torch.tensor([t[2] for t in valid_triplets])
451
+
452
+ test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
453
+ test_etypes = torch.tensor([t[2] for t in test_triplets])
454
+
455
+ train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
456
+ train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
457
+
458
+ train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
459
+ target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
460
+ valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
461
+ target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
462
+ test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
463
+ target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
464
+
465
+ # build graphs of relations
466
+ if self.pre_transform is not None:
467
+ train_data = self.pre_transform(train_data)
468
+ valid_data = self.pre_transform(valid_data)
469
+ test_data = self.pre_transform(test_data)
470
+
471
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
472
+
473
+
474
+ class ConceptNet100k(TransductiveDataset):
475
+
476
+ urls = [
477
+ "https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/train",
478
+ "https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/valid",
479
+ "https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/test",
480
+ ]
481
+ name = "cnet100k"
482
+ delimiter = "\t"
483
+
484
+
485
+ class DBpedia100k(TransductiveDataset):
486
+ urls = [
487
+ "https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_train.txt",
488
+ "https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_valid.txt",
489
+ "https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_test.txt",
490
+ ]
491
+ name = "dbp100k"
492
+
493
+
494
+ class YAGO310(TransductiveDataset):
495
+
496
+ urls = [
497
+ "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt",
498
+ "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt",
499
+ "https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt",
500
+ ]
501
+ name = "yago310"
502
+
503
+
504
+ class Hetionet(TransductiveDataset):
505
+
506
+ urls = [
507
+ "https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1",
508
+ "https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1",
509
+ "https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1",
510
+ ]
511
+ name = "hetionet"
512
+
513
+
514
+ class AristoV4(TransductiveDataset):
515
+
516
+ url = "https://zenodo.org/record/5942560/files/aristo-v4.zip"
517
+
518
+ name = "aristov4"
519
+ delimiter = "\t"
520
+
521
+ def download(self):
522
+ download_path = download_url(self.url, self.raw_dir)
523
+ extract_zip(download_path, self.raw_dir)
524
+ os.unlink(download_path)
525
+ for oldname, newname in zip(['train', 'valid', 'test'], self.raw_paths):
526
+ os.rename(os.path.join(self.raw_dir, oldname), newname)
527
+
528
+
529
+ class SparserKG(TransductiveDataset):
530
+
531
+ # 5 datasets based on FB/NELL/WD, introduced in https://github.com/THU-KEG/DacKGR
532
+ # re-writing the loading function because dumps are in the format (h, t, r) while the standard is (h, r, t)
533
+
534
+ url = "https://raw.githubusercontent.com/THU-KEG/DacKGR/master/data.zip"
535
+ delimiter = "\t"
536
+ base_name = "SparseKG"
537
+
538
+ @property
539
+ def raw_dir(self):
540
+ return os.path.join(self.root, self.base_name, self.name, "raw")
541
+
542
+ @property
543
+ def processed_dir(self):
544
+ return os.path.join(self.root, self.base_name, self.name, "processed")
545
+
546
+ def download(self):
547
+ base_path = os.path.join(self.root, self.base_name)
548
+ download_path = download_url(self.url, base_path)
549
+ extract_zip(download_path, base_path)
550
+ for dsname in ['NELL23K', 'WD-singer', 'FB15K-237-10', 'FB15K-237-20', 'FB15K-237-50']:
551
+ for oldname, newname in zip(['train.triples', 'dev.triples', 'test.triples'], self.raw_file_names):
552
+ os.renames(os.path.join(base_path, "data", dsname, oldname), os.path.join(base_path, dsname, "raw", newname))
553
+ shutil.rmtree(os.path.join(base_path, "data"))
554
+
555
+ def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
556
+
557
+ triplets = []
558
+ entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
559
+
560
+ with open(triplet_file, "r", encoding="utf-8") as fin:
561
+ for l in fin:
562
+ u, v, r = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
563
+ if u not in inv_entity_vocab:
564
+ inv_entity_vocab[u] = entity_cnt
565
+ entity_cnt += 1
566
+ if v not in inv_entity_vocab:
567
+ inv_entity_vocab[v] = entity_cnt
568
+ entity_cnt += 1
569
+ if r not in inv_rel_vocab:
570
+ inv_rel_vocab[r] = rel_cnt
571
+ rel_cnt += 1
572
+ u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
573
+
574
+ triplets.append((u, v, r))
575
+
576
+ return {
577
+ "triplets": triplets,
578
+ "num_node": len(inv_entity_vocab), #entity_cnt,
579
+ "num_relation": rel_cnt,
580
+ "inv_entity_vocab": inv_entity_vocab,
581
+ "inv_rel_vocab": inv_rel_vocab
582
+ }
583
+
584
+ class WDsinger(SparserKG):
585
+ name = "WD-singer"
586
+
587
+ class NELL23k(SparserKG):
588
+ name = "NELL23K"
589
+
590
+ class FB15k237_10(SparserKG):
591
+ name = "FB15K-237-10"
592
+
593
+ class FB15k237_20(SparserKG):
594
+ name = "FB15K-237-20"
595
+
596
+ class FB15k237_50(SparserKG):
597
+ name = "FB15K-237-50"
598
+
599
+
600
+ class InductiveDataset(InMemoryDataset):
601
+
602
+ delimiter = None
603
+ # some datasets (4 from Hamaguchi et al and Indigo) have validation set based off the train graph, not inference
604
+ valid_on_inf = True #
605
+
606
+ def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, **kwargs):
607
+
608
+ self.version = str(version)
609
+ super().__init__(root, transform, pre_transform)
610
+ self.data, self.slices = torch.load(self.processed_paths[0])
611
+
612
+ def download(self):
613
+ for url, path in zip(self.urls, self.raw_paths):
614
+ download_path = download_url(url % self.version, self.raw_dir)
615
+ os.rename(download_path, path)
616
+
617
+ def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
618
+
619
+ triplets = []
620
+ entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
621
+
622
+ with open(triplet_file, "r", encoding="utf-8") as fin:
623
+ for l in fin:
624
+ u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
625
+ if u not in inv_entity_vocab:
626
+ inv_entity_vocab[u] = entity_cnt
627
+ entity_cnt += 1
628
+ if v not in inv_entity_vocab:
629
+ inv_entity_vocab[v] = entity_cnt
630
+ entity_cnt += 1
631
+ if r not in inv_rel_vocab:
632
+ inv_rel_vocab[r] = rel_cnt
633
+ rel_cnt += 1
634
+ u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
635
+
636
+ triplets.append((u, v, r))
637
+
638
+ return {
639
+ "triplets": triplets,
640
+ "num_node": len(inv_entity_vocab), #entity_cnt,
641
+ "num_relation": rel_cnt,
642
+ "inv_entity_vocab": inv_entity_vocab,
643
+ "inv_rel_vocab": inv_rel_vocab
644
+ }
645
+
646
+ def process(self):
647
+
648
+ train_files = self.raw_paths[:4]
649
+
650
+ train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
651
+ inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
652
+ valid_res = self.load_file(
653
+ train_files[2],
654
+ inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
655
+ inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
656
+ )
657
+ test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
658
+
659
+ num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
660
+ inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
661
+
662
+ train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
663
+
664
+ train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
665
+ train_target_etypes = torch.tensor([t[2] for t in train_edges])
666
+
667
+ train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
668
+ train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
669
+
670
+ inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
671
+ inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
672
+ inf_etypes = torch.tensor([t[2] for t in inf_graph])
673
+ inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
674
+
675
+ inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
676
+ inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
677
+
678
+ train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
679
+ target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
680
+ valid_data = Data(edge_index=inf_edges if self.valid_on_inf else train_fact_index,
681
+ edge_type=inf_etypes if self.valid_on_inf else train_fact_type,
682
+ num_nodes=inference_num_nodes if self.valid_on_inf else num_train_nodes,
683
+ target_edge_index=inf_valid_edges[:, :2].T,
684
+ target_edge_type=inf_valid_edges[:, 2],
685
+ num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
686
+ test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
687
+ target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
688
+
689
+ if self.pre_transform is not None:
690
+ train_data = self.pre_transform(train_data)
691
+ valid_data = self.pre_transform(valid_data)
692
+ test_data = self.pre_transform(test_data)
693
+
694
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
695
+
696
+ @property
697
+ def num_relations(self):
698
+ return int(self.data.edge_type.max()) + 1
699
+
700
+ @property
701
+ def raw_dir(self):
702
+ return os.path.join(self.root, self.name, self.version, "raw")
703
+
704
+ @property
705
+ def processed_dir(self):
706
+ return os.path.join(self.root, self.name, self.version, "processed")
707
+
708
+ @property
709
+ def raw_file_names(self):
710
+ return [
711
+ "transductive_train.txt", "inference_graph.txt", "inf_valid.txt", "inf_test.txt"
712
+ ]
713
+
714
+ @property
715
+ def processed_file_names(self):
716
+ return "data.pt"
717
+
718
+ def __repr__(self):
719
+ return "%s(%s)" % (self.name, self.version)
720
+
721
+
722
+ class IngramInductive(InductiveDataset):
723
+
724
+ @property
725
+ def raw_dir(self):
726
+ return os.path.join(self.root, "ingram", self.name, self.version, "raw")
727
+
728
+ @property
729
+ def processed_dir(self):
730
+ return os.path.join(self.root, "ingram", self.name, self.version, "processed")
731
+
732
+
733
+ class FBIngram(IngramInductive):
734
+
735
+ urls = [
736
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/train.txt",
737
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/msg.txt",
738
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/valid.txt",
739
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/test.txt",
740
+ ]
741
+ name = "fb"
742
+
743
+
744
+ class WKIngram(IngramInductive):
745
+
746
+ urls = [
747
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/train.txt",
748
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/msg.txt",
749
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/valid.txt",
750
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/test.txt",
751
+ ]
752
+ name = "wk"
753
+
754
+ class NLIngram(IngramInductive):
755
+
756
+ urls = [
757
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/train.txt",
758
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/msg.txt",
759
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/valid.txt",
760
+ "https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/test.txt",
761
+ ]
762
+ name = "nl"
763
+
764
+
765
+ class ILPC2022(InductiveDataset):
766
+
767
+ urls = [
768
+ "https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/train.txt",
769
+ "https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference.txt",
770
+ "https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_validation.txt",
771
+ "https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_test.txt",
772
+ ]
773
+
774
+ name = "ilpc2022"
775
+
776
+
777
+ class HM(InductiveDataset):
778
+ # benchmarks from Hamaguchi et al and Indigo BM
779
+
780
+ urls = [
781
+ "https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/train.txt",
782
+ "https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-graph.txt",
783
+ "https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/valid.txt",
784
+ "https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-fact.txt",
785
+ ]
786
+
787
+ name = "hm"
788
+ versions = {
789
+ '1k': "Hamaguchi-BM_both-1000",
790
+ '3k': "Hamaguchi-BM_both-3000",
791
+ '5k': "Hamaguchi-BM_both-5000",
792
+ 'indigo': "INDIGO-BM"
793
+ }
794
+ # in 4 HM graphs, the validation set is based off the training graph, so we'll adjust the dataset creation accordingly
795
+ valid_on_inf = False
796
+
797
+ def __init__(self, root, version, **kwargs):
798
+ version = self.versions[version]
799
+ super().__init__(root, version, **kwargs)
800
+
801
+ # HM datasets are a bit weird: validation set (based off the train graph) has a few hundred new nodes, so we need a custom processing
802
+ def process(self):
803
+
804
+ train_files = self.raw_paths[:4]
805
+
806
+ train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
807
+ inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
808
+ valid_res = self.load_file(
809
+ train_files[2],
810
+ inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
811
+ inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
812
+ )
813
+ test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
814
+
815
+ num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
816
+ inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
817
+
818
+ train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
819
+
820
+ train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
821
+ train_target_etypes = torch.tensor([t[2] for t in train_edges])
822
+
823
+ train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
824
+ train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
825
+
826
+ inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
827
+ inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
828
+ inf_etypes = torch.tensor([t[2] for t in inf_graph])
829
+ inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
830
+
831
+ inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
832
+ inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
833
+
834
+ train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
835
+ target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
836
+ valid_data = Data(edge_index=train_fact_index,
837
+ edge_type=train_fact_type,
838
+ num_nodes=valid_res["num_node"], # the only fix in this function
839
+ target_edge_index=inf_valid_edges[:, :2].T,
840
+ target_edge_type=inf_valid_edges[:, 2],
841
+ num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
842
+ test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
843
+ target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
844
+
845
+ if self.pre_transform is not None:
846
+ train_data = self.pre_transform(train_data)
847
+ valid_data = self.pre_transform(valid_data)
848
+ test_data = self.pre_transform(test_data)
849
+
850
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
851
+
852
+
853
+ class MTDEAInductive(InductiveDataset):
854
+
855
+ valid_on_inf = False
856
+ url = "https://reltrans.s3.us-east-2.amazonaws.com/MTDEA_data.zip"
857
+ base_name = "mtdea"
858
+
859
+ def __init__(self, root, version, **kwargs):
860
+
861
+ assert version in self.versions, f"unknown version {version} for {self.name}, available: {self.versions}"
862
+ super().__init__(root, version, **kwargs)
863
+
864
+ @property
865
+ def raw_dir(self):
866
+ return os.path.join(self.root, self.base_name, self.name, self.version, "raw")
867
+
868
+ @property
869
+ def processed_dir(self):
870
+ return os.path.join(self.root, self.base_name, self.name, self.version, "processed")
871
+
872
+ @property
873
+ def raw_file_names(self):
874
+ return [
875
+ "transductive_train.txt", "inference_graph.txt", "transductive_valid.txt", "inf_test.txt"
876
+ ]
877
+
878
+ def download(self):
879
+ base_path = os.path.join(self.root, self.base_name)
880
+ download_path = download_url(self.url, base_path)
881
+ extract_zip(download_path, base_path)
882
+ # unzip all datasets at once
883
+ for dsname in ['FBNELL', 'Metafam', 'WikiTopics-MT1', 'WikiTopics-MT2', 'WikiTopics-MT3', 'WikiTopics-MT4']:
884
+ cl = globals()[dsname.replace("-","")]
885
+ versions = cl.versions
886
+ for version in versions:
887
+ for oldname, newname in zip(['train.txt', 'observe.txt', 'valid.txt', 'test.txt'], self.raw_file_names):
888
+ foldername = cl.prefix % version + "-trans" if "transductive" in newname else cl.prefix % version + "-ind"
889
+ os.renames(
890
+ os.path.join(base_path, "MTDEA_datasets", dsname, foldername, oldname),
891
+ os.path.join(base_path, dsname, version, "raw", newname)
892
+ )
893
+ shutil.rmtree(os.path.join(base_path, "MTDEA_datasets"))
894
+
895
+ def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}, limit_vocab=False):
896
+
897
+ triplets = []
898
+ entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
899
+
900
+ # limit_vocab is for dropping triples with unseen head/tail not seen in the main entity_vocab
901
+ # can be used for FBNELL and MT3:art, other datasets seem to be ok and share num_nodes/num_relations in the train/inference graph
902
+ with open(triplet_file, "r", encoding="utf-8") as fin:
903
+ for l in fin:
904
+ u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
905
+ if u not in inv_entity_vocab:
906
+ if limit_vocab:
907
+ continue
908
+ inv_entity_vocab[u] = entity_cnt
909
+ entity_cnt += 1
910
+ if v not in inv_entity_vocab:
911
+ if limit_vocab:
912
+ continue
913
+ inv_entity_vocab[v] = entity_cnt
914
+ entity_cnt += 1
915
+ if r not in inv_rel_vocab:
916
+ if limit_vocab:
917
+ continue
918
+ inv_rel_vocab[r] = rel_cnt
919
+ rel_cnt += 1
920
+ u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
921
+
922
+ triplets.append((u, v, r))
923
+
924
+ return {
925
+ "triplets": triplets,
926
+ "num_node": entity_cnt,
927
+ "num_relation": rel_cnt,
928
+ "inv_entity_vocab": inv_entity_vocab,
929
+ "inv_rel_vocab": inv_rel_vocab
930
+ }
931
+
932
+ # special processes for MTDEA datasets for one particular fix in the validation set loading
933
+ def process(self):
934
+
935
+ train_files = self.raw_paths[:4]
936
+
937
+ train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
938
+ inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
939
+ valid_res = self.load_file(
940
+ train_files[2],
941
+ inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
942
+ inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"],
943
+ limit_vocab=True, # the 1st fix in this function compared to the superclass processor
944
+ )
945
+ test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
946
+
947
+ num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
948
+ inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
949
+
950
+ train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
951
+
952
+ train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
953
+ train_target_etypes = torch.tensor([t[2] for t in train_edges])
954
+
955
+ train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
956
+ train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
957
+
958
+ inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
959
+ inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
960
+ inf_etypes = torch.tensor([t[2] for t in inf_graph])
961
+ inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
962
+
963
+ inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
964
+ inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
965
+
966
+ train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
967
+ target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
968
+ valid_data = Data(edge_index=train_fact_index,
969
+ edge_type=train_fact_type,
970
+ num_nodes=valid_res["num_node"], # the 2nd fix in this function
971
+ target_edge_index=inf_valid_edges[:, :2].T,
972
+ target_edge_type=inf_valid_edges[:, 2],
973
+ num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
974
+ test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
975
+ target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
976
+
977
+ if self.pre_transform is not None:
978
+ train_data = self.pre_transform(train_data)
979
+ valid_data = self.pre_transform(valid_data)
980
+ test_data = self.pre_transform(test_data)
981
+
982
+ torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
983
+
984
+
985
+ class FBNELL(MTDEAInductive):
986
+
987
+ name = "FBNELL"
988
+ prefix = "%s"
989
+ versions = ["FBNELL_v1"]
990
+
991
+ def __init__(self, **kwargs):
992
+ kwargs.pop("version")
993
+ kwargs['version'] = self.versions[0]
994
+ super(FBNELL, self).__init__(**kwargs)
995
+
996
+
997
+ class Metafam(MTDEAInductive):
998
+
999
+ name = "Metafam"
1000
+ prefix = "%s"
1001
+ versions = ["Metafam"]
1002
+
1003
+ def __init__(self, **kwargs):
1004
+ kwargs.pop("version")
1005
+ kwargs['version'] = self.versions[0]
1006
+ super(Metafam, self).__init__(**kwargs)
1007
+
1008
+
1009
+ class WikiTopicsMT1(MTDEAInductive):
1010
+
1011
+ name = "WikiTopics-MT1"
1012
+ prefix = "wikidata_%sv1"
1013
+ versions = ['mt', 'health', 'tax']
1014
+
1015
+ def __init__(self, **kwargs):
1016
+ assert kwargs['version'] in self.versions, f"unknown version {kwargs['version']}, available: {self.versions}"
1017
+ super(WikiTopicsMT1, self).__init__(**kwargs)
1018
+
1019
+
1020
+ class WikiTopicsMT2(MTDEAInductive):
1021
+
1022
+ name = "WikiTopics-MT2"
1023
+ prefix = "wikidata_%sv1"
1024
+ versions = ['mt2', 'org', 'sci']
1025
+
1026
+ def __init__(self, **kwargs):
1027
+ super(WikiTopicsMT2, self).__init__(**kwargs)
1028
+
1029
+
1030
+ class WikiTopicsMT3(MTDEAInductive):
1031
+
1032
+ name = "WikiTopics-MT3"
1033
+ prefix = "wikidata_%sv2"
1034
+ versions = ['mt3', 'art', 'infra']
1035
+
1036
+ def __init__(self, **kwargs):
1037
+ super(WikiTopicsMT3, self).__init__(**kwargs)
1038
+
1039
+
1040
+ class WikiTopicsMT4(MTDEAInductive):
1041
+
1042
+ name = "WikiTopics-MT4"
1043
+ prefix = "wikidata_%sv2"
1044
+ versions = ['mt4', 'sci', 'health']
1045
+
1046
+ def __init__(self, **kwargs):
1047
+ super(WikiTopicsMT4, self).__init__(**kwargs)
1048
+
1049
+
1050
+ # a joint dataset for pre-training ULTRA on several graphs
1051
+ class JointDataset(InMemoryDataset):
1052
+
1053
+ datasets_map = {
1054
+ 'FB15k237': FB15k237,
1055
+ 'WN18RR': WN18RR,
1056
+ 'CoDExSmall': CoDExSmall,
1057
+ 'CoDExMedium': CoDExMedium,
1058
+ 'CoDExLarge': CoDExLarge,
1059
+ 'NELL995': NELL995,
1060
+ 'ConceptNet100k': ConceptNet100k,
1061
+ 'DBpedia100k': DBpedia100k,
1062
+ 'YAGO310': YAGO310,
1063
+ 'AristoV4': AristoV4,
1064
+ }
1065
+
1066
+ def __init__(self, root, graphs, transform=None, pre_transform=None):
1067
+
1068
+
1069
+ self.graphs = [self.datasets_map[ds](root=root) for ds in graphs]
1070
+ self.num_graphs = len(graphs)
1071
+ super().__init__(root, transform, pre_transform)
1072
+ self.data = torch.load(self.processed_paths[0])
1073
+
1074
+ @property
1075
+ def raw_dir(self):
1076
+ return os.path.join(self.root, "joint", f'{self.num_graphs}g', "raw")
1077
+
1078
+ @property
1079
+ def processed_dir(self):
1080
+ return os.path.join(self.root, "joint", f'{self.num_graphs}g', "processed")
1081
+
1082
+ @property
1083
+ def processed_file_names(self):
1084
+ return "data.pt"
1085
+
1086
+ def process(self):
1087
+
1088
+ train_data = [g[0] for g in self.graphs]
1089
+ valid_data = [g[1] for g in self.graphs]
1090
+ test_data = [g[2] for g in self.graphs]
1091
+ # filter_data = [
1092
+ # Data(edge_index=g.data.target_edge_index, edge_type=g.data.target_edge_type, num_nodes=g[0].num_nodes) for g in self.graphs
1093
+ # ]
1094
+
1095
+ torch.save((train_data, valid_data, test_data), self.processed_paths[0])
ultra/eval.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import distributed as dist
5
+ from torch.utils import data as torch_data
6
+ from torch_geometric.data import Data
7
+
8
+ from ultra import tasks, util
9
+
10
+
11
+ TRANSDUCTIVE = ("WordNet18RR", "RelLinkPredDataset", "CoDExSmall", "CoDExMedium", "CoDExLarge",
12
+ "YAGO310", "NELL995", "ConceptNet100k", "DBpedia100k", "Hetionet", "AristoV4",
13
+ "WDsinger", "NELL23k", "FB15k237_10", "FB15k237_20", "FB15k237_50")
14
+
15
+
16
+ def get_filtered_data(dataset, mode):
17
+ train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
18
+ ds_name = dataset.__class__.__name__
19
+
20
+ if ds_name in TRANSDUCTIVE:
21
+ filtered_data = Data(edge_index=dataset._data.target_edge_index, edge_type=dataset._data.target_edge_type, num_nodes=dataset[0].num_nodes)
22
+ else:
23
+ if "ILPC" in ds_name or "Ingram" in ds_name:
24
+ full_inference_edges = torch.cat([valid_data.edge_index, valid_data.target_edge_index, test_data.target_edge_index], dim=1)
25
+ full_inference_etypes = torch.cat([valid_data.edge_type, valid_data.target_edge_type, test_data.target_edge_type])
26
+ filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
27
+ else:
28
+ # test filtering graph: inference edges + test edges
29
+ full_inference_edges = torch.cat([test_data.edge_index, test_data.target_edge_index], dim=1)
30
+ full_inference_etypes = torch.cat([test_data.edge_type, test_data.target_edge_type])
31
+ if mode == "test":
32
+ filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
33
+ else:
34
+ # validation filtering graph: train edges + validation edges
35
+ filtered_data = Data(
36
+ edge_index=torch.cat([train_data.edge_index, valid_data.target_edge_index], dim=1),
37
+ edge_type=torch.cat([train_data.edge_type, valid_data.target_edge_type])
38
+ )
39
+
40
+ return filtered_data
41
+
42
+
43
+ @torch.no_grad()
44
+ def test(model, mode, dataset, batch_size=32, eval_metrics=["mrr", "hits@10"], gpus=None, return_metrics=False):
45
+ logger = util.get_root_logger()
46
+ test_data = dataset[1] if mode == "valid" else dataset[2]
47
+ filtered_data = get_filtered_data(dataset, mode)
48
+
49
+ device = util.get_devices(gpus)
50
+ world_size = util.get_world_size()
51
+ rank = util.get_rank()
52
+
53
+ test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
54
+ sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
55
+ test_loader = torch_data.DataLoader(test_triplets, batch_size, sampler=sampler)
56
+
57
+ model.eval()
58
+ rankings = []
59
+ num_negatives = []
60
+ tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets
61
+ for batch in test_loader:
62
+ t_batch, h_batch = tasks.all_negative(test_data, batch)
63
+ t_pred = model(test_data, t_batch)
64
+ h_pred = model(test_data, h_batch)
65
+
66
+ if filtered_data is None:
67
+ t_mask, h_mask = tasks.strict_negative_mask(test_data, batch)
68
+ else:
69
+ t_mask, h_mask = tasks.strict_negative_mask(filtered_data, batch)
70
+ pos_h_index, pos_t_index, pos_r_index = batch.t()
71
+ t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
72
+ h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
73
+ num_t_negative = t_mask.sum(dim=-1)
74
+ num_h_negative = h_mask.sum(dim=-1)
75
+
76
+ rankings += [t_ranking, h_ranking]
77
+ num_negatives += [num_t_negative, num_h_negative]
78
+
79
+ tail_rankings += [t_ranking]
80
+ num_tail_negs += [num_t_negative]
81
+
82
+ ranking = torch.cat(rankings)
83
+ num_negative = torch.cat(num_negatives)
84
+ all_size = torch.zeros(world_size, dtype=torch.long, device=device)
85
+ all_size[rank] = len(ranking)
86
+
87
+ # ugly repetitive code for tail-only ranks processing
88
+ tail_ranking = torch.cat(tail_rankings)
89
+ num_tail_neg = torch.cat(num_tail_negs)
90
+ all_size_t = torch.zeros(world_size, dtype=torch.long, device=device)
91
+ all_size_t[rank] = len(tail_ranking)
92
+ if world_size > 1:
93
+ dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
94
+ dist.all_reduce(all_size_t, op=dist.ReduceOp.SUM)
95
+
96
+ # obtaining all ranks
97
+ cum_size = all_size.cumsum(0)
98
+ all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
99
+ all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
100
+ all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
101
+ all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
102
+
103
+ # the same for tails-only ranks
104
+ cum_size_t = all_size_t.cumsum(0)
105
+ all_ranking_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
106
+ all_ranking_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = tail_ranking
107
+ all_num_negative_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
108
+ all_num_negative_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = num_tail_neg
109
+ if world_size > 1:
110
+ dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
111
+ dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)
112
+ dist.all_reduce(all_ranking_t, op=dist.ReduceOp.SUM)
113
+ dist.all_reduce(all_num_negative_t, op=dist.ReduceOp.SUM)
114
+
115
+ metrics = {}
116
+ if rank == 0:
117
+ for metric in eval_metrics:
118
+ if "-tail" in metric:
119
+ _metric_name, direction = metric.split("-")
120
+ if direction != "tail":
121
+ raise ValueError("Only tail metric is supported in this mode")
122
+ _ranking = all_ranking_t
123
+ _num_neg = all_num_negative_t
124
+ else:
125
+ _ranking = all_ranking
126
+ _num_neg = all_num_negative
127
+ _metric_name = metric
128
+
129
+ if _metric_name == "mr":
130
+ score = _ranking.float().mean()
131
+ elif _metric_name == "mrr":
132
+ score = (1 / _ranking.float()).mean()
133
+ elif _metric_name.startswith("hits@"):
134
+ values = _metric_name[5:].split("_")
135
+ threshold = int(values[0])
136
+ if len(values) > 1:
137
+ num_sample = int(values[1])
138
+ # unbiased estimation
139
+ fp_rate = (_ranking - 1).float() / _num_neg
140
+ score = 0
141
+ for i in range(threshold):
142
+ # choose i false positive from num_sample - 1 negatives
143
+ num_comb = math.factorial(num_sample - 1) / \
144
+ math.factorial(i) / math.factorial(num_sample - i - 1)
145
+ score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
146
+ score = score.mean()
147
+ else:
148
+ score = (_ranking <= threshold).float().mean()
149
+ logger.warning("%s: %g" % (metric, score))
150
+ metrics[metric] = score
151
+ mrr = (1 / all_ranking.float()).mean()
152
+
153
+ return mrr if not return_metrics else metrics
ultra/layers.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torch_scatter import scatter
5
+
6
+ from torch_geometric.nn.conv import MessagePassing
7
+ from torch_geometric.utils import degree
8
+ from typing import Tuple
9
+
10
+
11
+ class GeneralizedRelationalConv(MessagePassing):
12
+
13
+ eps = 1e-6
14
+
15
+ message2mul = {
16
+ "transe": "add",
17
+ "distmult": "mul",
18
+ }
19
+
20
+ # TODO for compile() - doesn't work currently
21
+ # propagate_type = {"edge_index": torch.LongTensor, "size": Tuple[int, int]}
22
+
23
+ def __init__(self, input_dim, output_dim, num_relation, query_input_dim, message_func="distmult",
24
+ aggregate_func="pna", layer_norm=False, activation="relu", dependent=False, project_relations=False):
25
+ super(GeneralizedRelationalConv, self).__init__()
26
+ self.input_dim = input_dim
27
+ self.output_dim = output_dim
28
+ self.num_relation = num_relation
29
+ self.query_input_dim = query_input_dim
30
+ self.message_func = message_func
31
+ self.aggregate_func = aggregate_func
32
+ self.dependent = dependent
33
+ self.project_relations = project_relations
34
+
35
+ if layer_norm:
36
+ self.layer_norm = nn.LayerNorm(output_dim)
37
+ else:
38
+ self.layer_norm = None
39
+ if isinstance(activation, str):
40
+ self.activation = getattr(F, activation)
41
+ else:
42
+ self.activation = activation
43
+
44
+ if self.aggregate_func == "pna":
45
+ self.linear = nn.Linear(input_dim * 13, output_dim)
46
+ else:
47
+ self.linear = nn.Linear(input_dim * 2, output_dim)
48
+
49
+ if dependent:
50
+ # obtain relation embeddings as a projection of the query relation
51
+ self.relation_linear = nn.Linear(query_input_dim, num_relation * input_dim)
52
+ else:
53
+ if not self.project_relations:
54
+ # relation embeddings as an independent embedding matrix per each layer
55
+ self.relation = nn.Embedding(num_relation, input_dim)
56
+ else:
57
+ # will be initialized after the pass over relation graph
58
+ self.relation = None
59
+ self.relation_projection = nn.Sequential(
60
+ nn.Linear(input_dim, input_dim),
61
+ nn.ReLU(),
62
+ nn.Linear(input_dim, input_dim)
63
+ )
64
+
65
+
66
+ def forward(self, input, query, boundary, edge_index, edge_type, size, edge_weight=None):
67
+ batch_size = len(query)
68
+
69
+ if self.dependent:
70
+ # layer-specific relation features as a projection of input "query" (relation) embeddings
71
+ relation = self.relation_linear(query).view(batch_size, self.num_relation, self.input_dim)
72
+ else:
73
+ if not self.project_relations:
74
+ # layer-specific relation features as a special embedding matrix unique to each layer
75
+ relation = self.relation.weight.expand(batch_size, -1, -1)
76
+ else:
77
+ # NEW and only change:
78
+ # projecting relation features to unique features for this layer, then resizing for the current batch
79
+ relation = self.relation_projection(self.relation)
80
+ if edge_weight is None:
81
+ edge_weight = torch.ones(len(edge_type), device=input.device)
82
+
83
+ # note that we send the initial boundary condition (node states at layer0) to the message passing
84
+ # correspond to Eq.6 on p5 in https://arxiv.org/pdf/2106.06935.pdf
85
+ output = self.propagate(input=input, relation=relation, boundary=boundary, edge_index=edge_index,
86
+ edge_type=edge_type, size=size, edge_weight=edge_weight)
87
+ return output
88
+
89
+ def propagate(self, edge_index, size=None, **kwargs):
90
+ if kwargs["edge_weight"].requires_grad or self.message_func == "rotate":
91
+ # the rspmm cuda kernel only works for TransE and DistMult message functions
92
+ # otherwise we invoke separate message & aggregate functions
93
+ return super(GeneralizedRelationalConv, self).propagate(edge_index, size, **kwargs)
94
+
95
+ for hook in self._propagate_forward_pre_hooks.values():
96
+ res = hook(self, (edge_index, size, kwargs))
97
+ if res is not None:
98
+ edge_index, size, kwargs = res
99
+
100
+ # in newer PyG,
101
+ # __check_input__ -> _check_input()
102
+ # __collect__ -> _collect()
103
+ # __fused_user_args__ -> _fuser_user_args
104
+ size = self._check_input(edge_index, size)
105
+ coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs)
106
+
107
+ msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
108
+ for hook in self._message_and_aggregate_forward_pre_hooks.values():
109
+ res = hook(self, (edge_index, msg_aggr_kwargs))
110
+ if res is not None:
111
+ edge_index, msg_aggr_kwargs = res
112
+ out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
113
+ for hook in self._message_and_aggregate_forward_hooks.values():
114
+ res = hook(self, (edge_index, msg_aggr_kwargs), out)
115
+ if res is not None:
116
+ out = res
117
+
118
+ update_kwargs = self.inspector.distribute("update", coll_dict)
119
+ out = self.update(out, **update_kwargs)
120
+
121
+ for hook in self._propagate_forward_hooks.values():
122
+ res = hook(self, (edge_index, size, kwargs), out)
123
+ if res is not None:
124
+ out = res
125
+
126
+ return out
127
+
128
+ def message(self, input_j, relation, boundary, edge_type):
129
+ relation_j = relation.index_select(self.node_dim, edge_type)
130
+
131
+ if self.message_func == "transe":
132
+ message = input_j + relation_j
133
+ elif self.message_func == "distmult":
134
+ message = input_j * relation_j
135
+ elif self.message_func == "rotate":
136
+ x_j_re, x_j_im = input_j.chunk(2, dim=-1)
137
+ r_j_re, r_j_im = relation_j.chunk(2, dim=-1)
138
+ message_re = x_j_re * r_j_re - x_j_im * r_j_im
139
+ message_im = x_j_re * r_j_im + x_j_im * r_j_re
140
+ message = torch.cat([message_re, message_im], dim=-1)
141
+ else:
142
+ raise ValueError("Unknown message function `%s`" % self.message_func)
143
+
144
+ # augment messages with the boundary condition
145
+ message = torch.cat([message, boundary], dim=self.node_dim) # (num_edges + num_nodes, batch_size, input_dim)
146
+
147
+ return message
148
+
149
+ def aggregate(self, input, edge_weight, index, dim_size):
150
+ # augment aggregation index with self-loops for the boundary condition
151
+ index = torch.cat([index, torch.arange(dim_size, device=input.device)]) # (num_edges + num_nodes,)
152
+ edge_weight = torch.cat([edge_weight, torch.ones(dim_size, device=input.device)])
153
+ shape = [1] * input.ndim
154
+ shape[self.node_dim] = -1
155
+ edge_weight = edge_weight.view(shape)
156
+
157
+ if self.aggregate_func == "pna":
158
+ mean = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
159
+ sq_mean = scatter(input ** 2 * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
160
+ max = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="max")
161
+ min = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="min")
162
+ std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
163
+ features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
164
+ features = features.flatten(-2)
165
+ degree_out = degree(index, dim_size).unsqueeze(0).unsqueeze(-1)
166
+ scale = degree_out.log()
167
+ scale = scale / scale.mean()
168
+ scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1)
169
+ output = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2)
170
+ else:
171
+ output = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size,
172
+ reduce=self.aggregate_func)
173
+
174
+ return output
175
+
176
+ def message_and_aggregate(self, edge_index, input, relation, boundary, edge_type, edge_weight, index, dim_size):
177
+ # fused computation of message and aggregate steps with the custom rspmm cuda kernel
178
+ # speed up computation by several times
179
+ # reduce memory complexity from O(|E|d) to O(|V|d), so we can apply it to larger graphs
180
+ from ultra.rspmm.rspmm import generalized_rspmm
181
+
182
+ batch_size, num_node = input.shape[:2]
183
+ input = input.transpose(0, 1).flatten(1)
184
+ relation = relation.transpose(0, 1).flatten(1)
185
+ boundary = boundary.transpose(0, 1).flatten(1)
186
+ degree_out = degree(index, dim_size).unsqueeze(-1) + 1
187
+
188
+ if self.message_func in self.message2mul:
189
+ mul = self.message2mul[self.message_func]
190
+ else:
191
+ raise ValueError("Unknown message function `%s`" % self.message_func)
192
+ if self.aggregate_func == "sum":
193
+ update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
194
+ update = update + boundary
195
+ elif self.aggregate_func == "mean":
196
+ update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
197
+ update = (update + boundary) / degree_out
198
+ elif self.aggregate_func == "max":
199
+ update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
200
+ update = torch.max(update, boundary)
201
+ elif self.aggregate_func == "pna":
202
+ # we use PNA with 4 aggregators (mean / max / min / std)
203
+ # and 3 scalars (identity / log degree / reciprocal of log degree)
204
+ sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
205
+ sq_sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation ** 2, input ** 2, sum="add",
206
+ mul=mul)
207
+ max = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
208
+ min = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="min", mul=mul)
209
+ mean = (sum + boundary) / degree_out
210
+ sq_mean = (sq_sum + boundary ** 2) / degree_out
211
+ max = torch.max(max, boundary)
212
+ min = torch.min(min, boundary) # (node, batch_size * input_dim)
213
+ std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
214
+ features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
215
+ features = features.flatten(-2) # (node, batch_size * input_dim * 4)
216
+ scale = degree_out.log()
217
+ scale = scale / scale.mean()
218
+ scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) # (node, 3)
219
+ update = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) # (node, batch_size * input_dim * 4 * 3)
220
+ else:
221
+ raise ValueError("Unknown aggregation function `%s`" % self.aggregate_func)
222
+
223
+ update = update.view(num_node, batch_size, -1).transpose(0, 1)
224
+ return update
225
+
226
+ def update(self, update, input):
227
+ # node update as a function of old states (input) and this layer output (update)
228
+ output = self.linear(torch.cat([input, update], dim=-1))
229
+ if self.layer_norm:
230
+ output = self.layer_norm(output)
231
+ if self.activation:
232
+ output = self.activation(output)
233
+ return output
234
+
ultra/models.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from . import tasks, layers
5
+ from ultra.base_nbfnet import BaseNBFNet
6
+
7
+ class Ultra(nn.Module):
8
+
9
+ def __init__(self, rel_model_cfg, entity_model_cfg):
10
+ # kept that because super Ultra sounds cool
11
+ super(Ultra, self).__init__()
12
+
13
+ self.relation_model = RelNBFNet(**rel_model_cfg)
14
+ self.entity_model = EntityNBFNet(**entity_model_cfg)
15
+
16
+
17
+ def forward(self, data, batch):
18
+
19
+ # batch shape: (bs, 1+num_negs, 3)
20
+ # relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
21
+ query_rels = batch[:, 0, 2]
22
+ relation_representations = self.relation_model(data.relation_graph, query=query_rels)
23
+ score = self.entity_model(data, relation_representations, batch)
24
+
25
+ return score
26
+
27
+
28
+ # NBFNet to work on the graph of relations with 4 fundamental interactions
29
+ # Doesn't have the final projection MLP from hidden dim -> 1, returns all node representations
30
+ # of shape [bs, num_rel, hidden]
31
+ class RelNBFNet(BaseNBFNet):
32
+
33
+ def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs):
34
+ super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
35
+
36
+ self.layers = nn.ModuleList()
37
+ for i in range(len(self.dims) - 1):
38
+ self.layers.append(
39
+ layers.GeneralizedRelationalConv(
40
+ self.dims[i], self.dims[i + 1], num_relation,
41
+ self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
42
+ self.activation, dependent=False)
43
+ )
44
+
45
+ if self.concat_hidden:
46
+ feature_dim = sum(hidden_dims) + input_dim
47
+ self.mlp = nn.Sequential(
48
+ nn.Linear(feature_dim, feature_dim),
49
+ nn.ReLU(),
50
+ nn.Linear(feature_dim, input_dim)
51
+ )
52
+
53
+
54
+ def bellmanford(self, data, h_index, separate_grad=False):
55
+ batch_size = len(h_index)
56
+
57
+ # initialize initial nodes (relations of interest in the batcj) with all ones
58
+ query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
59
+ index = h_index.unsqueeze(-1).expand_as(query)
60
+
61
+ # initial (boundary) condition - initialize all node states as zeros
62
+ boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
63
+ #boundary = torch.zeros(data.num_nodes, *query.shape, device=h_index.device)
64
+ # Indicator function: by the scatter operation we put ones as init features of source (index) nodes
65
+ boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
66
+ size = (data.num_nodes, data.num_nodes)
67
+ edge_weight = torch.ones(data.num_edges, device=h_index.device)
68
+
69
+ hiddens = []
70
+ edge_weights = []
71
+ layer_input = boundary
72
+
73
+ for layer in self.layers:
74
+ # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
75
+ hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
76
+ if self.short_cut and hidden.shape == layer_input.shape:
77
+ # residual connection here
78
+ hidden = hidden + layer_input
79
+ hiddens.append(hidden)
80
+ edge_weights.append(edge_weight)
81
+ layer_input = hidden
82
+
83
+ # original query (relation type) embeddings
84
+ node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
85
+ if self.concat_hidden:
86
+ output = torch.cat(hiddens + [node_query], dim=-1)
87
+ output = self.mlp(output)
88
+ else:
89
+ output = hiddens[-1]
90
+
91
+ return {
92
+ "node_feature": output,
93
+ "edge_weights": edge_weights,
94
+ }
95
+
96
+ def forward(self, rel_graph, query):
97
+
98
+ # message passing and updated node representations (that are in fact relations)
99
+ output = self.bellmanford(rel_graph, h_index=query)["node_feature"] # (batch_size, num_nodes, hidden_dim)
100
+
101
+ return output
102
+
103
+
104
+ class EntityNBFNet(BaseNBFNet):
105
+
106
+ def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
107
+
108
+ # dummy num_relation = 1 as we won't use it in the NBFNet layer
109
+ super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
110
+
111
+ self.layers = nn.ModuleList()
112
+ for i in range(len(self.dims) - 1):
113
+ self.layers.append(
114
+ layers.GeneralizedRelationalConv(
115
+ self.dims[i], self.dims[i + 1], num_relation,
116
+ self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
117
+ self.activation, dependent=False, project_relations=True)
118
+ )
119
+
120
+ feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim
121
+ self.mlp = nn.Sequential()
122
+ mlp = []
123
+ for i in range(self.num_mlp_layers - 1):
124
+ mlp.append(nn.Linear(feature_dim, feature_dim))
125
+ mlp.append(nn.ReLU())
126
+ mlp.append(nn.Linear(feature_dim, 1))
127
+ self.mlp = nn.Sequential(*mlp)
128
+
129
+
130
+ def bellmanford(self, data, h_index, r_index, separate_grad=False):
131
+ batch_size = len(r_index)
132
+
133
+ # initialize queries (relation types of the given triples)
134
+ query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
135
+ index = h_index.unsqueeze(-1).expand_as(query)
136
+
137
+ # initial (boundary) condition - initialize all node states as zeros
138
+ boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
139
+ # by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
140
+ boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
141
+
142
+ size = (data.num_nodes, data.num_nodes)
143
+ edge_weight = torch.ones(data.num_edges, device=h_index.device)
144
+
145
+ hiddens = []
146
+ edge_weights = []
147
+ layer_input = boundary
148
+
149
+ for layer in self.layers:
150
+
151
+ # for visualization
152
+ if separate_grad:
153
+ edge_weight = edge_weight.clone().requires_grad_()
154
+
155
+ # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
156
+ hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
157
+ if self.short_cut and hidden.shape == layer_input.shape:
158
+ # residual connection here
159
+ hidden = hidden + layer_input
160
+ hiddens.append(hidden)
161
+ edge_weights.append(edge_weight)
162
+ layer_input = hidden
163
+
164
+ # original query (relation type) embeddings
165
+ node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
166
+ if self.concat_hidden:
167
+ output = torch.cat(hiddens + [node_query], dim=-1)
168
+ else:
169
+ output = torch.cat([hiddens[-1], node_query], dim=-1)
170
+
171
+ return {
172
+ "node_feature": output,
173
+ "edge_weights": edge_weights,
174
+ }
175
+
176
+ def forward(self, data, relation_representations, batch):
177
+ h_index, t_index, r_index = batch.unbind(-1)
178
+
179
+ # initial query representations are those from the relation graph
180
+ self.query = relation_representations
181
+
182
+ # initialize relations in each NBFNet layer (with uinque projection internally)
183
+ for layer in self.layers:
184
+ layer.relation = relation_representations
185
+
186
+ if self.training:
187
+ # Edge dropout in the training mode
188
+ # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
189
+ # to make NBFNet iteration learn non-trivial paths
190
+ data = self.remove_easy_edges(data, h_index, t_index, r_index)
191
+
192
+ shape = h_index.shape
193
+ # turn all triples in a batch into a tail prediction mode
194
+ h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
195
+ assert (h_index[:, [0]] == h_index).all()
196
+ assert (r_index[:, [0]] == r_index).all()
197
+
198
+ # message passing and updated node representations
199
+ output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
200
+ feature = output["node_feature"]
201
+ index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
202
+ # extract representations of tail entities from the updated node states
203
+ feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
204
+
205
+ # probability logit for each tail node in the batch
206
+ # (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
207
+ score = self.mlp(feature).squeeze(-1)
208
+ return score.view(shape)
209
+
210
+
211
+
212
+
213
+
214
+
ultra/rspmm/rspmm.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch.backends.openmp
5
+ from torch import autograd
6
+ from torch.utils import cpp_extension
7
+
8
+ module = sys.modules[__name__]
9
+
10
+
11
+ class RSPMMAddMulFunction(autograd.Function):
12
+
13
+ @staticmethod
14
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
15
+ node_in, node_out = edge_index
16
+ key = node_in * (node_out.max() + 1) + node_out
17
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
18
+
19
+ if input.device.type == "cuda":
20
+ forward = rspmm.rspmm_add_mul_forward_cuda
21
+ else:
22
+ forward = rspmm.rspmm_add_mul_forward_cpu
23
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
24
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
25
+ return output
26
+
27
+ @staticmethod
28
+ def backward(ctx, output_grad):
29
+ if output_grad.device.type == "cuda":
30
+ backward = rspmm.rspmm_add_mul_backward_cuda
31
+ else:
32
+ backward = rspmm.rspmm_add_mul_backward_cpu
33
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
34
+ return None, None, weight_grad, relation_grad, input_grad
35
+
36
+
37
+ class RSPMMMinMulFunction(autograd.Function):
38
+
39
+ @staticmethod
40
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
41
+ node_in, node_out = edge_index
42
+ key = node_in * (node_out.max() + 1) + node_out
43
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
44
+
45
+ if input.device.type == "cuda":
46
+ forward = rspmm.rspmm_min_mul_forward_cuda
47
+ else:
48
+ forward = rspmm.rspmm_min_mul_forward_cpu
49
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
50
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
51
+ return output
52
+
53
+ @staticmethod
54
+ def backward(ctx, output_grad):
55
+ if output_grad.device.type == "cuda":
56
+ backward = rspmm.rspmm_min_mul_backward_cuda
57
+ else:
58
+ backward = rspmm.rspmm_min_mul_backward_cpu
59
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
60
+ return None, None, weight_grad, relation_grad, input_grad
61
+
62
+
63
+ class RSPMMMaxMulFunction(autograd.Function):
64
+
65
+ @staticmethod
66
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
67
+ node_in, node_out = edge_index
68
+ key = node_in * (node_out.max() + 1) + node_out
69
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
70
+
71
+ if input.device.type == "cuda":
72
+ forward = rspmm.rspmm_max_mul_forward_cuda
73
+ else:
74
+ forward = rspmm.rspmm_max_mul_forward_cpu
75
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
76
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
77
+ return output
78
+
79
+ @staticmethod
80
+ def backward(ctx, output_grad):
81
+ if output_grad.device.type == "cuda":
82
+ backward = rspmm.rspmm_max_mul_backward_cuda
83
+ else:
84
+ backward = rspmm.rspmm_max_mul_backward_cpu
85
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
86
+ return None, None, weight_grad, relation_grad, input_grad
87
+
88
+
89
+ class RSPMMAddAddFunction(autograd.Function):
90
+
91
+ @staticmethod
92
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
93
+ node_in, node_out = edge_index
94
+ key = node_in * (node_out.max() + 1) + node_out
95
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
96
+
97
+ if input.device.type == "cuda":
98
+ forward = rspmm.rspmm_add_add_forward_cuda
99
+ else:
100
+ forward = rspmm.rspmm_add_add_forward_cpu
101
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
102
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
103
+ return output
104
+
105
+ @staticmethod
106
+ def backward(ctx, output_grad):
107
+ if output_grad.device.type == "cuda":
108
+ backward = rspmm.rspmm_add_add_backward_cuda
109
+ else:
110
+ backward = rspmm.rspmm_add_add_backward_cpu
111
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
112
+ return None, None, weight_grad, relation_grad, input_grad
113
+
114
+
115
+ class RSPMMMinAddFunction(autograd.Function):
116
+
117
+ @staticmethod
118
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
119
+ node_in, node_out = edge_index
120
+ key = node_in * (node_out.max() + 1) + node_out
121
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
122
+
123
+ if input.device.type == "cuda":
124
+ forward = rspmm.rspmm_min_add_forward_cuda
125
+ else:
126
+ forward = rspmm.rspmm_min_add_forward_cpu
127
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
128
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
129
+ return output
130
+
131
+ @staticmethod
132
+ def backward(ctx, output_grad):
133
+ if output_grad.device.type == "cuda":
134
+ backward = rspmm.rspmm_min_add_backward_cuda
135
+ else:
136
+ backward = rspmm.rspmm_min_add_backward_cpu
137
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
138
+ return None, None, weight_grad, relation_grad, input_grad
139
+
140
+
141
+ class RSPMMMaxAddFunction(autograd.Function):
142
+
143
+ @staticmethod
144
+ def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
145
+ node_in, node_out = edge_index
146
+ key = node_in * (node_out.max() + 1) + node_out
147
+ assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
148
+
149
+ if input.device.type == "cuda":
150
+ forward = rspmm.rspmm_max_add_forward_cuda
151
+ else:
152
+ forward = rspmm.rspmm_max_add_forward_cpu
153
+ output = forward(edge_index, edge_type, edge_weight, relation, input)
154
+ ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
155
+ return output
156
+
157
+ @staticmethod
158
+ def backward(ctx, output_grad):
159
+ if output_grad.device.type == "cuda":
160
+ backward = rspmm.rspmm_max_add_backward_cuda
161
+ else:
162
+ backward = rspmm.rspmm_max_add_backward_cpu
163
+ weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
164
+ return None, None, weight_grad, relation_grad, input_grad
165
+
166
+
167
+ def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul="mul"):
168
+ name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize())
169
+ if not hasattr(module, name):
170
+ raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`"
171
+ % (sum, mul))
172
+ Function = getattr(module, name)
173
+
174
+ node_in, node_out = edge_index
175
+ key = node_in * (node_out.max() + 1) + node_out
176
+ order = key.argsort()
177
+
178
+ return Function.apply(edge_index[:, order], edge_type[order], edge_weight[order], relation, input)
179
+
180
+
181
+ def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
182
+ if extra_cflags is None:
183
+ extra_cflags = ["-Ofast"]
184
+ if torch.backends.openmp.is_available():
185
+ extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
186
+ else:
187
+ extra_cflags.append("-DAT_PARALLEL_NATIVE")
188
+ if extra_cuda_cflags is None:
189
+ if torch.cuda.is_available():
190
+ extra_cuda_cflags = ["-O3"]
191
+ extra_cflags.append("-DCUDA_OP")
192
+ else:
193
+ new_sources = []
194
+ for source in sources:
195
+ if not cpp_extension._is_cuda_file(source):
196
+ new_sources.append(source)
197
+ sources = new_sources
198
+
199
+ return cpp_extension.load(name, sources, extra_cflags, extra_cuda_cflags, **kwargs)
200
+
201
+
202
+ print("Load rspmm extension. This may take a while...")
203
+ path = os.path.join(os.path.dirname(__file__), "source")
204
+ rspmm = load_extension("rspmm", [os.path.join(path, "rspmm.cpp"), os.path.join(path, "rspmm.cu")])
ultra/rspmm/source/operator.cuh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <limits>
4
+
5
+ #ifdef __CUDA_ARCH__
6
+ #define HOST_DEVICE __host__ __device__
7
+ #else
8
+ #define HOST_DEVICE
9
+ #endif
10
+
11
+ namespace at {
12
+
13
+ template <class scalar_t>
14
+ struct BinaryAdd {
15
+ HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
16
+ return x + y;
17
+ }
18
+
19
+ HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
20
+ return 1;
21
+ }
22
+
23
+ HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
24
+ return 1;
25
+ }
26
+ };
27
+
28
+ template <class scalar_t>
29
+ struct BinaryMul {
30
+ HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
31
+ return x * y;
32
+ }
33
+
34
+ HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
35
+ return y;
36
+ }
37
+
38
+ HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
39
+ return x;
40
+ }
41
+ };
42
+
43
+ template <class scalar_t>
44
+ struct NaryAdd {
45
+ HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
46
+ return result + x;
47
+ }
48
+
49
+ HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
50
+ return 1;
51
+ }
52
+
53
+ static constexpr scalar_t zero = 0;
54
+ };
55
+
56
+ template <class scalar_t>
57
+ struct NaryMin {
58
+ HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
59
+ return result < x ? result : x;
60
+ }
61
+
62
+ HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
63
+ return result == x ? 1 : 0;
64
+ }
65
+
66
+ static constexpr scalar_t zero = std::numeric_limits<scalar_t>::max();
67
+ };
68
+
69
+ template <class scalar_t>
70
+ struct NaryMax {
71
+ HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
72
+ return result > x ? result : x;
73
+ }
74
+
75
+ HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
76
+ return result == x ? 1 : 0;
77
+ }
78
+
79
+ static constexpr scalar_t zero = std::numeric_limits<scalar_t>::lowest();
80
+ };
81
+
82
+ } // namespace at
ultra/rspmm/source/rspmm.cpp ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <mutex>
2
+
3
+ #include <ATen/Parallel.h>
4
+
5
+ #include "operator.cuh"
6
+ #include "rspmm.h"
7
+
8
+ namespace at {
9
+
10
+ // In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h
11
+ // which are not explicitly included
12
+ // This is fixed in some new PyTorch release
13
+ using namespace at::internal;
14
+
15
+ void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
16
+ const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg) {
17
+ checkDim(c, edge_index_arg, 2);
18
+ checkDim(c, edge_type_arg, 1);
19
+ checkDim(c, edge_weight_arg, 1);
20
+ checkDim(c, relation_arg, 2);
21
+ checkDim(c, input_arg, 2);
22
+ checkSameType(c, edge_index_arg, edge_type_arg);
23
+ checkAllSameType(c, {edge_weight_arg, relation_arg, input_arg});
24
+ checkSize(c, edge_index_arg, 0, 2);
25
+ checkSize(c, edge_type_arg, {edge_index_arg->size(1)});
26
+ checkSize(c, edge_weight_arg, {edge_index_arg->size(1)});
27
+ checkSize(c, relation_arg, 1, input_arg->size(1));
28
+ }
29
+
30
+ void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
31
+ const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
32
+ const TensorArg &output_arg, const TensorArg &output_grad_arg) {
33
+ rspmm_forward_check(c, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
34
+ checkDim(c, output_arg, 2);
35
+ checkSameSize(c, output_arg, output_grad_arg);
36
+ checkAllSameType(c, {input_arg, output_arg, output_grad_arg});
37
+ checkSize(c, output_arg, 1, input_arg->size(1));
38
+ }
39
+
40
+ Tensor ind2ptr(const Tensor &index, int size) {
41
+ // scatter_add is super slow for int64, due to non-hardware atomic operations
42
+ // use int32 instead
43
+ Tensor num_per_index = at::zeros({size}, index.options().dtype(at::ScalarType::Int));
44
+ num_per_index.scatter_add_(0, index, at::ones(index.sizes(), num_per_index.options()));
45
+ num_per_index = num_per_index.toType(at::ScalarType::Long);
46
+ Tensor pointer = num_per_index.cumsum(0) - num_per_index;
47
+ return pointer;
48
+ }
49
+
50
+ template <class scalar_t, class NaryOp, class BinaryOp>
51
+ void rspmm_forward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
52
+ const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
53
+ scalar_t *output,
54
+ int64_t num_row, int64_t nnz, int64_t dim) {
55
+ parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
56
+ for (int64_t row = row_start; row < row_end; row++) {
57
+ for (int64_t d = 0; d < dim; d++)
58
+ output[row * dim + d] = NaryOp::zero;
59
+
60
+ int64_t ptr_start = row_ptr[row];
61
+ int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
62
+ for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
63
+ int64_t col = col_ind[ptr];
64
+ int64_t layer = layer_ind[ptr];
65
+ scalar_t w = weight[ptr];
66
+ for (int64_t d = 0; d < dim; d++) {
67
+ scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
68
+ scalar_t y = w * x;
69
+ scalar_t &out = output[row * dim + d];
70
+ out = NaryOp::forward(out, y);
71
+ }
72
+ }
73
+ }
74
+ });
75
+ }
76
+
77
+ template <class scalar_t, class NaryOp, class BinaryOp>
78
+ void rspmm_backward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
79
+ const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
80
+ const scalar_t *output, const scalar_t *output_grad,
81
+ scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
82
+ int64_t num_row, int64_t nnz, int64_t dim,
83
+ std::vector<std::mutex> &relation_mutex, std::vector<std::mutex> &input_mutex) {
84
+ parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
85
+ for (int64_t row = row_start; row < row_end; row++) {
86
+ int64_t ptr_start = row_ptr[row];
87
+ int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
88
+ for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
89
+ int64_t col = col_ind[ptr];
90
+ int64_t layer = layer_ind[ptr];
91
+ scalar_t w = weight[ptr];
92
+ scalar_t w_grad = 0;
93
+ for (int64_t d = 0; d < dim; d++) {
94
+ scalar_t rel = relation[layer * dim + d];
95
+ scalar_t in = input[col * dim + d];
96
+ scalar_t out = output[row * dim + d];
97
+ scalar_t out_grad = output_grad[row * dim + d];
98
+ scalar_t x = BinaryOp::forward(rel, in);
99
+ scalar_t y = w * x;
100
+ scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
101
+ scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
102
+ scalar_t dout_dy = NaryOp::backward(out, y);
103
+ scalar_t dy_dw = x;
104
+ scalar_t dy_dx = w;
105
+ w_grad += out_grad * dout_dy * dy_dw;
106
+ {
107
+ std::lock_guard<std::mutex> lock(relation_mutex[layer * dim + d]);
108
+ relation_grad[layer * dim + d] += out_grad * dout_dy * dy_dx * dx_drel;
109
+ }
110
+ {
111
+ std::lock_guard<std::mutex> lock(input_mutex[col * dim + d]);
112
+ input_grad[col * dim + d] += out_grad * dout_dy * dy_dx * dx_din;
113
+ }
114
+ }
115
+ weight_grad[ptr] = w_grad;
116
+ }
117
+ }
118
+ });
119
+ }
120
+
121
+ template <template<class> class NaryOp, template<class> class BinaryOp>
122
+ Tensor rspmm_forward_cpu(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
123
+ const Tensor &relation_, const Tensor &input_) {
124
+ constexpr const char *fn_name = "rspmm_forward_cpu";
125
+ TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
126
+ edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
127
+ input_arg(input_, "input", 5);
128
+
129
+ rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
130
+ checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_}, kCPU);
131
+
132
+ const Tensor edge_index = edge_index_.contiguous();
133
+ const Tensor edge_type = edge_type_.contiguous();
134
+ const Tensor edge_weight = edge_weight_.contiguous();
135
+ const Tensor relation = relation_.contiguous();
136
+ const Tensor input = input_.contiguous();
137
+
138
+ int64_t nnz = edge_index.size(0);
139
+ int64_t num_row = input.size(0);
140
+ int64_t dim = input.size(1);
141
+ Tensor output = at::empty({num_row, dim}, input.options());
142
+
143
+ Tensor row_ind = edge_index.select(0, 0);
144
+ Tensor row_ptr = ind2ptr(row_ind, num_row);
145
+ Tensor col_ind = edge_index.select(0, 1);
146
+ Tensor layer_ind = edge_type;
147
+
148
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cpu", [&] {
149
+ rspmm_forward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
150
+ row_ptr.data_ptr<int64_t>(),
151
+ col_ind.data_ptr<int64_t>(),
152
+ layer_ind.data_ptr<int64_t>(),
153
+ edge_weight.data_ptr<scalar_t>(),
154
+ relation.data_ptr<scalar_t>(),
155
+ input.data_ptr<scalar_t>(),
156
+ output.data_ptr<scalar_t>(),
157
+ num_row, nnz, dim
158
+ );
159
+ });
160
+
161
+ return output;
162
+ }
163
+
164
+ template <template<class> class NaryOp, template<class> class BinaryOp>
165
+ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cpu(
166
+ const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
167
+ const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
168
+ constexpr const char *fn_name = "rspmm_backward_cpu";
169
+ TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
170
+ edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
171
+ input_arg(input_, "input", 5), output_arg(output_, "output", 6),
172
+ output_grad_arg(output_grad_, "output_grad", 7);
173
+
174
+ rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
175
+ output_arg, output_grad_arg);
176
+ checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_, output_, output_grad_}, kCPU);
177
+
178
+ const Tensor edge_index = edge_index_.contiguous();
179
+ const Tensor edge_type = edge_type_.contiguous();
180
+ const Tensor edge_weight = edge_weight_.contiguous();
181
+ const Tensor relation = relation_.contiguous();
182
+ const Tensor input = input_.contiguous();
183
+ const Tensor output = output_.contiguous();
184
+ const Tensor output_grad = output_grad_.contiguous();
185
+
186
+ int64_t nnz = edge_index.size(0);
187
+ int64_t num_row = input.size(0);
188
+ int64_t dim = input.size(1);
189
+ Tensor weight_grad = at::zeros_like(edge_weight);
190
+ Tensor relation_grad = at::zeros_like(relation);
191
+ Tensor input_grad = at::zeros_like(input);
192
+
193
+ Tensor row_ind = edge_index.select(0, 0);
194
+ Tensor row_ptr = ind2ptr(row_ind, num_row);
195
+ Tensor col_ind = edge_index.select(0, 1);
196
+ Tensor layer_ind = edge_type;
197
+ std::vector<std::mutex> relation_mutex(relation.numel());
198
+ std::vector<std::mutex> input_mutex(input.numel());
199
+
200
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cpu", [&] {
201
+ rspmm_backward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
202
+ row_ptr.data_ptr<int64_t>(),
203
+ col_ind.data_ptr<int64_t>(),
204
+ layer_ind.data_ptr<int64_t>(),
205
+ edge_weight.data_ptr<scalar_t>(),
206
+ relation.data_ptr<scalar_t>(),
207
+ input.data_ptr<scalar_t>(),
208
+ output.data_ptr<scalar_t>(),
209
+ output_grad.data_ptr<scalar_t>(),
210
+ weight_grad.data_ptr<scalar_t>(),
211
+ relation_grad.data_ptr<scalar_t>(),
212
+ input_grad.data_ptr<scalar_t>(),
213
+ num_row, nnz, dim,
214
+ relation_mutex, input_mutex
215
+ );
216
+ });
217
+
218
+ return std::make_tuple(weight_grad, relation_grad, input_grad);
219
+ }
220
+
221
+ #define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
222
+ Tensor rspmm_##ADD##_##MUL##_forward_cpu( \
223
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
224
+ const Tensor &relation, const Tensor &input) { \
225
+ return rspmm_forward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
226
+ }
227
+
228
+ #define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
229
+ std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cpu( \
230
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
231
+ const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
232
+ return rspmm_backward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
233
+ output, output_grad); \
234
+ }
235
+
236
+ DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
237
+ DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
238
+
239
+ DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
240
+ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
241
+
242
+ DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
243
+ DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
244
+
245
+ DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
246
+ DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
247
+
248
+ DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
249
+ DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
250
+
251
+ DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
252
+ DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
253
+
254
+ } // namespace at
255
+
256
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
257
+ m.def("rspmm_add_mul_forward_cpu", &at::rspmm_add_mul_forward_cpu);
258
+ m.def("rspmm_add_mul_backward_cpu", &at::rspmm_add_mul_backward_cpu);
259
+ m.def("rspmm_min_mul_forward_cpu", &at::rspmm_min_mul_forward_cpu);
260
+ m.def("rspmm_min_mul_backward_cpu", &at::rspmm_min_mul_backward_cpu);
261
+ m.def("rspmm_max_mul_forward_cpu", &at::rspmm_max_mul_forward_cpu);
262
+ m.def("rspmm_max_mul_backward_cpu", &at::rspmm_max_mul_backward_cpu);
263
+ m.def("rspmm_add_add_forward_cpu", &at::rspmm_add_add_forward_cpu);
264
+ m.def("rspmm_add_add_backward_cpu", &at::rspmm_add_add_backward_cpu);
265
+ m.def("rspmm_min_add_forward_cpu", &at::rspmm_min_add_forward_cpu);
266
+ m.def("rspmm_min_add_backward_cpu", &at::rspmm_min_add_backward_cpu);
267
+ m.def("rspmm_max_add_forward_cpu", &at::rspmm_max_add_forward_cpu);
268
+ m.def("rspmm_max_add_backward_cpu", &at::rspmm_max_add_backward_cpu);
269
+ #ifdef CUDA_OP
270
+ m.def("rspmm_add_mul_forward_cuda", &at::rspmm_add_mul_forward_cuda);
271
+ m.def("rspmm_add_mul_backward_cuda", &at::rspmm_add_mul_backward_cuda);
272
+ m.def("rspmm_min_mul_forward_cuda", &at::rspmm_min_mul_forward_cuda);
273
+ m.def("rspmm_min_mul_backward_cuda", &at::rspmm_min_mul_backward_cuda);
274
+ m.def("rspmm_max_mul_forward_cuda", &at::rspmm_max_mul_forward_cuda);
275
+ m.def("rspmm_max_mul_backward_cuda", &at::rspmm_max_mul_backward_cuda);
276
+ m.def("rspmm_add_add_forward_cuda", &at::rspmm_add_add_forward_cuda);
277
+ m.def("rspmm_add_add_backward_cuda", &at::rspmm_add_add_backward_cuda);
278
+ m.def("rspmm_min_add_forward_cuda", &at::rspmm_min_add_forward_cuda);
279
+ m.def("rspmm_min_add_backward_cuda", &at::rspmm_min_add_backward_cuda);
280
+ m.def("rspmm_max_add_forward_cuda", &at::rspmm_max_add_forward_cuda);
281
+ m.def("rspmm_max_add_backward_cuda", &at::rspmm_max_add_backward_cuda);
282
+ #endif
283
+ }
ultra/rspmm/source/rspmm.cu ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <THC/THCAtomics.cuh>
3
+
4
+ #include "util.cuh"
5
+ #include "operator.cuh"
6
+ #include "rspmm.h"
7
+
8
+ namespace at {
9
+
10
+ // Memory & time efficient implementation of generalized spmm
11
+ // Much of the code is inspired by GE-SpMM
12
+ // https://github.com/hgyhungry/ge-spmm
13
+
14
+ namespace {
15
+
16
+ const int kCoarseningFactor = 2;
17
+ const int kThreadPerBlock = 256;
18
+
19
+ } // namespace anonymous
20
+
21
+ template <class scalar_t, class NaryOp, class BinaryOp>
22
+ __global__
23
+ void rspmm_forward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
24
+ const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
25
+ scalar_t *output,
26
+ int64_t num_row, int64_t nnz, int64_t dim) {
27
+ // for best optimization, the following code is compiled with constant warpSize
28
+ assert(blockDim.x == warpSize);
29
+
30
+ extern __shared__ int64_t buffer[];
31
+ int64_t *col_ind_buf = buffer;
32
+ int64_t *layer_ind_buf = buffer + blockDim.y * warpSize;
33
+ scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
34
+ col_ind_buf += threadIdx.y * warpSize;
35
+ layer_ind_buf += threadIdx.y * warpSize;
36
+ weight_buf += threadIdx.y * warpSize;
37
+
38
+ int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
39
+ if (row >= num_row)
40
+ return;
41
+ int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
42
+ int64_t ptr_start = row_ptr[row];
43
+ int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
44
+ scalar_t out[kCoarseningFactor];
45
+ #pragma unroll
46
+ for (int64_t i = 0; i < kCoarseningFactor; i++)
47
+ out[i] = NaryOp::zero;
48
+
49
+ for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
50
+ int64_t ptr = block_ptr + threadIdx.x;
51
+ if (ptr < ptr_end) {
52
+ col_ind_buf[threadIdx.x] = col_ind[ptr];
53
+ layer_ind_buf[threadIdx.x] = layer_ind[ptr];
54
+ weight_buf[threadIdx.x] = weight[ptr];
55
+ }
56
+ __syncwarp();
57
+
58
+ int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
59
+ for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
60
+ int64_t col = col_ind_buf[offset_ptr];
61
+ int64_t layer = layer_ind_buf[offset_ptr];
62
+ scalar_t w = weight_buf[offset_ptr];
63
+ #pragma unroll
64
+ for (int64_t i = 0; i < kCoarseningFactor; i++) {
65
+ int64_t d = d_start + i * warpSize;
66
+ if (d >= dim)
67
+ break;
68
+ scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
69
+ scalar_t y = w * x;
70
+ out[i] = NaryOp::forward(out[i], y);
71
+ }
72
+ }
73
+ __syncwarp();
74
+ }
75
+
76
+ #pragma unroll
77
+ for (int64_t i = 0; i < kCoarseningFactor; i++) {
78
+ int64_t d = d_start + i * warpSize;
79
+ if (d >= dim)
80
+ break;
81
+ output[row * dim + d] = out[i];
82
+ }
83
+ }
84
+
85
+ template <class scalar_t, class NaryOp, class BinaryOp>
86
+ __global__
87
+ void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
88
+ const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
89
+ const scalar_t *output, const scalar_t *output_grad,
90
+ scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
91
+ int64_t num_row, int64_t nnz, int64_t dim) {
92
+ // for best optimization, the following code is compiled with constant warpSize
93
+ assert(blockDim.x == warpSize);
94
+
95
+ extern __shared__ int64_t buffer[];
96
+ int64_t *col_ind_buf = buffer;
97
+ int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
98
+ scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
99
+ col_ind_buf += threadIdx.y * warpSize;
100
+ layer_ind_buf += threadIdx.y * warpSize;
101
+ weight_buf += threadIdx.y * warpSize;
102
+
103
+ int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
104
+ if (row >= num_row)
105
+ return;
106
+ int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
107
+ int64_t ptr_start = row_ptr[row];
108
+ int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
109
+
110
+ for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
111
+ int64_t ptr = block_ptr + threadIdx.x;
112
+ if (ptr < ptr_end) {
113
+ col_ind_buf[threadIdx.x] = col_ind[ptr];
114
+ layer_ind_buf[threadIdx.x] = layer_ind[ptr];
115
+ weight_buf[threadIdx.x] = weight[ptr];
116
+ }
117
+ __syncwarp();
118
+
119
+ int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
120
+ for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
121
+ int64_t col = col_ind_buf[offset_ptr];
122
+ int64_t layer = layer_ind_buf[offset_ptr];
123
+ scalar_t w = weight_buf[offset_ptr];
124
+ scalar_t w_grad = 0;
125
+ #pragma unroll
126
+ for (int64_t i = 0; i < kCoarseningFactor; i++) {
127
+ int64_t d = d_start + i * warpSize;
128
+ if (d >= dim)
129
+ break;
130
+ scalar_t rel = relation[layer * dim + d];
131
+ scalar_t in = input[col * dim + d];
132
+ scalar_t out = output[row * dim + d];
133
+ scalar_t out_grad = output_grad[row * dim + d];
134
+ scalar_t x = BinaryOp::forward(rel, in);
135
+ scalar_t y = w * x;
136
+ scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
137
+ scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
138
+ scalar_t dout_dy = NaryOp::backward(out, y);
139
+ scalar_t dy_dw = x;
140
+ scalar_t dy_dx = w;
141
+ w_grad += out_grad * dout_dy * dy_dw;
142
+ atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
143
+ atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
144
+ }
145
+ w_grad = warp_reduce(w_grad);
146
+ if (threadIdx.x == 0)
147
+ atomicAdd(&weight_grad[block_ptr + offset_ptr], w_grad);
148
+ }
149
+ __syncwarp();
150
+ }
151
+ }
152
+
153
+ // only relation & input require gradients
154
+ template <class scalar_t, class NaryOp, class BinaryOp>
155
+ __global__
156
+ void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
157
+ const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
158
+ const scalar_t *output, const scalar_t *output_grad,
159
+ scalar_t *relation_grad, scalar_t *input_grad,
160
+ int64_t num_row, int64_t nnz, int64_t dim) {
161
+ // for best optimization, the following code is compiled with constant warpSize
162
+ assert(blockDim.x == warpSize);
163
+
164
+ extern __shared__ int64_t buffer[];
165
+ int64_t *col_ind_buf = buffer;
166
+ int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
167
+ scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
168
+ col_ind_buf += threadIdx.y * warpSize;
169
+ layer_ind_buf += threadIdx.y * warpSize;
170
+ weight_buf += threadIdx.y * warpSize;
171
+
172
+ int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
173
+ if (row >= num_row)
174
+ return;
175
+ int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
176
+ int64_t ptr_start = row_ptr[row];
177
+ int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
178
+
179
+ for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
180
+ int64_t ptr = block_ptr + threadIdx.x;
181
+ if (ptr < ptr_end) {
182
+ col_ind_buf[threadIdx.x] = col_ind[ptr];
183
+ layer_ind_buf[threadIdx.x] = layer_ind[ptr];
184
+ weight_buf[threadIdx.x] = weight[ptr];
185
+ }
186
+ __syncwarp();
187
+
188
+ int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
189
+ for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
190
+ int64_t col = col_ind_buf[offset_ptr];
191
+ int64_t layer = layer_ind_buf[offset_ptr];
192
+ scalar_t w = weight_buf[offset_ptr];
193
+ #pragma unroll
194
+ for (int64_t i = 0; i < kCoarseningFactor; i++) {
195
+ int64_t d = d_start + i * warpSize;
196
+ if (d >= dim)
197
+ break;
198
+ scalar_t rel = relation[layer * dim + d];
199
+ scalar_t in = input[col * dim + d];
200
+ scalar_t out = output[row * dim + d];
201
+ scalar_t out_grad = output_grad[row * dim + d];
202
+ scalar_t x = BinaryOp::forward(rel, in);
203
+ scalar_t y = w * x;
204
+ scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
205
+ scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
206
+ scalar_t dout_dy = NaryOp::backward(out, y);
207
+ scalar_t dy_dx = w;
208
+ atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
209
+ atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
210
+ }
211
+ }
212
+ __syncwarp();
213
+ }
214
+ }
215
+
216
+ template <template<class> class NaryOp, template<class> class BinaryOp>
217
+ Tensor rspmm_forward_cuda(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
218
+ const Tensor &relation_, const Tensor &input_) {
219
+ constexpr const char *fn_name = "rspmm_forward_cuda";
220
+ TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
221
+ edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
222
+ input_arg(input_, "input", 5);
223
+
224
+ rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
225
+ checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg});
226
+
227
+ const Tensor edge_index = edge_index_.contiguous();
228
+ const Tensor edge_type = edge_type_.contiguous();
229
+ const Tensor edge_weight = edge_weight_.contiguous();
230
+ const Tensor relation = relation_.contiguous();
231
+ const Tensor input = input_.contiguous();
232
+
233
+ int64_t nnz = edge_index.size(0);
234
+ int64_t num_row = input.size(0);
235
+ int64_t dim = input.size(1);
236
+ Tensor output = at::empty({num_row, dim}, input.options());
237
+
238
+ Tensor row_ind = edge_index.select(0, 0);
239
+ Tensor row_ptr = ind2ptr(row_ind, num_row);
240
+ Tensor col_ind = edge_index.select(0, 1);
241
+ Tensor layer_ind = edge_type;
242
+
243
+ cudaSetDevice(input.get_device());
244
+ auto stream = at::cuda::getCurrentCUDAStream();
245
+
246
+ const int dim_per_block = 32; // warpSize
247
+ const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
248
+ const int row_per_block = kThreadPerBlock / dim_per_block;
249
+ const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
250
+
251
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cuda", [&] {
252
+ const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
253
+ rspmm_forward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
254
+ <<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
255
+ row_ptr.data_ptr<int64_t>(),
256
+ col_ind.data_ptr<int64_t>(),
257
+ layer_ind.data_ptr<int64_t>(),
258
+ edge_weight.data_ptr<scalar_t>(),
259
+ relation.data_ptr<scalar_t>(),
260
+ input.data_ptr<scalar_t>(),
261
+ output.data_ptr<scalar_t>(),
262
+ num_row, nnz, dim
263
+ );
264
+ });
265
+
266
+ return output;
267
+ }
268
+
269
+ template <template<class> class NaryOp, template<class> class BinaryOp>
270
+ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cuda(
271
+ const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
272
+ const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
273
+ constexpr const char *fn_name = "rspmm_backward_cuda";
274
+ TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
275
+ edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
276
+ input_arg(input_, "input", 5), output_arg(output_, "output", 6),
277
+ output_grad_arg(output_grad_, "output_grad", 7);
278
+
279
+ rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
280
+ output_arg, output_grad_arg);
281
+ checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg, output_arg,
282
+ output_grad_arg});
283
+
284
+ const Tensor edge_index = edge_index_.contiguous();
285
+ const Tensor edge_type = edge_type_.contiguous();
286
+ const Tensor edge_weight = edge_weight_.contiguous();
287
+ const Tensor relation = relation_.contiguous();
288
+ const Tensor input = input_.contiguous();
289
+ const Tensor output = output_.contiguous();
290
+ const Tensor output_grad = output_grad_.contiguous();
291
+
292
+ int64_t nnz = edge_index.size(0);
293
+ int64_t num_row = input.size(0);
294
+ int64_t dim = input.size(1);
295
+ Tensor weight_grad = at::zeros_like(edge_weight);
296
+ Tensor relation_grad = at::zeros_like(relation);
297
+ Tensor input_grad = at::zeros_like(input);
298
+
299
+ Tensor row_ind = edge_index.select(0, 0);
300
+ Tensor row_ptr = ind2ptr(row_ind, num_row);
301
+ Tensor col_ind = edge_index.select(0, 1);
302
+ Tensor layer_ind = edge_type;
303
+
304
+ cudaSetDevice(input.get_device());
305
+ auto stream = at::cuda::getCurrentCUDAStream();
306
+
307
+ const int dim_per_block = 32; // warpSize
308
+ const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
309
+ const int row_per_block = kThreadPerBlock / dim_per_block;
310
+ const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
311
+
312
+ if (edge_weight.requires_grad())
313
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
314
+ const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
315
+ rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
316
+ <<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
317
+ row_ptr.data_ptr<int64_t>(),
318
+ col_ind.data_ptr<int64_t>(),
319
+ layer_ind.data_ptr<int64_t>(),
320
+ edge_weight.data_ptr<scalar_t>(),
321
+ relation.data_ptr<scalar_t>(),
322
+ input.data_ptr<scalar_t>(),
323
+ output.data_ptr<scalar_t>(),
324
+ output_grad.data_ptr<scalar_t>(),
325
+ weight_grad.data_ptr<scalar_t>(),
326
+ relation_grad.data_ptr<scalar_t>(),
327
+ input_grad.data_ptr<scalar_t>(),
328
+ num_row, nnz, dim
329
+ );
330
+ });
331
+ else
332
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
333
+ const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
334
+ rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
335
+ <<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
336
+ row_ptr.data_ptr<int64_t>(),
337
+ col_ind.data_ptr<int64_t>(),
338
+ layer_ind.data_ptr<int64_t>(),
339
+ edge_weight.data_ptr<scalar_t>(),
340
+ relation.data_ptr<scalar_t>(),
341
+ input.data_ptr<scalar_t>(),
342
+ output.data_ptr<scalar_t>(),
343
+ output_grad.data_ptr<scalar_t>(),
344
+ relation_grad.data_ptr<scalar_t>(),
345
+ input_grad.data_ptr<scalar_t>(),
346
+ num_row, nnz, dim
347
+ );
348
+ });
349
+
350
+ return std::make_tuple(weight_grad, relation_grad, input_grad);
351
+ }
352
+
353
+ #define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
354
+ Tensor rspmm_##ADD##_##MUL##_forward_cuda( \
355
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
356
+ const Tensor &relation, const Tensor &input) { \
357
+ return rspmm_forward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
358
+ }
359
+
360
+ #define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
361
+ std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cuda( \
362
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
363
+ const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
364
+ return rspmm_backward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
365
+ output, output_grad); \
366
+ }
367
+
368
+ DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
369
+ DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
370
+
371
+ DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
372
+ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
373
+
374
+ DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
375
+ DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
376
+
377
+ DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
378
+ DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
379
+
380
+ DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
381
+ DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
382
+
383
+ DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
384
+ DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
385
+
386
+ } // namespace at
ultra/rspmm/source/rspmm.h ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <tuple>
4
+
5
+ #include <torch/extension.h>
6
+ //#include <ATen/SparseTensorUtils.h>
7
+ #include <ATen/native/SparseTensorUtils.h>
8
+
9
+ namespace at {
10
+
11
+ using namespace at::sparse;
12
+
13
+ void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
14
+ const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg);
15
+
16
+ void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
17
+ const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
18
+ const TensorArg &output_arg, const TensorArg &output_grad_arg);
19
+
20
+ Tensor ind2ptr(const Tensor &index, int size);
21
+
22
+ Tensor rspmm_add_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
23
+ const Tensor &relation, const Tensor &input);
24
+
25
+ std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cpu(
26
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
27
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
28
+
29
+ Tensor rspmm_min_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
30
+ const Tensor &relation, const Tensor &input);
31
+
32
+ std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cpu(
33
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
34
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
35
+
36
+ Tensor rspmm_max_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
37
+ const Tensor &relation, const Tensor &input);
38
+
39
+ std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(
40
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
41
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
42
+
43
+ Tensor rspmm_add_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
44
+ const Tensor &relation, const Tensor &input);
45
+
46
+ std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cpu(
47
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
48
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
49
+
50
+ Tensor rspmm_min_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
51
+ const Tensor &relation, const Tensor &input);
52
+
53
+ std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cpu(
54
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
55
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
56
+
57
+ Tensor rspmm_max_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
58
+ const Tensor &relation, const Tensor &input);
59
+
60
+ std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cpu(
61
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
62
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
63
+
64
+ #ifdef CUDA_OP
65
+ Tensor rspmm_add_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
66
+ const Tensor &relation, const Tensor &input);
67
+
68
+ std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cuda(
69
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
70
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
71
+
72
+ Tensor rspmm_min_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
73
+ const Tensor &relation, const Tensor &input);
74
+
75
+ std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cuda(
76
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
77
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
78
+
79
+ Tensor rspmm_max_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
80
+ const Tensor &relation, const Tensor &input);
81
+
82
+ std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(
83
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
84
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
85
+
86
+ Tensor rspmm_add_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
87
+ const Tensor &relation, const Tensor &input);
88
+
89
+ std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cuda(
90
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
91
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
92
+
93
+ Tensor rspmm_min_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
94
+ const Tensor &relation, const Tensor &input);
95
+
96
+ std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cuda(
97
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
98
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
99
+
100
+ Tensor rspmm_max_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
101
+ const Tensor &relation, const Tensor &input);
102
+
103
+ std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cuda(
104
+ const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
105
+ const Tensor &input, const Tensor &output, const Tensor &output_grad);
106
+ #endif
107
+
108
+ } // namespace at
ultra/rspmm/source/util.cuh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ namespace at {
4
+
5
+ const unsigned kFullMask = 0xFFFFFFFF;
6
+
7
+ template <class scalar_t>
8
+ __device__ scalar_t warp_reduce(scalar_t value) {
9
+ #pragma unroll
10
+ for (int delta = 1; delta < warpSize; delta *= 2)
11
+ #if __CUDACC_VER_MAJOR__ >= 9
12
+ value += __shfl_down_sync(kFullMask, value, delta);
13
+ #else
14
+ value += __shfl_down(value, delta);
15
+ #endif
16
+ return value;
17
+ }
18
+
19
+ template<class scalar_t>
20
+ __device__ scalar_t warp_broadcast(scalar_t value, int lane_id) {
21
+ #if __CUDACC_VER_MAJOR__ >= 9
22
+ return __shfl_sync(kFullMask, value, lane_id);
23
+ #else
24
+ return __shfl(value, lane_id);
25
+ #endif
26
+ }
27
+
28
+ } // namespace at
ultra/tasks.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from torch_scatter import scatter_add
3
+ from torch_geometric.data import Data
4
+ import torch
5
+
6
+
7
+ def edge_match(edge_index, query_index):
8
+ # O((n + q)logn) time
9
+ # O(n) memory
10
+ # edge_index: big underlying graph
11
+ # query_index: edges to match
12
+
13
+ # preparing unique hashing of edges, base: (max_node, max_relation) + 1
14
+ base = edge_index.max(dim=1)[0] + 1
15
+ # we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT
16
+ # idea: max number of edges = num_nodes * num_relations
17
+ # e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0
18
+ # given a tuple (h, r), we will search for all other existing edges starting from head h
19
+ assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max
20
+ scale = base.cumprod(0)
21
+ scale = scale[-1] // scale
22
+
23
+ # hash both the original edge index and the query index to unique integers
24
+ edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0)
25
+ edge_hash, order = edge_hash.sort()
26
+ query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0)
27
+
28
+ # matched ranges: [start[i], end[i])
29
+ start = torch.bucketize(query_hash, edge_hash)
30
+ end = torch.bucketize(query_hash, edge_hash, right=True)
31
+ # num_match shows how many edges satisfy the (h, r) pattern for each query in the batch
32
+ num_match = end - start
33
+
34
+ # generate the corresponding ranges
35
+ offset = num_match.cumsum(0) - num_match
36
+ range = torch.arange(num_match.sum(), device=edge_index.device)
37
+ range = range + (start - offset).repeat_interleave(num_match)
38
+
39
+ return order[range], num_match
40
+
41
+
42
+ def negative_sampling(data, batch, num_negative, strict=True):
43
+ batch_size = len(batch)
44
+ pos_h_index, pos_t_index, pos_r_index = batch.t()
45
+
46
+ # strict negative sampling vs random negative sampling
47
+ if strict:
48
+ t_mask, h_mask = strict_negative_mask(data, batch)
49
+ t_mask = t_mask[:batch_size // 2]
50
+ neg_t_candidate = t_mask.nonzero()[:, 1]
51
+ num_t_candidate = t_mask.sum(dim=-1)
52
+ # draw samples for negative tails
53
+ rand = torch.rand(len(t_mask), num_negative, device=batch.device)
54
+ index = (rand * num_t_candidate.unsqueeze(-1)).long()
55
+ index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
56
+ neg_t_index = neg_t_candidate[index]
57
+
58
+ h_mask = h_mask[batch_size // 2:]
59
+ neg_h_candidate = h_mask.nonzero()[:, 1]
60
+ num_h_candidate = h_mask.sum(dim=-1)
61
+ # draw samples for negative heads
62
+ rand = torch.rand(len(h_mask), num_negative, device=batch.device)
63
+ index = (rand * num_h_candidate.unsqueeze(-1)).long()
64
+ index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
65
+ neg_h_index = neg_h_candidate[index]
66
+ else:
67
+ neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device)
68
+ neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]
69
+
70
+ h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
71
+ t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
72
+ r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)
73
+ t_index[:batch_size // 2, 1:] = neg_t_index
74
+ h_index[batch_size // 2:, 1:] = neg_h_index
75
+
76
+ return torch.stack([h_index, t_index, r_index], dim=-1)
77
+
78
+
79
+ def all_negative(data, batch):
80
+ pos_h_index, pos_t_index, pos_r_index = batch.t()
81
+ r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes)
82
+ # generate all negative tails for this batch
83
+ all_index = torch.arange(data.num_nodes, device=batch.device)
84
+ h_index, t_index = torch.meshgrid(pos_h_index, all_index, indexing="ij") # indexing "xy" would return transposed
85
+ t_batch = torch.stack([h_index, t_index, r_index], dim=-1)
86
+ # generate all negative heads for this batch
87
+ all_index = torch.arange(data.num_nodes, device=batch.device)
88
+ t_index, h_index = torch.meshgrid(pos_t_index, all_index, indexing="ij")
89
+ h_batch = torch.stack([h_index, t_index, r_index], dim=-1)
90
+
91
+ return t_batch, h_batch
92
+
93
+
94
+ def strict_negative_mask(data, batch):
95
+ # this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives
96
+ # similarly, for a given (t, r) we will NOT sample existing true heads as random negatives
97
+
98
+ pos_h_index, pos_t_index, pos_r_index = batch.t()
99
+
100
+ # part I: sample hard negative tails
101
+ # edge index of all (head, relation) edges from the underlying graph
102
+ edge_index = torch.stack([data.edge_index[0], data.edge_type])
103
+ # edge index of current batch (head, relation) for which we will sample negatives
104
+ query_index = torch.stack([pos_h_index, pos_r_index])
105
+ # search for all true tails for the given (h, r) batch
106
+ edge_id, num_t_truth = edge_match(edge_index, query_index)
107
+ # build an index from the found edges
108
+ t_truth_index = data.edge_index[1, edge_id]
109
+ sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth)
110
+ t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
111
+ # assign 0s to the mask with the found true tails
112
+ t_mask[sample_id, t_truth_index] = 0
113
+ t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0)
114
+
115
+ # part II: sample hard negative heads
116
+ # edge_index[1] denotes tails, so the edge index becomes (t, r)
117
+ edge_index = torch.stack([data.edge_index[1], data.edge_type])
118
+ # edge index of current batch (tail, relation) for which we will sample heads
119
+ query_index = torch.stack([pos_t_index, pos_r_index])
120
+ # search for all true heads for the given (t, r) batch
121
+ edge_id, num_h_truth = edge_match(edge_index, query_index)
122
+ # build an index from the found edges
123
+ h_truth_index = data.edge_index[0, edge_id]
124
+ sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth)
125
+ h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
126
+ # assign 0s to the mask with the found true heads
127
+ h_mask[sample_id, h_truth_index] = 0
128
+ h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0)
129
+
130
+ return t_mask, h_mask
131
+
132
+
133
+ def compute_ranking(pred, target, mask=None):
134
+ pos_pred = pred.gather(-1, target.unsqueeze(-1))
135
+ if mask is not None:
136
+ # filtered ranking
137
+ ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
138
+ else:
139
+ # unfiltered ranking
140
+ ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
141
+ return ranking
142
+
143
+
144
+ def build_relation_graph(graph):
145
+
146
+ # expect the graph is already with inverse edges
147
+
148
+ edge_index, edge_type = graph.edge_index, graph.edge_type
149
+ num_nodes, num_rels = graph.num_nodes, graph.num_relations
150
+ device = edge_index.device
151
+
152
+ Eh = torch.vstack([edge_index[0], edge_type]).T.unique(dim=0) # (num_edges, 2)
153
+ Dh = scatter_add(torch.ones_like(Eh[:, 1]), Eh[:, 0])
154
+
155
+ EhT = torch.sparse_coo_tensor(
156
+ torch.flip(Eh, dims=[1]).T,
157
+ torch.ones(Eh.shape[0], device=device) / Dh[Eh[:, 0]],
158
+ (num_rels, num_nodes)
159
+ )
160
+ Eh = torch.sparse_coo_tensor(
161
+ Eh.T,
162
+ torch.ones(Eh.shape[0], device=device),
163
+ (num_nodes, num_rels)
164
+ )
165
+ Et = torch.vstack([edge_index[1], edge_type]).T.unique(dim=0) # (num_edges, 2)
166
+
167
+ Dt = scatter_add(torch.ones_like(Et[:, 1]), Et[:, 0])
168
+ assert not (Dt[Et[:, 0]] == 0).any()
169
+
170
+ EtT = torch.sparse_coo_tensor(
171
+ torch.flip(Et, dims=[1]).T,
172
+ torch.ones(Et.shape[0], device=device) / Dt[Et[:, 0]],
173
+ (num_rels, num_nodes)
174
+ )
175
+ Et = torch.sparse_coo_tensor(
176
+ Et.T,
177
+ torch.ones(Et.shape[0], device=device),
178
+ (num_nodes, num_rels)
179
+ )
180
+
181
+ Ahh = torch.sparse.mm(EhT, Eh).coalesce()
182
+ Att = torch.sparse.mm(EtT, Et).coalesce()
183
+ Aht = torch.sparse.mm(EhT, Et).coalesce()
184
+ Ath = torch.sparse.mm(EtT, Eh).coalesce()
185
+
186
+ hh_edges = torch.cat([Ahh.indices().T, torch.zeros(Ahh.indices().T.shape[0], 1, dtype=torch.long).fill_(0)], dim=1) # head to head
187
+ tt_edges = torch.cat([Att.indices().T, torch.zeros(Att.indices().T.shape[0], 1, dtype=torch.long).fill_(1)], dim=1) # tail to tail
188
+ ht_edges = torch.cat([Aht.indices().T, torch.zeros(Aht.indices().T.shape[0], 1, dtype=torch.long).fill_(2)], dim=1) # head to tail
189
+ th_edges = torch.cat([Ath.indices().T, torch.zeros(Ath.indices().T.shape[0], 1, dtype=torch.long).fill_(3)], dim=1) # tail to head
190
+
191
+ rel_graph = Data(
192
+ edge_index=torch.cat([hh_edges[:, [0, 1]].T, tt_edges[:, [0, 1]].T, ht_edges[:, [0, 1]].T, th_edges[:, [0, 1]].T], dim=1),
193
+ edge_type=torch.cat([hh_edges[:, 2], tt_edges[:, 2], ht_edges[:, 2], th_edges[:, 2]], dim=0),
194
+ num_nodes=num_rels,
195
+ num_relations=4
196
+ )
197
+
198
+ graph.relation_graph = rel_graph
199
+ return graph
200
+
201
+
ultra/util.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import ast
4
+ import copy
5
+ import time
6
+ import logging
7
+ import argparse
8
+
9
+ import yaml
10
+ import jinja2
11
+ from jinja2 import meta
12
+ import easydict
13
+
14
+ import torch
15
+ from torch import distributed as dist
16
+ from torch_geometric.data import Data
17
+ from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
18
+
19
+ from ultra import models, datasets
20
+
21
+
22
+ logger = logging.getLogger(__file__)
23
+
24
+
25
+ def detect_variables(cfg_file):
26
+ with open(cfg_file, "r") as fin:
27
+ raw = fin.read()
28
+ env = jinja2.Environment()
29
+ tree = env.parse(raw)
30
+ vars = meta.find_undeclared_variables(tree)
31
+ return vars
32
+
33
+
34
+ def load_config(cfg_file, context=None):
35
+ with open(cfg_file, "r") as fin:
36
+ raw = fin.read()
37
+ template = jinja2.Template(raw)
38
+ instance = template.render(context)
39
+ cfg = yaml.safe_load(instance)
40
+ cfg = easydict.EasyDict(cfg)
41
+ return cfg
42
+
43
+
44
+ def literal_eval(string):
45
+ try:
46
+ return ast.literal_eval(string)
47
+ except (ValueError, SyntaxError):
48
+ return string
49
+
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("-c", "--config", help="yaml configuration file", required=True)
54
+ parser.add_argument("-s", "--seed", help="random seed for PyTorch", type=int, default=1024)
55
+
56
+ args, unparsed = parser.parse_known_args()
57
+ # get dynamic arguments defined in the config file
58
+ vars = detect_variables(args.config)
59
+ parser = argparse.ArgumentParser()
60
+ for var in vars:
61
+ parser.add_argument("--%s" % var, required=True)
62
+ vars = parser.parse_known_args(unparsed)[0]
63
+ vars = {k: literal_eval(v) for k, v in vars._get_kwargs()}
64
+
65
+ return args, vars
66
+
67
+
68
+ def get_root_logger(file=True):
69
+ format = "%(asctime)-10s %(message)s"
70
+ datefmt = "%H:%M:%S"
71
+ logging.basicConfig(format=format, datefmt=datefmt)
72
+ logger = logging.getLogger("")
73
+ logger.setLevel(logging.INFO)
74
+
75
+ if file:
76
+ handler = logging.FileHandler("log.txt")
77
+ format = logging.Formatter(format, datefmt)
78
+ handler.setFormatter(format)
79
+ logger.addHandler(handler)
80
+
81
+ return logger
82
+
83
+
84
+ def get_rank():
85
+ if dist.is_initialized():
86
+ return dist.get_rank()
87
+ if "RANK" in os.environ:
88
+ return int(os.environ["RANK"])
89
+ return 0
90
+
91
+
92
+ def get_world_size():
93
+ if dist.is_initialized():
94
+ return dist.get_world_size()
95
+ if "WORLD_SIZE" in os.environ:
96
+ return int(os.environ["WORLD_SIZE"])
97
+ return 1
98
+
99
+
100
+ def synchronize():
101
+ if get_world_size() > 1:
102
+ dist.barrier()
103
+
104
+
105
+ def get_device(cfg):
106
+ if cfg.train.gpus:
107
+ device = torch.device(cfg.train.gpus[get_rank()])
108
+ else:
109
+ device = torch.device("cpu")
110
+ return device
111
+
112
+ def get_devices(gpus):
113
+ if gpus is not None:
114
+ device = torch.device(gpus[get_rank()])
115
+ else:
116
+ device = torch.device("cpu")
117
+ return device
118
+
119
+
120
+ def create_working_directory(cfg):
121
+ file_name = "working_dir.tmp"
122
+ world_size = get_world_size()
123
+ if cfg.train.gpus is not None and len(cfg.train.gpus) != world_size:
124
+ error_msg = "World size is %d but found %d GPUs in the argument"
125
+ if world_size == 1:
126
+ error_msg += ". Did you launch with `python -m torch.distributed.launch`?"
127
+ raise ValueError(error_msg % (world_size, len(cfg.train.gpus)))
128
+ if world_size > 1 and not dist.is_initialized():
129
+ dist.init_process_group("nccl", init_method="env://")
130
+
131
+ working_dir = os.path.join(os.path.expanduser(cfg.output_dir),
132
+ cfg.model["class"], cfg.dataset["class"], time.strftime("%Y-%m-%d-%H-%M-%S"))
133
+
134
+ # synchronize working directory
135
+ if get_rank() == 0:
136
+ with open(file_name, "w") as fout:
137
+ fout.write(working_dir)
138
+ os.makedirs(working_dir)
139
+ synchronize()
140
+ if get_rank() != 0:
141
+ with open(file_name, "r") as fin:
142
+ working_dir = fin.read()
143
+ synchronize()
144
+ if get_rank() == 0:
145
+ os.remove(file_name)
146
+
147
+ os.chdir(working_dir)
148
+ return working_dir
149
+
150
+
151
+ def build_dataset(cfg):
152
+ data_config = copy.deepcopy(cfg.dataset)
153
+ cls = data_config.pop("class")
154
+
155
+ ds_cls = getattr(datasets, cls)
156
+ dataset = ds_cls(**data_config)
157
+
158
+ if get_rank() == 0:
159
+ logger.warning("%s dataset" % (cls if "version" not in cfg.dataset else f'{cls}({cfg.dataset.version})'))
160
+ if cls != "JointDataset":
161
+ logger.warning("#train: %d, #valid: %d, #test: %d" %
162
+ (dataset[0].target_edge_index.shape[1], dataset[1].target_edge_index.shape[1],
163
+ dataset[2].target_edge_index.shape[1]))
164
+ else:
165
+ logger.warning("#train: %d, #valid: %d, #test: %d" %
166
+ (sum(d.target_edge_index.shape[1] for d in dataset._data[0]),
167
+ sum(d.target_edge_index.shape[1] for d in dataset._data[1]),
168
+ sum(d.target_edge_index.shape[1] for d in dataset._data[2]),
169
+ ))
170
+
171
+ return dataset
172
+