Skip to content

Commit 5922e1e

Browse files
clean stagate
1 parent df94f93 commit 5922e1e

16 files changed

+3434
-74
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ site
2020
outputs
2121
multirun
2222
lightning_logs
23-
novae_*
2423

2524
# Data files
2625
data/*

benchmark/model/build.py

-32
This file was deleted.

novae_benchmark/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import MODEL_DICT, get_model
File renamed without changes.
File renamed without changes.

benchmark/model/STAGATE_pyG/gat_conv.py novae_benchmark/model/STAGATE_pyG/gat_conv.py

+43-32
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from typing import Union, Tuple, Optional
2-
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
3-
OptTensor)
1+
from typing import Optional, Tuple, Union
42

53
import torch
6-
from torch import Tensor
4+
import torch.nn as nn
75
import torch.nn.functional as F
6+
from torch import Tensor
87
from torch.nn import Parameter
9-
import torch.nn as nn
10-
from torch_sparse import SparseTensor, set_diag
11-
from torch_geometric.nn.dense.linear import Linear
128
from torch_geometric.nn.conv import MessagePassing
13-
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
149

10+
# from torch_sparse import SparseTensor, set_diag
11+
from torch_geometric.nn.dense.linear import Linear
12+
from torch_geometric.typing import Adj, NoneType, OptPairTensor, OptTensor, Size
13+
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
1514

1615

1716
class GATConv(MessagePassing):
@@ -58,13 +57,22 @@ class GATConv(MessagePassing):
5857
**kwargs (optional): Additional arguments of
5958
:class:`torch_geometric.nn.conv.MessagePassing`.
6059
"""
60+
6161
_alpha: OptTensor
6262

63-
def __init__(self, in_channels: Union[int, Tuple[int, int]],
64-
out_channels: int, heads: int = 1, concat: bool = True,
65-
negative_slope: float = 0.2, dropout: float = 0.0,
66-
add_self_loops: bool = True, bias: bool = True, **kwargs):
67-
kwargs.setdefault('aggr', 'add')
63+
def __init__(
64+
self,
65+
in_channels: Union[int, Tuple[int, int]],
66+
out_channels: int,
67+
heads: int = 1,
68+
concat: bool = True,
69+
negative_slope: float = 0.2,
70+
dropout: float = 0.0,
71+
add_self_loops: bool = True,
72+
bias: bool = True,
73+
**kwargs,
74+
):
75+
kwargs.setdefault("aggr", "add")
6876
super(GATConv, self).__init__(node_dim=0, **kwargs)
6977

7078
self.in_channels = in_channels
@@ -91,7 +99,6 @@ def __init__(self, in_channels: Union[int, Tuple[int, int]],
9199
nn.init.xavier_normal_(self.lin_src.data, gain=1.414)
92100
self.lin_dst = self.lin_src
93101

94-
95102
# The learnable parameters to compute attention coefficients:
96103
self.att_src = Parameter(torch.Tensor(1, heads, out_channels))
97104
self.att_dst = Parameter(torch.Tensor(1, heads, out_channels))
@@ -117,12 +124,15 @@ def __init__(self, in_channels: Union[int, Tuple[int, int]],
117124
# glorot(self.att_dst)
118125
# # zeros(self.bias)
119126

120-
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
121-
size: Size = None, return_attention_weights=None, attention=True, tied_attention = None):
122-
# type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
123-
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
124-
# type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
125-
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
127+
def forward(
128+
self,
129+
x: Union[Tensor, OptPairTensor],
130+
edge_index: Adj,
131+
size: Size = None,
132+
return_attention_weights=None,
133+
attention=True,
134+
tied_attention=None,
135+
):
126136
r"""
127137
Args:
128138
return_attention_weights (bool, optional): If set to :obj:`True`,
@@ -161,7 +171,6 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
161171
else:
162172
alpha = tied_attention
163173

164-
165174
if self.add_self_loops:
166175
if isinstance(edge_index, Tensor):
167176
# We only want to add self-loops for nodes that appear both as
@@ -172,8 +181,10 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
172181
num_nodes = min(size) if size is not None else num_nodes
173182
edge_index, _ = remove_self_loops(edge_index)
174183
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
175-
elif isinstance(edge_index, SparseTensor):
176-
edge_index = set_diag(edge_index)
184+
else:
185+
raise ValueError(f"Received invalid type {type(edge_index)}")
186+
# elif isinstance(edge_index, SparseTensor):
187+
# edge_index = set_diag(edge_index)
177188

178189
# propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
179190
out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
@@ -193,26 +204,26 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
193204
if isinstance(return_attention_weights, bool):
194205
if isinstance(edge_index, Tensor):
195206
return out, (edge_index, alpha)
196-
elif isinstance(edge_index, SparseTensor):
197-
return out, edge_index.set_value(alpha, layout='coo')
207+
else:
208+
raise ValueError(f"Received invalid type {type(edge_index)}")
209+
# elif isinstance(edge_index, SparseTensor):
210+
# return out, edge_index.set_value(alpha, layout="coo")
198211
else:
199212
return out
200213

201-
def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
202-
index: Tensor, ptr: OptTensor,
203-
size_i: Optional[int]) -> Tensor:
214+
def message(
215+
self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]
216+
) -> Tensor:
204217
# Given egel-level attention coefficients for source and target nodes,
205218
# we simply need to sum them up to "emulate" concatenation:
206219
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
207220

208-
#alpha = F.leaky_relu(alpha, self.negative_slope)
221+
# alpha = F.leaky_relu(alpha, self.negative_slope)
209222
alpha = torch.sigmoid(alpha)
210223
alpha = softmax(alpha, index, ptr, size_i)
211224
self._alpha = alpha # Save for later use.
212225
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
213226
return x_j * alpha.unsqueeze(-1)
214227

215228
def __repr__(self):
216-
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
217-
self.in_channels,
218-
self.out_channels, self.heads)
229+
return "{}({}, {}, heads={})".format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)

novae_benchmark/model/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import STAGATE_pyG as STAGATE
2+
from .trainer import MODEL_DICT, get_model

benchmark/trainer.py novae_benchmark/model/trainer.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import pandas as pd
33
import scanpy as sc
4-
import STAGATE_pyG as STAGATE
5-
import torch
64
from anndata import AnnData
75

6+
from . import STAGATE
7+
88

99
class Model:
1010
def __init__(self, model_name: str, hidden_dim: int) -> None:
@@ -32,18 +32,13 @@ def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
3232
if batch_key is None:
3333
STAGATE.Cal_Spatial_Net(adata, rad_cutoff=self.RAD_CUTOFF)
3434
else:
35-
adatas = [
36-
adata[adata.obs[batch_key] == b].copy()
37-
for b in adata.obs[batch_key].unique()
38-
]
35+
adatas = [adata[adata.obs[batch_key] == b].copy() for b in adata.obs[batch_key].unique()]
3936
for adata_ in adatas:
4037
print("Batch:", adata_.obs[batch_key][0])
4138
STAGATE.Cal_Spatial_Net(adata_, rad_cutoff=self.RAD_CUTOFF)
4239

4340
adata = sc.concat(adatas)
44-
adata.uns["Spatial_Net"] = pd.concat(
45-
[adata_.uns["Spatial_Net"] for adata_ in adatas]
46-
)
41+
adata.uns["Spatial_Net"] = pd.concat([adata_.uns["Spatial_Net"] for adata_ in adatas])
4742
print("\nConcatenated:", adata)
4843

4944
adata = STAGATE.train_STAGATE(adata, key_added="STAGATE", device=device)

novae_benchmark/utils.py

Whitespace-only changes.

0 commit comments

Comments
 (0)