-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8ed265d
commit 74fcc68
Showing
35 changed files
with
427 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,29 @@ | ||
import hydra # noqa: F401 | ||
import torch | ||
from omegaconf import DictConfig # noqa: F401 | ||
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper | ||
from topobenchmarkx.models.wrappers.graph import GNNWrapper | ||
from topobenchmarkx.models.wrappers.hypergraph import HypergraphWrapper | ||
from topobenchmarkx.models.wrappers.simplicial import SANWrapper, SCNWrapper, SCCNNWrapper, SCCNWrapper | ||
from topobenchmarkx.models.wrappers.cell import CANWrapper, CWNDCMWrapper, CWNWrapper, CCXNWrapper | ||
|
||
# ... import other readout classes here | ||
# For example: | ||
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 | ||
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 | ||
|
||
class DefaultLoss: | ||
"""Abstract class that provides an interface to loss logic within | ||
netowrk.""" | ||
|
||
def __init__(self, task): | ||
if task == "classification": | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
|
||
elif task == "regression": | ||
self.criterion = torch.nn.mse() | ||
else: | ||
raise Exception("Loss is not defined") | ||
|
||
def __call__(self, model_output): | ||
"""Loss logic based on model_output.""" | ||
|
||
logits = model_output["logits"] | ||
target = model_output["labels"] | ||
model_output["loss"] = self.criterion(logits, target) | ||
|
||
return model_output | ||
# Export all wrappers | ||
__all__ = [ | ||
"DefaultWrapper", | ||
"GNNWrapper", | ||
"HypergraphWrapper", | ||
"SANWrapper", | ||
"SCNWrapper", | ||
"SCCNNWrapper", | ||
"SCCNWrapper", | ||
"CANWrapper", | ||
"CWNDCMWrapper", | ||
"CWNWrapper", | ||
"CCXNWrapper", | ||
# "OtherWrapper1", | ||
# "OtherWrapper2", | ||
# ... add other readout classes here | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from topobenchmarkx.models.wrappers.cell.can_wrapper import CANWrapper | ||
from topobenchmarkx.models.wrappers.cell.cwndcm_wrapper import CWNDCMWrapper | ||
from topobenchmarkx.models.wrappers.cell.cwn_wrapper import CWNWrapper | ||
from topobenchmarkx.models.wrappers.cell.ccxn_wrapper import CCXNWrapper | ||
|
||
# ... import other readout classes here | ||
# For example: | ||
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 | ||
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 | ||
|
||
__all__ = [ | ||
"CANWrapper", | ||
"CWNDCMWrapper", | ||
"CWNWrapper", | ||
"CCXNWrapper", | ||
|
||
# "OtherWrapper1", | ||
# "OtherWrapper2", | ||
# ... add other readout classes here | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper | ||
|
||
class CANWrapper(DefaultWrapper): | ||
"""Abstract class that provides an interface to loss logic within | ||
network.""" | ||
|
||
def forward(self, batch): | ||
"""Define logic for forward pass.""" | ||
|
||
x_1 = self.backbone( | ||
x_0=batch.x_0, | ||
x_1=batch.x_1, | ||
adjacency_0=batch.adjacency_0.coalesce(), | ||
down_laplacian_1=batch.down_laplacian_1.coalesce(), | ||
up_laplacian_1=batch.up_laplacian_1.coalesce(), | ||
) | ||
|
||
model_out = {"labels": batch.y, "batch_0": batch.batch_0} | ||
model_out["x_1"] = x_1 | ||
model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) | ||
return model_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper | ||
|
||
class CCXNWrapper(DefaultWrapper): | ||
"""Abstract class that provides an interface to loss logic within | ||
network.""" | ||
|
||
def forward(self, batch): | ||
"""Define logic for forward pass.""" | ||
|
||
x_0, x_1, x_2 = self.backbone( | ||
x_0=batch.x_0, | ||
x_1=batch.x_1, | ||
adjacency_0=batch.adjacency_0, | ||
incidence_2_t=batch.incidence_2.T, | ||
) | ||
|
||
model_out = {"labels": batch.y, "batch_0": batch.batch_0} | ||
model_out["x_0"] = x_0 | ||
model_out["x_1"] = x_1 | ||
model_out["x_2"] = x_2 | ||
return model_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper | ||
|
||
class CWNWrapper(DefaultWrapper): | ||
"""Abstract class that provides an interface to loss logic within | ||
network.""" | ||
|
||
def forward(self, batch): | ||
"""Define logic for forward pass.""" | ||
|
||
x_0, x_1, x_2 = self.backbone( | ||
x_0=batch.x_0, | ||
x_1=batch.x_1, | ||
x_2=batch.x_2, | ||
incidence_1_t=batch.incidence_1.T, | ||
adjacency_0=batch.adjacency_1, | ||
incidence_2=batch.incidence_2, | ||
) | ||
|
||
model_out = {"labels": batch.y, "batch_0": batch.batch_0} | ||
model_out["x_0"] = x_0 | ||
model_out["x_1"] = x_1 | ||
model_out["x_2"] = x_2 | ||
return model_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper | ||
|
||
class CWNDCMWrapper(DefaultWrapper): | ||
"""Abstract class that provides an interface to loss logic within | ||
network.""" | ||
|
||
def forward(self, batch): | ||
"""Define logic for forward pass.""" | ||
|
||
x_1 = self.backbone( | ||
batch.x_1, | ||
batch.down_laplacian_1.coalesce(), | ||
batch.up_laplacian_1.coalesce(), | ||
) | ||
|
||
model_out = {"labels": batch.y, "batch_0": batch.batch_0} | ||
|
||
model_out["x_1"] = x_1 | ||
model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) | ||
return model_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from topobenchmarkx.models.wrappers.graph.gnn_wrapper import GNNWrapper | ||
|
||
# ... import other readout classes here | ||
# For example: | ||
# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 | ||
# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 | ||
|
||
# Export all wrappers | ||
__all__ = [ | ||
"GNNWrapper", | ||
|
||
# "OtherWrapper1", | ||
# "OtherWrapper2", | ||
# ... add other readout classes here | ||
] |
Oops, something went wrong.