Skip to content

Commit 77fa781

Browse files
goggletimholy
authored andcommitted
Allow indexing of SparseMatrixCSC by array of CartesianIndex (#33225)
1 parent 04234fb commit 77fa781

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

base/multidimensional.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ module IteratorsMD
156156
return I
157157
end
158158

159+
Base._ind2sub(t::Tuple, ind::CartesianIndex) = Tuple(ind)
160+
159161
# Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred,
160162
# see #23719
161163
Base.iterate(::CartesianIndex) =

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ end
663663

664664
function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {Tv,Ti}
665665
require_one_based_indexing(A, I)
666+
@boundscheck checkbounds(A, I)
666667
szA = size(A)
667668
nA = szA[1]*szA[2]
668669
colptrA = getcolptr(A)
@@ -676,7 +677,6 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {T
676677

677678
idxB = 1
678679
for i in 1:n
679-
((I[i] < 1) | (I[i] > nA)) && throw(BoundsError(A, I))
680680
row,col = Base._ind2sub(szA, I[i])
681681
for r in colptrA[col]:(colptrA[col+1]-1)
682682
@inbounds if rowvalA[r] == row

stdlib/SparseArrays/test/sparse.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,21 @@ end
870870
end
871871
end
872872

873+
# indexing by array of CartesianIndex (issue #30981)
874+
S = sprand(10, 10, 0.4)
875+
inds_sparse = S[findall(S .> 0.2)]
876+
M = Matrix(S)
877+
inds_dense = M[findall(M .> 0.2)]
878+
@test Array(inds_sparse) == inds_dense
879+
inds_out = Array([CartesianIndex(1, 1), CartesianIndex(0, 1)])
880+
@test_throws BoundsError S[inds_out]
881+
pop!(inds_out); push!(inds_out, CartesianIndex(1, 0))
882+
@test_throws BoundsError S[inds_out]
883+
pop!(inds_out); push!(inds_out, CartesianIndex(11, 1))
884+
@test_throws BoundsError S[inds_out]
885+
pop!(inds_out); push!(inds_out, CartesianIndex(1, 11))
886+
@test_throws BoundsError S[inds_out]
887+
873888
# workaround issue #7197: comment out let-block
874889
#let S = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])
875890
S1290 = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])

0 commit comments

Comments
 (0)