Skip to content

Commit c204b02

Browse files
mikeheddespverges
andauthored
Add structures tests (#66)
* Structures testing, starting with memory * Structures testing, starting with multiset * Structures testing, starting with multiset * Structures testing, hashtable done * Structures testing, sequence done, two functions do not work * Structures testing, graph done, node neighbors not working and contains not working for directed graphs * tests * Structures testing sequence done * Structures testing, sequence and graph fix. Tree and FSA done * Structures testing, starting with memory * Structures testing, starting with multiset * Structures testing, starting with multiset * Structures testing, hashtable done * Structures testing, sequence done, two functions do not work * Structures testing, graph done, node neighbors not working and contains not working for directed graphs * tests * Structures testing sequence done * Structures testing, sequence and graph fix. Tree and FSA done * Structures testing, graph fixed * Structures testing, distinct sequence. Fixed Distinct Sequence init fixed * random hd incorrect out parameter * Use identity hv for distinct sequence * Update formatting * Fix tests * Fix embeddings out usage Co-authored-by: verges <[email protected]>
1 parent 8a84518 commit c204b02

11 files changed

+1245
-29
lines changed

torchhd/embeddings.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def reset_parameters(self):
4444
"device": self.weight.data.device,
4545
"dtype": self.weight.data.dtype,
4646
}
47-
functional.identity_hv(
48-
self.num_embeddings,
49-
self.embedding_dim,
50-
out=self.weight.data,
51-
**factory_kwargs
47+
48+
self.weight.data.copy_(
49+
functional.identity_hv(
50+
self.num_embeddings, self.embedding_dim, **factory_kwargs
51+
)
5252
)
5353

5454
self._fill_padding_idx_with_zero()
@@ -84,11 +84,11 @@ def reset_parameters(self):
8484
"device": self.weight.data.device,
8585
"dtype": self.weight.data.dtype,
8686
}
87-
functional.random_hv(
88-
self.num_embeddings,
89-
self.embedding_dim,
90-
out=self.weight.data,
91-
**factory_kwargs
87+
88+
self.weight.data.copy_(
89+
functional.random_hv(
90+
self.num_embeddings, self.embedding_dim, **factory_kwargs
91+
)
9292
)
9393

9494
self._fill_padding_idx_with_zero()
@@ -140,12 +140,14 @@ def reset_parameters(self):
140140
"device": self.weight.data.device,
141141
"dtype": self.weight.data.dtype,
142142
}
143-
functional.level_hv(
144-
self.num_embeddings,
145-
self.embedding_dim,
146-
randomness=self.randomness,
147-
out=self.weight.data,
148-
**factory_kwargs
143+
144+
self.weight.data.copy_(
145+
functional.level_hv(
146+
self.num_embeddings,
147+
self.embedding_dim,
148+
randomness=self.randomness,
149+
**factory_kwargs
150+
)
149151
)
150152

151153
self._fill_padding_idx_with_zero()
@@ -204,12 +206,14 @@ def reset_parameters(self):
204206
"device": self.weight.data.device,
205207
"dtype": self.weight.data.dtype,
206208
}
207-
functional.circular_hv(
208-
self.num_embeddings,
209-
self.embedding_dim,
210-
randomness=self.randomness,
211-
out=self.weight.data,
212-
**factory_kwargs
209+
210+
self.weight.data.copy_(
211+
functional.circular_hv(
212+
self.num_embeddings,
213+
self.embedding_dim,
214+
randomness=self.randomness,
215+
**factory_kwargs
216+
)
213217
)
214218

215219
self._fill_padding_idx_with_zero()

torchhd/structures.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Memory:
3030
3131
"""
3232

33-
def __init__(self, threshold=0.0):
33+
def __init__(self, threshold=0.5):
3434
self.threshold = threshold
3535
self.keys: List[Tensor] = []
3636
self.values: List[Any] = []
@@ -82,7 +82,7 @@ def index(self, key: Tensor) -> int:
8282
value, index = torch.max(sim, 0)
8383

8484
if value.item() < self.threshold:
85-
raise IndexError()
85+
raise IndexError("No elements in memory")
8686

8787
return index
8888

@@ -241,7 +241,7 @@ def clear(self) -> None:
241241

242242
@classmethod
243243
def from_ngrams(cls, input: Tensor, n=3):
244-
"""Creates a multiset from the ngrams of a set of hypervectors.
244+
r"""Creates a multiset from the ngrams of a set of hypervectors.
245245
246246
See: :func:`~torchhd.functional.ngrams`.
247247
@@ -273,7 +273,7 @@ def from_tensor(cls, input: Tensor):
273273
>>> M = structures.Multiset.from_tensor(x)
274274
275275
"""
276-
value = functional.multiset(input, dim=-2)
276+
value = functional.multiset(input)
277277
return cls(value, size=input.size(-2))
278278

279279

@@ -434,7 +434,7 @@ def from_tensors(cls, keys: Tensor, values: Tensor):
434434
435435
"""
436436
value = functional.hash_table(keys, values)
437-
return cls(value, size=input.size(-2))
437+
return cls(value, size=keys.size(-2))
438438

439439

440440
class Sequence:
@@ -663,7 +663,9 @@ def __init__(self, dim_or_input: int, **kwargs):
663663
else:
664664
dtype = kwargs.get("dtype", torch.get_default_dtype())
665665
device = kwargs.get("device", None)
666-
self.value = torch.zeros(dim_or_input, dtype=dtype, device=device)
666+
self.value = functional.identity_hv(
667+
1, dim_or_input, dtype=dtype, device=device
668+
).squeeze(0)
667669

668670
def append(self, input: Tensor) -> None:
669671
"""Appends the input tensor to the right of the sequence.
@@ -766,7 +768,7 @@ def clear(self) -> None:
766768
>>> DS.clear()
767769
768770
"""
769-
self.value.fill_(0.0)
771+
self.value.fill_(1.0)
770772
self.size = 0
771773

772774
@classmethod

torchhd/tests/structures/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import pytest
2+
import torch
3+
import string
4+
5+
from torchhd import structures, functional
6+
7+
seed = 2147483644
8+
letters = list(string.ascii_lowercase)
9+
10+
11+
class TestDistinctSequence:
12+
def test_creation_dim(self):
13+
S = structures.DistinctSequence(10000)
14+
assert torch.equal(S.value, torch.ones(10000))
15+
16+
def test_creation_tensor(self):
17+
generator = torch.Generator()
18+
generator.manual_seed(seed)
19+
hv = functional.random_hv(len(letters), 10000, generator=generator)
20+
21+
S = structures.DistinctSequence(hv[0])
22+
assert torch.equal(S.value, hv[0])
23+
24+
def test_generator(self):
25+
generator = torch.Generator()
26+
generator.manual_seed(seed)
27+
hv1 = functional.random_hv(60, 10000, generator=generator)
28+
29+
generator = torch.Generator()
30+
generator.manual_seed(seed)
31+
hv2 = functional.random_hv(60, 10000, generator=generator)
32+
33+
assert (hv1 == hv2).min().item()
34+
35+
def test_append(self):
36+
generator = torch.Generator()
37+
generator.manual_seed(seed)
38+
hv = functional.random_hv(len(letters), 10000, generator=generator)
39+
S = structures.DistinctSequence(10000)
40+
S.append(hv[0])
41+
assert functional.cosine_similarity(S.value, hv)[0] > 0.5
42+
43+
def test_appendleft(self):
44+
generator = torch.Generator()
45+
generator.manual_seed(seed)
46+
hv = functional.random_hv(len(letters), 10000, generator=generator)
47+
S = structures.DistinctSequence(10000)
48+
S.appendleft(hv[0])
49+
assert functional.cosine_similarity(S.value, hv)[0] > 0.5
50+
51+
def test_pop(self):
52+
generator = torch.Generator()
53+
generator.manual_seed(seed)
54+
hv = functional.random_hv(len(letters), 10000, generator=generator)
55+
S = structures.DistinctSequence(10000)
56+
S.append(hv[0])
57+
S.append(hv[1])
58+
S.pop(hv[1])
59+
assert functional.cosine_similarity(S.value, hv)[0] > 0.5
60+
S.pop(hv[0])
61+
S.append(hv[2])
62+
assert functional.cosine_similarity(S.value, hv)[2] > 0.5
63+
S.append(hv[3])
64+
S.pop(hv[3])
65+
assert functional.cosine_similarity(S.value, hv)[2] > 0.5
66+
67+
def test_popleft(self):
68+
generator = torch.Generator()
69+
generator.manual_seed(seed)
70+
hv = functional.random_hv(len(letters), 10000, generator=generator)
71+
S = structures.DistinctSequence(10000)
72+
S.appendleft(hv[0])
73+
S.appendleft(hv[1])
74+
S.popleft(hv[1])
75+
assert functional.cosine_similarity(S.value, hv)[0] > 0.5
76+
S.popleft(hv[0])
77+
S.appendleft(hv[2])
78+
assert functional.cosine_similarity(S.value, hv)[2] > 0.5
79+
S.appendleft(hv[3])
80+
S.popleft(hv[3])
81+
assert functional.cosine_similarity(S.value, hv)[2] > 0.5
82+
83+
def test_replace(self):
84+
generator = torch.Generator()
85+
generator.manual_seed(seed)
86+
hv = functional.random_hv(len(letters), 10000, generator=generator)
87+
S = structures.DistinctSequence(10000)
88+
S.append(hv[0])
89+
assert functional.cosine_similarity(S.value, hv)[0] > 0.5
90+
S.replace(0, hv[0], hv[1])
91+
assert functional.cosine_similarity(S.value, hv)[1] > 0.5
92+
93+
def test_length(self):
94+
generator = torch.Generator()
95+
generator.manual_seed(seed)
96+
hv = functional.random_hv(len(letters), 10000, generator=generator)
97+
S = structures.DistinctSequence(10000)
98+
S.append(hv[0])
99+
S.append(hv[0])
100+
S.append(hv[0])
101+
S.append(hv[0])
102+
assert len(S) == 4
103+
S.pop(hv[0])
104+
S.pop(hv[0])
105+
S.pop(hv[0])
106+
assert len(S) == 1
107+
S.pop(hv[0])
108+
assert len(S) == 0
109+
S.append(hv[0])
110+
assert len(S) == 1
111+
112+
def test_clear(self):
113+
generator = torch.Generator()
114+
generator.manual_seed(seed)
115+
hv = functional.random_hv(len(letters), 10000, generator=generator)
116+
S = structures.DistinctSequence(10000)
117+
S.append(hv[0])
118+
S.append(hv[0])
119+
S.append(hv[0])
120+
S.append(hv[0])
121+
assert len(S) == 4
122+
S.clear()
123+
assert len(S) == 0

torchhd/tests/structures/test_fsa.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
import torch
3+
import string
4+
5+
from torchhd import structures, functional
6+
7+
seed = 2147483644
8+
seed1 = 2147483643
9+
letters = list(string.ascii_lowercase)
10+
11+
12+
class TestFSA:
13+
def test_creation_dim(self):
14+
F = structures.FiniteStateAutomata(10000)
15+
assert torch.equal(F.value, torch.zeros(10000))
16+
17+
def test_generator(self):
18+
generator = torch.Generator()
19+
generator.manual_seed(seed)
20+
hv1 = functional.random_hv(60, 10000, generator=generator)
21+
22+
generator = torch.Generator()
23+
generator.manual_seed(seed)
24+
hv2 = functional.random_hv(60, 10000, generator=generator)
25+
26+
assert (hv1 == hv2).min().item()
27+
28+
def test_add_transition(self):
29+
generator = torch.Generator()
30+
generator1 = torch.Generator()
31+
generator.manual_seed(seed)
32+
generator1.manual_seed(seed1)
33+
tokens = functional.random_hv(10, 10, generator=generator)
34+
actions = functional.random_hv(10, 10, generator=generator1)
35+
36+
F = structures.FiniteStateAutomata(10)
37+
38+
F.add_transition(tokens[0], actions[1], actions[2])
39+
assert torch.equal(
40+
F.value,
41+
torch.tensor([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0]),
42+
)
43+
F.add_transition(tokens[1], actions[1], actions[3])
44+
assert torch.equal(
45+
F.value, torch.tensor([0.0, 0.0, -2.0, 2.0, 0.0, 2.0, 0.0, -2.0, -2.0, 0.0])
46+
)
47+
F.add_transition(tokens[2], actions[1], actions[3])
48+
assert torch.equal(
49+
F.value,
50+
torch.tensor([1.0, 1.0, -3.0, 1.0, 1.0, 3.0, -1.0, -1.0, -1.0, 1.0]),
51+
)
52+
53+
def test_transition(self):
54+
generator = torch.Generator()
55+
generator1 = torch.Generator()
56+
generator.manual_seed(seed)
57+
generator1.manual_seed(seed1)
58+
tokens = functional.random_hv(10, 10, generator=generator)
59+
states = functional.random_hv(10, 10, generator=generator1)
60+
61+
F = structures.FiniteStateAutomata(10)
62+
63+
F.add_transition(tokens[0], states[1], states[2])
64+
F.add_transition(tokens[1], states[1], states[3])
65+
F.add_transition(tokens[2], states[1], states[5])
66+
67+
assert (
68+
torch.argmax(
69+
functional.cosine_similarity(F.transition(states[1], tokens[0]), states)
70+
).item()
71+
== 2
72+
)
73+
assert (
74+
torch.argmax(
75+
functional.cosine_similarity(F.transition(states[1], tokens[1]), states)
76+
).item()
77+
== 3
78+
)
79+
assert (
80+
torch.argmax(
81+
functional.cosine_similarity(F.transition(states[1], tokens[2]), states)
82+
).item()
83+
== 5
84+
)
85+
86+
def test_clear(self):
87+
generator = torch.Generator()
88+
generator1 = torch.Generator()
89+
generator.manual_seed(seed)
90+
generator1.manual_seed(seed1)
91+
tokens = functional.random_hv(10, 10, generator=generator)
92+
states = functional.random_hv(10, 10, generator=generator1)
93+
94+
F = structures.FiniteStateAutomata(10)
95+
96+
F.add_transition(tokens[0], states[1], states[2])
97+
F.add_transition(tokens[1], states[1], states[3])
98+
F.add_transition(tokens[2], states[1], states[5])
99+
100+
F.clear()
101+
assert torch.equal(
102+
F.value, torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
103+
)

0 commit comments

Comments
 (0)