Skip to content

Commit 9dc9769

Browse files
committed
make extract_gradient[_chunk]! GPU compatible
1 parent 6a19554 commit 9dc9769

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

Diff for: Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ julia = "1.6"
3232
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
3333
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
3434
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
35+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3536
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3637
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3738

3839
[targets]
39-
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"]
40+
test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils", "JLArrays"]

Diff for: src/gradient.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ function extract_gradient!(::Type{T}, result::DiffResult, dual::Dual) where {T}
8080
end
8181

8282
extract_gradient!(::Type{T}, result::AbstractArray, y::Real) where {T} = fill!(result, zero(y))
83-
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T}= copyto!(result, partials(T, dual))
83+
extract_gradient!(::Type{T}, result::AbstractArray, dual::Dual) where {T} =
84+
extract_gradient_chunk!(T, result, dual, 1, npartials(dual))
8485

8586
function extract_gradient_chunk!(::Type{T}, result, dual, index, chunksize) where {T}
86-
offset = index - 1
87-
for i in 1:chunksize
88-
result[i + offset] = partials(T, dual, i)
87+
map!(view(Base.ReshapedArray(result, (length(result),), ()), index:index+chunksize-1), 1:chunksize) do i
88+
@inbounds partials(T, dual, i)
8989
end
9090
return result
9191
end

Diff for: test/AllocationsTest.jl

+22
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,26 @@ convert_test_574() = convert(ForwardDiff.Dual{Nothing,ForwardDiff.Dual{Nothing,F
3737

3838
end
3939

40+
@testset "Test extract_gradient! allocations" begin
41+
T = Float64
42+
@testset "vector-mode size(result)=$size" for size in [(4,), (2,2)]
43+
dual = ForwardDiff.Dual(0, (rand(T, size...)...,))
44+
y = Array{T}(undef, size)
45+
alloc = @allocated ForwardDiff.extract_gradient!(Nothing, y, dual)
46+
alloc = @allocated ForwardDiff.extract_gradient!(Nothing, y, dual)
47+
@test alloc == 0
48+
end
49+
@testset "chunk-mode size(result)=$size" for size in [(DEFAULT_CHUNK_THRESHOLD+1,), (DEFAULT_CHUNK_THRESHOLD+1, DEFAULT_CHUNK_THRESHOLD+1)]
50+
Npartials = DEFAULT_CHUNK_THRESHOLD÷2
51+
dual = ForwardDiff.Dual(0, (rand(T, Npartials...)...,))
52+
y = Array{T}(undef, size)
53+
alloc = @allocated ForwardDiff.extract_gradient_chunk!(Nothing, y, dual, 2, Npartials)
54+
alloc = @allocated ForwardDiff.extract_gradient_chunk!(Nothing, y, dual, 2, Npartials)
55+
@test alloc == 0
56+
alloc = @allocated ForwardDiff.extract_gradient_chunk!(Nothing, y, dual, 2, Npartials-1)
57+
alloc = @allocated ForwardDiff.extract_gradient_chunk!(Nothing, y, dual, 2, Npartials-1)
58+
@test alloc == 0
59+
end
60+
end
61+
4062
end

Diff for: test/GradientTest.jl

+22
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using ForwardDiff
88
using ForwardDiff: Dual, Tag
99
using StaticArrays
1010
using DiffTests
11+
using JLArrays
12+
JLArrays.allowscalar(false)
1113

1214
include(joinpath(dirname(@__FILE__), "utils.jl"))
1315

@@ -149,6 +151,26 @@ end
149151
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
150152
end
151153

154+
155+
##############################################
156+
# test GPUArray compatibility (via JLArrays) #
157+
##############################################
158+
159+
println(" ...testing GPUArray compatibility (via JLArrays)")
160+
161+
@testset "size = $(size(x))" for x in JLArray.([
162+
rand(1),
163+
rand(DEFAULT_CHUNK_THRESHOLD+1),
164+
rand(1,1),
165+
rand(DEFAULT_CHUNK_THRESHOLD+1,DEFAULT_CHUNK_THRESHOLD+1),
166+
rand(1,1,1)
167+
])
168+
169+
@test ForwardDiff.gradient(prod, x) isa typeof(x)
170+
171+
end
172+
173+
152174
#############
153175
# bug fixes #
154176
#############

0 commit comments

Comments
 (0)