Skip to content

Commit cc1bed8

Browse files
williambdeantwiecki
authored andcommitted
run pre-commit
1 parent 9b5aa8e commit cc1bed8

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

pymc_extras/prior.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class VariableFactory(Protocol):
335335
336336
class PowerSumDistribution:
337337
"""Create a distribution that is the sum of powers of a base distribution."""
338+
338339
def __init__(self, distribution: VariableFactory, n: int):
339340
self.distribution = distribution
340341
self.n = n
@@ -345,7 +346,12 @@ def dims(self):
345346
346347
def create_variable(self, name: str) -> "TensorVariable":
347348
raw = self.distribution.create_variable(f"{name}_raw")
348-
return pm.Deterministic(name, pt.sum([raw ** n for n in range(1, self.n + 1)], axis=0), dims=self.dims,)
349+
return pm.Deterministic(
350+
name,
351+
pt.sum([raw**n for n in range(1, self.n + 1)], axis=0),
352+
dims=self.dims,
353+
)
354+
349355
350356
cubic = PowerSumDistribution(Prior("Normal"), n=3)
351357
samples = sample_prior(cubic)
@@ -533,8 +539,10 @@ class Prior:
533539
534540
from pymc_extras.prior import register_tensor_transform
535541
542+
536543
def custom_transform(x):
537-
return x ** 2
544+
return x**2
545+
538546
539547
register_tensor_transform("square", custom_transform)
540548

0 commit comments

Comments
 (0)