Skip to content

Commit

Permalink
strengthen tensor train circuit template unite test
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Feb 5, 2025
1 parent 226c66d commit c733a96
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/templates/test_tensor_factorizations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools

import numpy as np
import pytest

from cirkit.symbolic.dtypes import DataType
Expand Down Expand Up @@ -137,6 +138,19 @@ def test_factorization_tensor_train(rank: int, factor_param: Parameterization |
assert all(isinstance(sl, HadamardLayer) for sl in product_layers)
assert len(product_layers) == (len(shape) - 2) * rank + 1
assert len(sum_layers) == len(shape) - 1
for sl in sum_layers:
assert len(sl.weight.nodes) == 1
weight = sl.weight.nodes[0]
assert isinstance(weight, ConstantParameter)
value = np.reshape(weight.value, shape=(sl.num_output_units, sl.arity, sl.num_input_units))
ones = np.ones(sl.num_input_units)
zeros = np.zeros(sl.num_input_units)
for i in range(sl.num_output_units):
assert np.all(value[i, i] == ones)
for j in range(sl.num_output_units):
if i == j:
continue
assert np.all(value[i, j] == zeros)
if factor_param is not None:
for sl in input_layers:
assert len(sl.weight.nodes) == 1
Expand Down

0 comments on commit c733a96

Please sign in to comment.