Skip to content

Commit c88cf52

Browse files
fix: fix adjoint/transpose shape handling
1 parent 604393f commit c88cf52

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

src/types.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,9 +2238,9 @@ end
22382238
is_array_shape(sh::ShapeT) = sh isa Unknown || _ndims_from_shape(sh) > 0
22392239
function _multiplied_shape(shapes)
22402240
first_arr = findfirst(is_array_shape, shapes)
2241-
first_arr === nothing && return ShapeVecT(), first_arr
2241+
first_arr === nothing && return ShapeVecT(), first_arr, nothing
22422242
last_arr::Int = findlast(is_array_shape, shapes)
2243-
first_arr == last_arr && return shapes[first_arr], first_arr
2243+
first_arr == last_arr && return shapes[first_arr], first_arr, last_arr
22442244

22452245
sh1::ShapeT = shapes[first_arr]
22462246
shend::ShapeT = shapes[last_arr]
@@ -2275,7 +2275,7 @@ function _multiplied_shape(shapes)
22752275
cur_shape = sh
22762276
end
22772277

2278-
return result, first_arr
2278+
return result, first_arr, last_arr
22792279
end
22802280

22812281
function promote_shape(::typeof(*), shs::ShapeT...)
@@ -2284,19 +2284,22 @@ end
22842284

22852285
const AdjointOrTranspose = Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}
22862286

2287-
function _check_adjoint_or_transpose(terms, result::ShapeT, first_arr::Union{Int, Nothing})
2288-
@nospecialize first_arr result
2287+
function _check_adjoint_or_transpose(terms, result::ShapeT, first_arr::Union{Int, Nothing}, last_arr::Union{Int, Nothing})
2288+
@nospecialize first_arr result last_arr
22892289
first_arr === nothing && return result
2290+
last_arr = last_arr::Int
2291+
first_arr == last_arr && return result
22902292
farr = terms[first_arr]
2291-
if result isa ShapeVecT && length(result) <= 2 && all(==(1) length, result) && (farr isa AdjointOrTranspose || iscall(farr) && (operation(farr) === adjoint || operation(farr) === transpose))
2293+
ndlarr = ndims(terms[last_arr])
2294+
if result isa ShapeVecT && length(result) <= 2 && all(isone length, result) && (farr isa AdjointOrTranspose || iscall(farr) && (operation(farr) === adjoint || operation(farr) === transpose)) && ndlarr < 2
22922295
return ShapeVecT()
22932296
end
22942297
return result
22952298
end
22962299

22972300
function _multiplied_terms_shape(terms::Tuple)
2298-
result, first_arr = _multiplied_shape(ntuple(shape Base.Fix1(getindex, terms), Val(length(terms))))
2299-
return _check_adjoint_or_transpose(terms, result, first_arr)
2301+
result, first_arr, last_arr = _multiplied_shape(ntuple(shape Base.Fix1(getindex, terms), Val(length(terms))))
2302+
return _check_adjoint_or_transpose(terms, result, first_arr, last_arr)
23002303
end
23012304

23022305
function _multiplied_terms_shape(terms)
@@ -2305,8 +2308,8 @@ function _multiplied_terms_shape(terms)
23052308
for t in terms
23062309
push!(shapes, shape(t))
23072310
end
2308-
result, first_arr = _multiplied_shape(shapes)
2309-
return _check_adjoint_or_transpose(terms, result, first_arr)
2311+
result, first_arr, last_arr = _multiplied_shape(shapes)
2312+
return _check_adjoint_or_transpose(terms, result, first_arr, last_arr)
23102313
end
23112314

23122315
function _split_arrterm_scalar_coeff(::Type{T}, ex::BasicSymbolic{T}) where {T}

0 commit comments

Comments
 (0)