Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making smf ForwardDiff-friendly #242

Open
cscherrer opened this issue Oct 18, 2022 · 2 comments
Open

Making smf ForwardDiff-friendly #242

cscherrer opened this issue Oct 18, 2022 · 2 comments

Comments

@cscherrer
Copy link
Collaborator

cscherrer commented Oct 18, 2022

We should be able to tell ForwardDiff that the derivative of smf is logdensityof. Here are timings for these separately for a StdNormal:

julia> using MeasureBase, ForwardDiff, BenchmarkTools

julia> using ForwardDiff: Dual

julia> @btime smf(StdNormal(), x) setup=(x=randn())
  5.660 ns (0 allocations: 0 bytes)
0.373722

julia> @btime logdensityof(StdNormal(), x) setup=(x=randn())
  2.294 ns (0 allocations: 0 bytes)
-0.935558

If we compute both, there's no extra overhead, and we even save a little;

julia> @btime (smf(StdNormal(), x), logdensityof(StdNormal(), x)) setup=(x=randn())
  7.151 ns (0 allocations: 0 bytes)
(0.206069, -1.25525)

It would be good to "teach" ForwardDiff to do this, because it's currently much slower:

julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x)  setup=(x=randn())
  13.181 ns (0 allocations: 0 bytes)
0.122241

julia> @btime smf(StdNormal(), Dual{}(x, one(x)))  setup=(x=randn())
  14.506 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.572796,0.392282)

I'd think this ought to do it:

julia> function MeasureBase.smf::StdNormal, x::Dual{TAG}) where TAG
           val = ForwardDiff.value(x)
           Δ = ForwardDiff.partials(x)
           Dual{TAG}(smf(μ, val), Δ * densityof(μ, val))
       end

But it doesn't:

julia> @btime ForwardDiff.derivative(x -> smf(StdNormal(), x), x)  setup=(x=randn())
  13.081 ns (0 allocations: 0 bytes)
0.370831

julia> @btime smf(StdNormal(), Dual{}(x, one(x)))  setup=(x=randn())
  14.898 ns (0 allocations: 0 bytes)
Dual{Nothing}(0.724177,0.334163)

How can we get this working properly?

It should also work if we call ForwardDiff.derivative(smf(StdNormal()), x)

@cscherrer
Copy link
Collaborator Author

Any ideas @oschulz ?

@cscherrer
Copy link
Collaborator Author

Oh wait, I'm doing logdensityof in some places, which is much cheaper 🤦‍♂️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant