Add object mode fallback for Numba RandomVariables#1249
Conversation
When we find a RandomVariable that doesn't have a Numba implementation, we now fallback to object mode instead of failing with NotImplementedError. This provides a more graceful degradation path for random variables that don't yet have specialized Numba implementations. - Added rv_fallback_impl function to create object mode implementation - Modified numba_funcify_RandomVariable to catch NotImplementedError - Added test for unsupported random variable fallback 🤖 Generated with Claude Code Co-Authored-By: Claude <noreply@anthropic.com>
|
Would be good if it referenced the original issues (there's a PR template you could tell it to fill). We shouldn't use it for beginner friendly issues, that's the point of marking them as beginner friendly? Fine if you're just testing. I would be much more excited if it tackled docs issues. Like ask it to fix and finish the PR related to: #292 , #830 |
ricardoV94
left a comment
There was a problem hiding this comment.
There were also some RVs that weren't being tested because we were not falling back to objmode. Test them now
| [rv_node] = op.fgraph.apply_nodes | ||
| rv_op: RandomVariable = rv_node.op | ||
|
|
||
| warnings.warn( |
There was a problem hiding this comment.
We already have a generic fallback implementation function can't we just use it like we do for other Ops?
May just need to do the unboxing of the RV that the other function is doing
| inplace = rv_op.inplace | ||
|
|
||
| try: | ||
| core_rv_fn = numba_core_rv_funcify(rv_op, rv_node) |
There was a problem hiding this comment.
only this line should be in the try except
| # Create a mock random variable that doesn't have a numba implementation | ||
| class CustomRV(ptr.RandomVariable): | ||
| name = "custom" | ||
| signature = "(d)->(d)" # We need a parameter for test to pass |
There was a problem hiding this comment.
create a univariate rv which will be a simpler test
| x = custom_rv(value, rng=rng) | ||
|
|
||
| # Capture warnings to check for the fallback warning | ||
| with warnings.catch_warnings(record=True) as w: |
| # Run again to make sure the compiled function works properly | ||
| result2 = fn() | ||
| assert isinstance(result2, np.ndarray) | ||
| assert not np.array_equal( |
There was a problem hiding this comment.
This will fail because updates were not set
There was a problem hiding this comment.
Actually test with and without updates, in which case it should change or stay the same
There was a problem hiding this comment.
Also set seed twice and compare to make sure it's following it
|
Top post does not include Related to or Closes # Edit: I repeated myself |
Fixes https://github.com/pymc-devs/pytensor/issues/1245\n\nSummary:\n- When a RandomVariable is not implemented in Numba, it now gracefully falls back to object mode.\n- Added tests to verify that unsupported RandomVariables correctly trigger the object mode fallback.\n- This update ensures a smoother degradation experience and improves testing coverage.\n\nCloses #1245\n\nTest Plan:\n- Run the test suite using pytest to ensure no regressions occur.\n\nAcknowledgements:\n- Thanks to ricardoV94 for the feedback and review comments.