Skip to content

DenseMetric and Component arrays (solve #344)#345

Closed
erathorn wants to merge 4 commits intoTuringLang:mainfrom
erathorn:master
Closed

DenseMetric and Component arrays (solve #344)#345
erathorn wants to merge 4 commits intoTuringLang:mainfrom
erathorn:master

Conversation

@erathorn
Copy link
Copy Markdown

@erathorn erathorn commented Aug 3, 2023

This PR attempts to solve #344

I went for the solution to preallocate the result in ∂H∂r such that the type of the input r matches the type of the output.
I added tests that not only check the correct numerical output but also check the type.

Additionally, I found a typo in the inner constructor of PhasePoint, where it originally was length(ℓπ.gradient) == length(ℓπ.gradient).

Copy link
Copy Markdown
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @erathorn ! And nice catch on the constructor for Phasepoint!

I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?

Comment thread src/hamiltonian.jl
Comment on lines +45 to +49
function ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat)
out = similar(r) # Make sure the output of this function is of the same type as r
mul!(out, h.metric.M⁻¹, r)
out
end
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit uncertain about this change as it "complicates" code to mainly just stay compatible with ComponentArrays.jl, and thus I'd be more in favour of just making it an extension instead, I think 😕 Then in the extension, we just overload whatever we need to be compatible.

Also, will this code break if, say, h.metric.M⁻¹ has eltype Float64 but r has eltype Float32, rather than just promoting, as is current behavior?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p1 = ComponentArray(m=one(Float32), s = one(Float32))
r = similar(p1)
M = diagm(randn(Float64, 2))
mul!(r, M, p1)

This works on my machine, and returns r as a component array of eltype Float32 as expected.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AHMC supports vectorised sampling, when passing arguments in a suitable type. In this case, r::AbstractVecOrMat could be a single momentum realization or a vector of momentum realizations. Therefore, the new code needs to be able to handle the vectorized sampling mode for the tests to pass.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the silence. Thank you for the suggestion, it totally makes sense to me. However, I looked into this a bit more and am honestly slightly lost. The call to the rand function, which fails in the tests only works in the test case. Calling this function in a plain Julia session fails for me (on the main branch). A brute force solution, which dispatches on r::AbtractVecOrMat{AbstractVecOrMat}, does unfortunately not do the trick either.

Comment thread test/hamiltonian.jl Outdated
Comment thread src/hamiltonian.jl
@erathorn
Copy link
Copy Markdown
Author

erathorn commented Aug 3, 2023

I'm wondering if we should just make compat with ComponentArrays.jl an extension instead of complicating the existing code, and then we can just overload whatever we need there. Thoughts?

I went for this "complication" because of the comment next to safe_rsimilar and the phasepoint function taking different types. Which I understood as "workarounds" without explicit dependence.

Co-authored-by: Tor Erlend Fjelde <[email protected]>
@erathorn
Copy link
Copy Markdown
Author

erathorn commented Aug 3, 2023

The tests fail at this line:
https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/test/metric.jl#L13C1-L13C34

The problem seems to be, that the implementations of rand do not have the correct signature.

https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/metric.jl#L128C1-L136C38

Base.rand(rng::AbstractVector{<:AbstractRNG}) = rand.(rng)

https://github.com/TuringLang/AdvancedHMC.jl/blob/eb9b2e0d60ef3dd85768d6e6a9f19de15b8f7130/src/utilities.jl#L9C1-L12C4

However, I have not touched any of this at all. 🤔

@torfjelde
Copy link
Copy Markdown
Member

The tests fail at this line:

This is indeed strange given that the CI on master is working just fine 😕

@yebai yebai requested a review from ErikQQY March 16, 2025 16:58
@yebai yebai requested review from torfjelde and removed request for ErikQQY and torfjelde March 16, 2025 16:58
@yebai
Copy link
Copy Markdown
Member

yebai commented Apr 8, 2025

Closed in favour of #407

@yebai yebai closed this Apr 8, 2025
@yebai yebai mentioned this pull request Apr 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants