Skip to content

Commit

Permalink
updated wrapper modularity
Browse files Browse the repository at this point in the history
  • Loading branch information
levtelyatnikov committed May 13, 2024
1 parent 8ed265d commit 74fcc68
Show file tree
Hide file tree
Showing 35 changed files with 427 additions and 43 deletions.
3 changes: 2 additions & 1 deletion configs/model/cell/can.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ backbone:
att_lift: False

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.CANWrapper
_target_: topobenchmarkx.models.wrappers.CANWrapper
_partial_: true
wrapper_name: CANWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/cell/ccxn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ backbone_additional_params:
hidden_channels: ${model.feature_encoder.out_channels}

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.CCXNWrapper
_target_: topobenchmarkx.models.wrappers.CCXNWrapper
_partial_: true
wrapper_name: CCXNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/cell/cwn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ backbone:
n_layers: 1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.CWNWrapper
_target_: topobenchmarkx.models.wrappers.CWNWrapper
_partial_: true
wrapper_name: CWNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/cell/cwn_dcm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ backbone:
dropout: 0.0

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.CWNDCMWrapper
_target_: topobenchmarkx.models.wrappers.CWNDCMWrapper
_partial_: true
wrapper_name: CWNDCMWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
4 changes: 2 additions & 2 deletions configs/model/graph/gat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ backbone:
concat: true

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper
_target_: topobenchmarkx.models.wrappers.GNNWrapper
_partial_: true
wrapper_name: GNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand All @@ -35,7 +36,6 @@ head_model:
out_channels: ${dataset.parameters.num_classes}
pooling_type: sum


loss:
_target_: topobenchmarkx.models.losses.loss.DefaultLoss
task: ${dataset.parameters.task}
Expand Down
3 changes: 2 additions & 1 deletion configs/model/graph/gcn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ backbone:
act: relu

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper
_target_: topobenchmarkx.models.wrappers.GNNWrapper
_partial_: true
wrapper_name: GNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
4 changes: 2 additions & 2 deletions configs/model/graph/gin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ backbone:
act: relu

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper
_target_: topobenchmarkx.models.wrappers.GNNWrapper
_partial_: true
wrapper_name: GNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand All @@ -32,7 +33,6 @@ head_model:
out_channels: ${dataset.parameters.num_classes}
pooling_type: sum


loss:
_target_: topobenchmarkx.models.losses.loss.DefaultLoss
task: ${dataset.parameters.task}
Expand Down
3 changes: 2 additions & 1 deletion configs/model/hypergraph/alldeepset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ backbone:
#num_features: ${model.backbone.hidden_channels}

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper
_target_: topobenchmarkx.models.wrappers.HypergraphWrapper
_partial_: true
wrapper_name: HypergraphWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/hypergraph/allsettransformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ backbone:
mlp_dropout: 0.

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper
_target_: topobenchmarkx.models.wrappers.HypergraphWrapper
_partial_: true
wrapper_name: HypergraphWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/hypergraph/edgnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ backbone:
aggregate: 'add'

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper
_target_: topobenchmarkx.models.wrappers.HypergraphWrapper
_partial_: true
wrapper_name: HypergraphWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/hypergraph/unignn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ backbone:
n_layers: 1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper
_target_: topobenchmarkx.models.wrappers.HypergraphWrapper
_partial_: true
wrapper_name: HypergraphWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/hypergraph/unignn2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ backbone:
layer_drop: 0.2

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper
_target_: topobenchmarkx.models.wrappers.HypergraphWrapper
_partial_: true
wrapper_name: HypergraphWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/simplicial/san.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ backbone:
epsilon_harmonic: 1e-1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SANWrapper
_target_: topobenchmarkx.models.wrappers.SANWrapper
_partial_: true
wrapper_name: SANWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/simplicial/sccn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ backbone:
update_func: "sigmoid"

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNWrapper
_target_: topobenchmarkx.models.wrappers.SCCNWrapper
_partial_: true
wrapper_name: SCCNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/simplicial/sccnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ backbone:
n_layers: 1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper
_target_: topobenchmarkx.models.wrappers.SCCNNWrapper
_partial_: true
wrapper_name: SCCNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/simplicial/sccnn_custom.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ backbone:
n_layers: 1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper
_target_: topobenchmarkx.models.wrappers.SCCNNWrapper
_partial_: true
wrapper_name: SCCNNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
3 changes: 2 additions & 1 deletion configs/model/simplicial/scn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ backbone:
n_layers: 1

backbone_wrapper:
_target_: topobenchmarkx.models.wrappers.default_wrapper.SCNWrapper
_target_: topobenchmarkx.models.wrappers.SCNWrapper
_partial_: true
wrapper_name: SCNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}}

Expand Down
2 changes: 1 addition & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
defaults:
- _self_
- dataset: PROTEINS_TU #us_country_demos
- model: simplicial/sccn #hypergraph/unignn2 #allsettransformer
- model: graph/gcn #hypergraph/unignn2 #allsettransformer
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
Expand Down
49 changes: 26 additions & 23 deletions topobenchmarkx/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import hydra # noqa: F401
import torch
from omegaconf import DictConfig # noqa: F401
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper
from topobenchmarkx.models.wrappers.graph import GNNWrapper
from topobenchmarkx.models.wrappers.hypergraph import HypergraphWrapper
from topobenchmarkx.models.wrappers.simplicial import SANWrapper, SCNWrapper, SCCNNWrapper, SCCNWrapper
from topobenchmarkx.models.wrappers.cell import CANWrapper, CWNDCMWrapper, CWNWrapper, CCXNWrapper

