Skip to content

Commit 8a84518

Browse files
authored
Fix graph neighbors (#65)
* Fix graph neighbors * Change variable name
1 parent 89bbadb commit 8a84518

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

torchhd/structures.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -794,13 +794,10 @@ class Graph:
794794
795795
Args:
796796
dimensions (int): number of dimensions of the graph.
797-
directed (bool): decides if the graph will be directed or not.
797+
directed (bool, optional): specify if the graph is directed or not. Default: ``False``.
798798
dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see ``torch.set_default_tensor_type()``).
799799
device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
800-
801-
Args:
802800
input (Tensor): tensor representing a graph hypervector.
803-
directed (bool): decides if the graph will be directed or not.
804801
805802
Examples::
806803
@@ -817,7 +814,7 @@ def __init__(self, input: Tensor, *, directed=False):
817814
...
818815

819816
def __init__(self, dim_or_input: int, **kwargs):
820-
self.directed = kwargs.get("directed", False)
817+
self.is_directed = kwargs.get("directed", False)
821818
if torch.is_tensor(dim_or_input):
822819
self.value = dim_or_input
823820
else:
@@ -861,28 +858,31 @@ def encode_edge(self, node1: Tensor, node2: Tensor) -> Tensor:
861858
tensor([-1., 1., -1., ..., 1., -1., -1.])
862859
863860
"""
864-
if self.directed:
865-
return functional.bind(node1, node2)
866-
else:
861+
if self.is_directed:
867862
return functional.bind(node1, functional.permute(node2))
863+
else:
864+
return functional.bind(node1, node2)
868865

869866
def node_neighbors(self, input: Tensor, outgoing=True) -> Tensor:
870867
"""Returns the multiset of node neighbors of the input node.
871868
872869
Args:
873870
input (Tensor): Hypervector representing the node.
871+
outgoing (bool, optional): if ``True``, returns the neighboring nodes that ``input`` has an edge to. If ``False``, returns the neighboring nodes that ``input`` has an edge from. This only has effect for directed graphs. Default: ``True``.
874872
875873
Examples::
876874
877875
>>> G.node_neighbors(letters_hv[0])
878876
tensor([ 1., 1., 1., ..., -1., -1., 1.])
879877
880878
"""
881-
if self.directed:
879+
if self.is_directed:
882880
if outgoing:
883-
return functional.permute(functional.bind(self.value, input), shifts=-1)
881+
permuted_neighbors = functional.bind(self.value, input)
882+
return functional.permute(permuted_neighbors, shifts=-1)
884883
else:
885-
return functional.bind(self.value, functional.permute(input, shifts=1))
884+
permuted_node = functional.permute(input, shifts=1)
885+
return functional.bind(self.value, permuted_node)
886886
else:
887887
return functional.bind(self.value, input)
888888

0 commit comments

Comments
 (0)