-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix sequential enumeration #3238
base: dev
Are you sure you want to change the base?
Conversation
@qinqian thanks for looking at this. Can you turn the example above into a unit test in |
yes @eb8680 , I turned this example into a unit test in |
Does anyone know why Github Actions are not running after new commits? |
I think we may need permission to kick off the Github Actions. Thanks for your interest to the pull request. I tested it locally, it should work this time. |
if msg["is_observed"] or msg["infer"].get("enumerate") not in [ | ||
"parallel", | ||
"sequential", | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks suspicious. Wouldn't this change merely treat "sequential" like "parallel" enumeration, so sure of course the results would then agree.
Sorry about the github bug, sometimes I've needed to close a PR and open another. Aside from tests, can you explain your diagnosis of the problem and your proposed solution? From what I can tell, this PR amounts to "if the user says 'sequential' pretend they said 'parallel'", which seems wrong. But maybe I'm missing something. |
Yes @fritzo. The question is: sequential enumerate generate different results from parallel enumeration. I use two ways to diagnosis of the problem. The first is to add breakpoint to check the function here with the simple example above, and found that the key difference between the two enumerations is coming from the pyro/pyro/poutine/enum_messenger.py Line 123 in bca99e2
The second way is to compare the pyro EnumMessenger with funsor, they use a similar way as I proposed above. Please let me know if there is some misunderstanding of the problem. |
This is a pull request to fix the bug on the github issue.
with the same code, the sequential enumeration generates
0.6269999742507935
for 10000infer_discrete
operation withtemperature = 1
. Changingenum
variable toparallel
generates mean0.6294000148773193
. Usingtemperature=0
for MAP estimation of they_pre
will generatemean=1
for bothparallel
andsequential
enumeration. These tests are on a GCP VM machine with a Ubuntu docker image.