Skip to content

Commit

Permalink
Merge branch 'master' into removelocal
Browse files Browse the repository at this point in the history
  • Loading branch information
deckerla authored May 13, 2022
2 parents 2e5273d + f6dd212 commit 830beba
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 42 deletions.
113 changes: 73 additions & 40 deletions src/Devito.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,44 +100,8 @@ function Base.fill!(x::DevitoMPIAbstractArray, v)
x
end

function in_range(i::Int, ranges)
for rang in enumerate(ranges)
if i rang[2]
return rang[1]
end
end
error("Outside Valid Ranges")
end

function helix_helper(tup::NTuple{N,Int}) where {N}
wrapper = (1,)
for i in 2:N
wrapper = (wrapper..., wrapper[1]*tup[i-1])
end
return wrapper
end

function find_rank(x::DevitoMPIAbstractArray{T,N}, I::Vararg{Int,N}) where {T,N}
decomp = decomposition(x)
rank_position = in_range.(I,decomp)
helper = helix_helper(topology(x))
rank = sum((rank_position .- 1) .* helper)
return rank
end

shift_localindicies(i::Int, indices::UnitRange{Int}) = i - indices[1] + 1

function Base.getindex(x::DevitoMPIAbstractArray{T,N}, I::Vararg{Int,N}) where {T,N}
v = nothing
if all(ntuple(idim->I[idim] localindices(x)[idim], N))
J = ntuple(idim-> shift_localindicies( I[idim], localindices(x)[idim]), N)
v = getindex(x.p, J...)
end
v = MPI.bcast(v, find_rank(x, I...), MPI.COMM_WORLD)
v
end

Base.setindex!(x::DevitoMPIAbstractArray{T,N}, v, i) where {T,N} = error("not implemented")
Base.setindex!(x::DevitoMPIAbstractArray{T,N}, v, I::Vararg{Int,N}) where {T,N} = error("not implemented")
Base.IndexStyle(::Type{<:DevitoMPIAbstractArray}) = IndexCartesian()

struct DevitoMPIArray{T,N,A<:AbstractArray{T,N},D} <: DevitoMPIAbstractArray{T,N}
Expand Down Expand Up @@ -183,11 +147,11 @@ function Base.convert(::Type{Array}, x::DevitoMPIAbstractArray{T,N}) where {T,N}
y = Array{T}(undef, ntuple(_->0, N))
y_vbuffer = VBuffer(nothing)
end

_x = zeros(T, size(parent(x)))
copyto!(_x, parent(x))
MPI.Gatherv!(_x, y_vbuffer, 0, MPI.COMM_WORLD)

if MPI.Comm_rank(MPI.COMM_WORLD) == 0
_y = convert_resort_array!(Array{T,N}(undef, size(x)), y, x.topology, x.decomposition)
else
Expand All @@ -211,7 +175,7 @@ end

function Base.copyto!(dst::DevitoMPIArray{T,N}, src::AbstractArray{T,N}) where {T,N}
_counts = counts(dst)

if MPI.Comm_rank(MPI.COMM_WORLD) == 0
_y = copyto_resort_array!(Vector{T}(undef, length(src)), src, dst.topology, dst.decomposition)
data_vbuffer = VBuffer(_y, _counts)
Expand Down Expand Up @@ -355,6 +319,75 @@ function Base.copyto!(dst::DevitoMPISparseTimeArray{T,N}, src::Array{T,N}) where
copyto!(parent(dst), _dst)
end

function in_range(i::Int, ranges)
for rang in enumerate(ranges)
if i rang[2]
return rang[1]
end
end
error("Outside Valid Ranges")
end

function helix_helper(tup::NTuple{N,Int}) where {N}
wrapper = (1,)
for i in 2:N
wrapper = (wrapper..., wrapper[1]*tup[i-1])
end
return wrapper
end

function find_rank(x::DevitoMPIArray{T,N}, I::Vararg{Int,N}) where {T,N}
decomp = decomposition(x)
rank_position = in_range.(I,decomp)
helper = helix_helper(topology(x))
rank = sum((rank_position .- 1) .* helper)
return rank
end

function find_rank(x::DevitoMPITimeArray{T,N}, I::Vararg{Int,N}) where {T,N}
decomp = decomposition(x)[1:end-1]
J = I[1:end-1]
rank_position = in_range.(J,decomp)
helper = helix_helper(topology(x))
rank = sum((rank_position .- 1) .* helper)
return rank
end

function find_rank(x::DevitoMPISparseTimeArray{T,N}, I::Vararg{Int,2}) where {T,N}
decomp = decomposition(x)[1:end-1]
J = I[1]
rank_position = in_range.(J,decomp)
helper = helix_helper(topology(x))
rank = sum((rank_position .- 1) .* helper)
return rank
end

shift_localindicies(i::Int, indices::UnitRange{Int}) = i - indices[1] + 1

shift_localindicies(i::Int, indices::Int) = i - indices + 1

function Base.getindex(x::Union{DevitoMPIArray{T,N},DevitoMPITimeArray{T,N}}, I::Vararg{Int,N}) where {T,N}
v = nothing
wanted_rank = find_rank(x, I...)
if MPI.Comm_rank(MPI.COMM_WORLD) == wanted_rank
J = ntuple(idim-> shift_localindicies( I[idim], localindices(x)[idim]), N)
v = getindex(x.p, J...)
end
v = MPI.bcast(v, wanted_rank, MPI.COMM_WORLD)
v
end

