-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more tests for collapse #2702
base: dev
Are you sure you want to change the base?
Conversation
with pyro.plate("data", T, dim=-1): | ||
expand_shape = (d,) if num_particles == 1 else (num_particles, 1, d) | ||
y = pyro.sample("y", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) | ||
pyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails due to 2 reasons:
y.output
is Reals[d], which will raise error while infer_param_domain in funsor:Output mismatch: Reals[2] vs Real
- to_event is not available for Funsor distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I wonder if there is always enough information available in the args of expanded distributions to automatically determine the event dim, so that we could make .to_event()
a no-op on funsors. For example here we could deduce the event shape from y.output
.
@eb8680 would it be possible to support this in to_funsor()
and to_data()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to support this in to_funsor() and to_data()?
It should be possible, at least in principle, and if we want collapse
to work seamlessly with models that use .to_event
we'll need something like that. I think it will require changing the way type inference works in funsor.distribution.Distribution
, though.
The point of failure in Funsor is the to_funsor
conversion of parameters in funsor.distribution.DistributionMeta.__call__
:
class DistributionMeta(FunsorMeta):
def __call__(cls, *args, **kwargs):
kwargs.update(zip(cls._ast_fields, args))
value = kwargs.pop('value', 'value')
# Failure occurs here -------v
kwargs = OrderedDict(
(k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))))
for k in cls._ast_fields if k != 'value')
# this is also incorrect
value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()}))
args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,)))
return super(DistributionMeta, cls).__call__(*args)
In this test case, kwargs[k]
is loc = Variable("y", Reals[2])
, and Normal._infer_param_domain(...)
thinks the output should be Real
instead of Reals[2]
.
I think a general solution (at least when all parameters have the same broadcasted .output
) is to compute unbroadcasted parameter and value shapes up front, then broadcast:
# compute unbroadcasted domains
domains = {k: cls._infer_param_domain(k, getattr(kwargs[k], "shape")) for k in kwargs if k != "value"}
domains["value"] = cls._infer_value_domain(cls, **domains)
# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
for k, v in kwargs.items():
if isinstance(v, Funsor):
domains[k] = Reals[broadcast_shapes(v.shape, domains[k].shape)[0]] # assume all Reals for exposition
# broadcast value domain, which depends on param domains, with broadcasted param domains
domains["value"] = Reals[broadcast_shapes(domains["value"], *domains.values())[0]]
Now the previously incorrect to_funsor
conversions reduce to
kwargs = OrderedDict((k, to_funsor(kwargs[k], output=domains[k])) for k in kwargs)
value = kwargs["value"]
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I put up a draft Funsor PR here: pyro-ppl/funsor#402
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Eli! testing now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe what's in that PR after my recent push is sufficient for this test case, although it's not ready to merge because of some edge cases in the distributions. Let me know if it's not working.
beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1)) | ||
beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S))) | ||
|
||
mean = torch.ones((T, d)) @ beta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails because beta.output
is Reals[S]
while we need it to be Reals[d, S]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this seems like a major incompatibility between funsor-style operations which have a clear batch/event split and numpy-style operations where everything is effectively an event dim. In particular there's no way for ops.matmul
to touch the batch dimension "d".
One workaround in this model would be to move beta0
and beta
out of the plate and instead use .to_event(1)
, and I think that kind of workaround will be needed whenever we exit a plate and treat a formerly-batch dimension as an event dimension. Conversely I conjecture that no workarounds are needed in tree-plated models, that is in models where "no variable outside a plate ever depends on an upstream variable inside a plate"; this class is similar to our condition for tractable tensor variable elimination.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like if we had a way of knowing the plate context of a value both at the time of its creation and each time it was accessed, we could handle this smoothly using funsor.terms.Lambda
. Suppose we could overload variable access/assignment, e.g. by using an overloadable environment data structure env
rather than locals()
:
with plate("plate_var", d, dim=-1):
beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1))
# setting env.beta also records the current plate context of beta
env.beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S)))
# use funsor.terms.Lambda to convert the dead plate dimension to an output:
# now reading env.beta returns funsor.Lambda(plate_var, beta)
# where plate_var is the difference between the current and original plate contexts
mean = torch.ones((T, d)) @ env.beta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I didn't know about Lambda
. I am not sure why we need env
but it seems from your code that in the collapse
code, when the output does not match plate
infos, we can use Lambda
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why we need env
I was suggesting it as a layer of automation on top of Lambda
that would help keep track of which plate dimensions need to be converted to event dimensions via Lambda
, just like contrib.funsor.to_funsor
/to_data
automatically track the name_to_dim
mapping for enumeration. You're right that it's also possible to use Lambda
directly in user code.
|
||
mean = torch.ones((T, d)) @ beta | ||
with pyro.plate("data", T, dim=-1): | ||
pyro.sample("obs", dist.MultivariateNormal(mean, torch.eye(S)), obs=data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails due to a similar reason to the diag_normal_plate_normal
test above:
- funsor requires
mean.output
isReals[S]
while the output ofmean
after takingmatmul
isReals[T, S]
Ported failing tests from pyro-ppl/numpyro#809 to discuss.
Here in Pyro, we don't face issues like in NumPyro. The only issue is the finalEdit: the same issues happen (I got the wrong impression due to trace back mechanism in collapse - which is fixed in this PR)log_prob
is a Contraction, not a Tensor. I guess we need some pattern here to make it work.