# ... import other readout classes here
# For example:
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2

class DefaultLoss:
"""Abstract class that provides an interface to loss logic within
netowrk."""

def __init__(self, task):
if task == "classification":
self.criterion = torch.nn.CrossEntropyLoss()

elif task == "regression":
self.criterion = torch.nn.mse()
else:
raise Exception("Loss is not defined")

def __call__(self, model_output):
"""Loss logic based on model_output."""

logits = model_output["logits"]
target = model_output["labels"]
model_output["loss"] = self.criterion(logits, target)

return model_output
# Export all wrappers
__all__ = [
"DefaultWrapper",
"GNNWrapper",
"HypergraphWrapper",
"SANWrapper",
"SCNWrapper",
"SCCNNWrapper",
"SCCNWrapper",
"CANWrapper",
"CWNDCMWrapper",
"CWNWrapper",
"CCXNWrapper",
# "OtherWrapper1",
# "OtherWrapper2",
# ... add other readout classes here
]
20 changes: 20 additions & 0 deletions topobenchmarkx/models/wrappers/cell/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from topobenchmarkx.models.wrappers.cell.can_wrapper import CANWrapper
from topobenchmarkx.models.wrappers.cell.cwndcm_wrapper import CWNDCMWrapper
from topobenchmarkx.models.wrappers.cell.cwn_wrapper import CWNWrapper
from topobenchmarkx.models.wrappers.cell.ccxn_wrapper import CCXNWrapper

# ... import other readout classes here
# For example:
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2

__all__ = [
"CANWrapper",
"CWNDCMWrapper",
"CWNWrapper",
"CCXNWrapper",

# "OtherWrapper1",
# "OtherWrapper2",
# ... add other readout classes here
]
22 changes: 22 additions & 0 deletions topobenchmarkx/models/wrappers/cell/can_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper

class CANWrapper(DefaultWrapper):
"""Abstract class that provides an interface to loss logic within
network."""

def forward(self, batch):
"""Define logic for forward pass."""

x_1 = self.backbone(
x_0=batch.x_0,
x_1=batch.x_1,
adjacency_0=batch.adjacency_0.coalesce(),
down_laplacian_1=batch.down_laplacian_1.coalesce(),
up_laplacian_1=batch.up_laplacian_1.coalesce(),
)

model_out = {"labels": batch.y, "batch_0": batch.batch_0}
model_out["x_1"] = x_1
model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1)
return model_out
21 changes: 21 additions & 0 deletions topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper

class CCXNWrapper(DefaultWrapper):
"""Abstract class that provides an interface to loss logic within
network."""

def forward(self, batch):
"""Define logic for forward pass."""

x_0, x_1, x_2 = self.backbone(
x_0=batch.x_0,
x_1=batch.x_1,
adjacency_0=batch.adjacency_0,
incidence_2_t=batch.incidence_2.T,
)

model_out = {"labels": batch.y, "batch_0": batch.batch_0}
model_out["x_0"] = x_0
model_out["x_1"] = x_1
model_out["x_2"] = x_2
return model_out
23 changes: 23 additions & 0 deletions topobenchmarkx/models/wrappers/cell/cwn_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper

class CWNWrapper(DefaultWrapper):
"""Abstract class that provides an interface to loss logic within
network."""

def forward(self, batch):
"""Define logic for forward pass."""

x_0, x_1, x_2 = self.backbone(
x_0=batch.x_0,
x_1=batch.x_1,
x_2=batch.x_2,
incidence_1_t=batch.incidence_1.T,
adjacency_0=batch.adjacency_1,
incidence_2=batch.incidence_2,
)

model_out = {"labels": batch.y, "batch_0": batch.batch_0}
model_out["x_0"] = x_0
model_out["x_1"] = x_1
model_out["x_2"] = x_2
return model_out
21 changes: 21 additions & 0 deletions topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper

class CWNDCMWrapper(DefaultWrapper):
"""Abstract class that provides an interface to loss logic within
network."""

def forward(self, batch):
"""Define logic for forward pass."""

x_1 = self.backbone(
batch.x_1,
batch.down_laplacian_1.coalesce(),
batch.up_laplacian_1.coalesce(),
)

model_out = {"labels": batch.y, "batch_0": batch.batch_0}

model_out["x_1"] = x_1
model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1)
return model_out
15 changes: 15 additions & 0 deletions topobenchmarkx/models/wrappers/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from topobenchmarkx.models.wrappers.graph.gnn_wrapper import GNNWrapper

# ... import other readout classes here
# For example:
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2

# Export all wrappers
__all__ = [
"GNNWrapper",

# "OtherWrapper1",
# "OtherWrapper2",
# ... add other readout classes here
]
Loading

0 comments on commit 74fcc68

Please sign in to comment.