Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
claying committed May 12, 2023
1 parent 49c5494 commit 70ff2ec
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 16 deletions.
6 changes: 0 additions & 6 deletions fie/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def __init__(self, input_size, output_size, num_mixtures=8, num_heads=1, residue
out_proj='kernel', kernel='exp', out_proj_args=0.5, use_deg=False):
super().__init__(node_dim=0, aggr='add')#, flow='target_to_source')
self.input_size = input_size
# self.hidden_size = hidden_size
self.output_size = output_size
self.num_mixtures = num_mixtures
self.num_heads = num_heads
Expand Down Expand Up @@ -67,9 +66,6 @@ def forward(self, x, edge_index, edge_attr=None, deg_sqrt=None, before_out_proj=
out = out - rearrange(self.weight, "p h d -> 1 d (p h)")
out = rearrange(out, "n d p -> n (p d)")

# attn = torch.sparse_coo_tensor(edge_index, self._attn.squeeze()).to_dense()
# print(attn)

if before_out_proj:
return out

Expand All @@ -92,7 +88,6 @@ def message(self, x_i, x_j, edge_attr, index, ptr, size_i):
alpha_ij = utils.softmax(alpha_ij, index, ptr, size_i)
x_j = rearrange(x_j, "n d -> n d 1")
alpha_ij = rearrange(alpha_ij, "n p h -> n 1 (p h)")
# self._attn = alpha_ij

return x_j * alpha_ij

Expand All @@ -110,7 +105,6 @@ def forward_proj(self, x, edge_index, deg_sqrt=None):
def sample(self, x, edge_index, n_samples=1000):
indices = torch.randperm(edge_index.shape[1])[:min(edge_index.shape[1], n_samples)]
edge_index = edge_index[:, indices]
# x_feat = self.feature_transform(x[edge_index[0]], x[edge_index[1]])
x_feat = self.feature_transform(x[edge_index[1]], x[edge_index[0]])
return x_feat

Expand Down
10 changes: 1 addition & 9 deletions fie/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,8 @@ def forward(self, data):
outputs = [x]
output = x
for i, mod in enumerate(self.layers):
# if i == self.num_layers - 1:
# output = mod(outputs[-1], edge_index, edge_attr, deg_sqrt=deg_sqrt, before_out_proj=True)
# else:
# output = mod(output, edge_index, edge_attr, deg_sqrt=deg_sqrt)
output = mod(output, edge_index, edge_attr, deg_sqrt=deg_sqrt)
# output = mod(output, edge_index, edge_attr, deg_sqrt=deg_sqrt, before_out_proj=True)
outputs.append(output)
# output = mod.out_proj(output) * deg_sqrt.view(-1, 1)

if self.concat:
output = torch.cat(outputs, dim=-1)
Expand Down Expand Up @@ -252,7 +246,6 @@ def __init__(self, num_class, input_size, num_layers=2,
self.num_layers = num_layers
self.concat = concat

# self.in_head = nn.Linear(input_size, hidden_size)
self.in_head = KernelLayer(input_size, hidden_size, sigma=out_proj_args)

layers = []
Expand Down Expand Up @@ -437,7 +430,7 @@ def forward(self, data):
def forward_ns(self, x, adjs, batch_size):
x = self.in_head(x)

outputs = x#[:batch_size]
outputs = x
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]]
x = self.layers[i]((x, x_target), edge_index)
Expand All @@ -451,7 +444,6 @@ def forward_ns(self, x, adjs, batch_size):
def inference_ns(self, x_all, subgraph_loader):
device = x_all.device
x_all = self.in_head(x_all)
# outputs = [x_all.cpu()]
outputs = x_all.cpu()
for i, mod in enumerate(self.layers):
xs = []
Expand Down
1 change: 0 additions & 1 deletion fie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def spherical_kmeans(x, n_clusters, max_iters=100, verbose=True,
for j in range(n_clusters):
index = assign == j
if index.sum() == 0:
# clusters[j] = x[random.randrange(n_samples)]
idx = tmp.argmin()
clusters[j] = x[idx]
tmp[idx] = 1
Expand Down

0 comments on commit 70ff2ec

Please sign in to comment.