We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent beccca4 commit 8bdd5b7Copy full SHA for 8bdd5b7
tests/model/transform/test_basic.py
@@ -38,11 +38,14 @@ def test_remove_minibatches():
38
data_size = 100
39
data = np.zeros((data_size,))
40
batch_size = 10
41
- with pm.Model() as m1:
+ with pm.Model(coords={"d": range(5)}) as m1:
42
mb = pm.Minibatch(data, batch_size=batch_size)
43
+ mu = pm.Normal("mu", dims="d")
44
x = pm.Normal("x")
45
y = pm.Normal("y", x, observed=mb, total_size=100)
46
47
m2 = remove_minibatched_nodes(m1)
48
assert m1.y.shape[0].eval() == batch_size
49
assert m2.y.shape[0].eval() == data_size
50
+ assert m1.coords == m2.coords
51
+ assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval()
0 commit comments