Skip to content

Commit 8bdd5b7

Browse files
committed
Update test to check that coords and dim_lengths are preserved
1 parent beccca4 commit 8bdd5b7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/model/transform/test_basic.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ def test_remove_minibatches():
3838
data_size = 100
3939
data = np.zeros((data_size,))
4040
batch_size = 10
41-
with pm.Model() as m1:
41+
with pm.Model(coords={"d": range(5)}) as m1:
4242
mb = pm.Minibatch(data, batch_size=batch_size)
43+
mu = pm.Normal("mu", dims="d")
4344
x = pm.Normal("x")
4445
y = pm.Normal("y", x, observed=mb, total_size=100)
4546

4647
m2 = remove_minibatched_nodes(m1)
4748
assert m1.y.shape[0].eval() == batch_size
4849
assert m2.y.shape[0].eval() == data_size
50+
assert m1.coords == m2.coords
51+
assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval()

0 commit comments

Comments
 (0)