function Base.getindex(x::DevitoMPISparseTimeArray{T,N}, I::Vararg{Int,2}) where {T,N}
v = nothing
wanted_rank = find_rank(x, I...)
if MPI.Comm_rank(MPI.COMM_WORLD) == wanted_rank
J = (shift_localindicies( I[1], localindices(x)[1]), I[2])
v = getindex(x.p, J...)
end
v = MPI.bcast(v, wanted_rank, MPI.COMM_WORLD)
v
end

#
# Dimension
#
Expand Down
60 changes: 59 additions & 1 deletion test/mpitests_2ranks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,13 @@ end
end
end

@testset "MPI Getindex" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
@testset "MPI Setindex Not Implemented" begin
grid = Grid(shape=(5,6,7))
f = Devito.Function(name="f", grid=grid)
@test_throws ErrorException("not implemented") data(f)[2,2,2] = 1.0
end

@testset "MPI Getindex for Function n=$n" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
N = length(n)
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
Expand All @@ -646,3 +652,55 @@ end
@test data(f)[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:] arr[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:]
end
end

@testset "MPI Getindex for TimeFunction n=$n" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
N = length(n)
nt = 5
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
f = TimeFunction(name="f", grid=grid, save=nt)
arr = reshape(1f0*[1:prod(size(data(f)));], size(data(f)))
copy!(data(f), arr)
nchecks = 10
Random.seed!(1234);
for check in 1:nchecks
i = rand((1:n[1]))
j = rand((1:n[2]))
I = (i,j)
if N == 3
k = rand((1:n[3]))
I = (i,j,k)
end
m = rand((1:nt))
I = (I...,m)
@test data(f)[I...] == arr[I...]
end
if N == 2
@test data(f)[1:div(n[1],2),:,1:div(nt,2)] arr[1:div(n[1],2),:,1:div(nt,2)]
else
@test data(f)[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:,1:div(nt,2)] arr[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:,1:div(nt,2)]
end
end

@testset "MPI Getindex for SparseTimeFunction n=$n npoint=$npoint" for n in ( (5,4),(4,5,6) ), npoint in (1,5,10)
N = length(n)
nt = 5
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
f = SparseTimeFunction(name="f", grid=grid, nt=nt, npoint=npoint)
arr = reshape(1f0*[1:prod(size(data(f)));], size(data(f)))
copy!(data(f), arr)
nchecks = 10
Random.seed!(1234);
for check in 1:nchecks
i = rand((1:npoint))
j = rand((1:nt))
I = (i,j)
@test data(f)[I...] == arr[I...]
end
if npoint > 1
@test data(f)[1:div(npoint,2),2:end-1] arr[1:div(npoint,2),2:end-1]
else
@test data(f)[1,2:end-1] arr[1,2:end-1]
end
end
54 changes: 53 additions & 1 deletion test/mpitests_4ranks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ end
end
end

@testset "MPI Getindex" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
@testset "MPI Getindex for Function n=$n" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
N = length(n)
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
Expand All @@ -329,3 +329,55 @@ end
@test data(f)[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:] arr[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:]
end
end

@testset "MPI Getindex for TimeFunction n=$n" for n in ( (11,10), (5,4), (7,2), (4,5,6), (2,3,4) )
N = length(n)
nt = 5
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
f = TimeFunction(name="f", grid=grid, save=nt)
arr = reshape(1f0*[1:prod(size(data(f)));], size(data(f)))
copy!(data(f), arr)
nchecks = 10
Random.seed!(1234);
for check in 1:nchecks
i = rand((1:n[1]))
j = rand((1:n[2]))
I = (i,j)
if N == 3
k = rand((1:n[3]))
I = (i,j,k)
end
m = rand((1:nt))
I = (I...,m)
@test data(f)[I...] == arr[I...]
end
if N == 2
@test data(f)[1:div(n[1],2),:,1:div(nt,2)] arr[1:div(n[1],2),:,1:div(nt,2)]
else
@test data(f)[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:,1:div(nt,2)] arr[1:div(n[1],2),div(n[2],3):2*div(n[2],3),:,1:div(nt,2)]
end
end

@testset "MPI Getindex for SparseTimeFunction n=$n npoint=$npoint" for n in ( (5,4),(4,5,6) ), npoint in (1,5,10)
N = length(n)
nt = 5
rnk = MPI.Comm_rank(MPI.COMM_WORLD)
grid = Grid(shape=n, dtype=Float32)
f = SparseTimeFunction(name="f", grid=grid, nt=nt, npoint=npoint)
arr = reshape(1f0*[1:prod(size(data(f)));], size(data(f)))
copy!(data(f), arr)
nchecks = 10
Random.seed!(1234);
for check in 1:nchecks
i = rand((1:npoint))
j = rand((1:nt))
I = (i,j)
@test data(f)[I...] == arr[I...]
end
if npoint > 1
@test data(f)[1:div(npoint,2),2:end-1] arr[1:div(npoint,2),2:end-1]
else
@test data(f)[1,2:end-1] arr[1,2:end-1]
end
end
1 change: 1 addition & 0 deletions test/serialtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ end
dm = SubDimensionMiddle(name="dm", parent=d, thickness_left=2, thickness_right=3)
for subdim in (dl,dr,dm)
@test parent(subdim) == d
@test PyObject(subdim) == subdim.o
end
@test (thickness(dl)[1][2], thickness(dl)[2][2]) == (2, 0)
@test (thickness(dr)[1][2], thickness(dr)[2][2]) == (0, 3)
Expand Down

0 comments on commit 830beba

Please sign in to comment.