DenseMetric and Component arrays (solve #344)#345
DenseMetric and Component arrays (solve #344)#345erathorn wants to merge 4 commits intoTuringLang:mainfrom
Conversation
torfjelde
left a comment
There was a problem hiding this comment.
Thanks @erathorn ! And nice catch on the constructor for Phasepoint!
I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?
| function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) | ||
| out = similar(r) # Make sure the output of this function is of the same type as r | ||
| mul!(out, h.metric.M⁻¹, r) | ||
| out | ||
| end |
There was a problem hiding this comment.
I'm a bit uncertain about this change as it "complicates" code to mainly just stay compatible with ComponentArrays.jl, and thus I'd be more in favour of just making it an extension instead, I think 😕 Then in the extension, we just overload whatever we need to be compatible.
Also, will this code break if, say, h.metric.M⁻¹ has eltype Float64 but r has eltype Float32, rather than just promoting, as is current behavior?
There was a problem hiding this comment.
p1 = ComponentArray(m=one(Float32), s = one(Float32))
r = similar(p1)
M = diagm(randn(Float64, 2))
mul!(r, M, p1)This works on my machine, and returns r as a component array of eltype Float32 as expected.
There was a problem hiding this comment.
AHMC supports vectorised sampling, when passing arguments in a suitable type. In this case, r::AbstractVecOrMat could be a single momentum realization or a vector of momentum realizations. Therefore, the new code needs to be able to handle the vectorized sampling mode for the tests to pass.
There was a problem hiding this comment.
Sorry for the silence. Thank you for the suggestion, it totally makes sense to me. However, I looked into this a bit more and am honestly slightly lost. The call to the rand function, which fails in the tests only works in the test case. Calling this function in a plain Julia session fails for me (on the main branch). A brute force solution, which dispatches on r::AbtractVecOrMat{AbstractVecOrMat}, does unfortunately not do the trick either.
I went for this "complication" because of the comment next to |
Co-authored-by: Tor Erlend Fjelde <[email protected]>
|
The tests fail at this line: The problem seems to be, that the implementations of AdvancedHMC.jl/src/utilities.jl Line 5 in eb9b2e0 However, I have not touched any of this at all. 🤔 |
This is indeed strange given that the CI on master is working just fine 😕 |
|
Closed in favour of #407 |
This PR attempts to solve #344
I went for the solution to preallocate the result in
∂H∂rsuch that the type of the inputrmatches the type of the output.I added tests that not only check the correct numerical output but also check the type.
Additionally, I found a typo in the inner constructor of PhasePoint, where it originally was
length(ℓπ.gradient) == length(ℓπ.gradient).