Skip to content

Commit

Permalink
add tests for partition function
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Dec 11, 2023
1 parent 227bdf5 commit a71a2e7
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions tests/new/model/test_tensorized_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def test_circuit_output_linear() -> None:
assert output.shape == (16, 1, 1) # shape (B=16, num_out=1, num_cls=1)
# TODO: this is currently not correct. how to fix???
assert floats.allclose(output, _get_circuit_2x2_output())
set_layer_comp_space("log") # TODO: use a with to tmp set default?


def test_circuit_output_log() -> None:
set_layer_comp_space("log")
circuit = _get_circuit_2x2()
_set_circuit_2x2_params(circuit)
all_inputs = torch.tensor(
Expand All @@ -129,5 +129,19 @@ def test_circuit_output_log() -> None:
) # shape (B=16, D=2, C=1).
output = circuit(all_inputs)
assert output.shape == (16, 1, 1) # shape (B=16, num_out=1, num_cls=1)
# TODO: this is currently not correct. how to fix???
assert floats.allclose(output, _get_circuit_2x2_output().log())


def test_circuit_part_func() -> None:
circuit = _get_circuit_2x2()
_set_circuit_2x2_params(circuit)
all_inputs = torch.tensor(
list(itertools.product([0, 1], repeat=4)) # type: ignore[misc]
).unsqueeze(
dim=-1
) # shape (B=16, D=2, C=1).
output = circuit(all_inputs) # shape (B=16, num_out=1, num_cls=1)
sum_output = torch.logsumexp(output, dim=0)
part_func = circuit.partition_func
assert floats.allclose(part_func, sum_output)
assert floats.allclose(part_func, 0.0)

0 comments on commit a71a2e7

Please sign in to comment.