@@ -2238,9 +2238,9 @@ end
22382238is_array_shape (sh:: ShapeT ) = sh isa Unknown || _ndims_from_shape (sh) > 0
22392239function _multiplied_shape (shapes)
22402240 first_arr = findfirst (is_array_shape, shapes)
2241- first_arr === nothing && return ShapeVecT (), first_arr
2241+ first_arr === nothing && return ShapeVecT (), first_arr, nothing
22422242 last_arr:: Int = findlast (is_array_shape, shapes)
2243- first_arr == last_arr && return shapes[first_arr], first_arr
2243+ first_arr == last_arr && return shapes[first_arr], first_arr, last_arr
22442244
22452245 sh1:: ShapeT = shapes[first_arr]
22462246 shend:: ShapeT = shapes[last_arr]
@@ -2275,7 +2275,7 @@ function _multiplied_shape(shapes)
22752275 cur_shape = sh
22762276 end
22772277
2278- return result, first_arr
2278+ return result, first_arr, last_arr
22792279end
22802280
22812281function promote_shape (:: typeof (* ), shs:: ShapeT... )
@@ -2284,19 +2284,22 @@ end
22842284
22852285const AdjointOrTranspose = Union{LinearAlgebra. Adjoint, LinearAlgebra. Transpose}
22862286
2287- function _check_adjoint_or_transpose (terms, result:: ShapeT , first_arr:: Union{Int, Nothing} )
2288- @nospecialize first_arr result
2287+ function _check_adjoint_or_transpose (terms, result:: ShapeT , first_arr:: Union{Int, Nothing} , last_arr :: Union{Int, Nothing} )
2288+ @nospecialize first_arr result last_arr
22892289 first_arr === nothing && return result
2290+ last_arr = last_arr:: Int
2291+ first_arr == last_arr && return result
22902292 farr = terms[first_arr]
2291- if result isa ShapeVecT && length (result) <= 2 && all (== (1 ) ∘ length, result) && (farr isa AdjointOrTranspose || iscall (farr) && (operation (farr) === adjoint || operation (farr) === transpose))
2293+ ndlarr = ndims (terms[last_arr])
2294+ if result isa ShapeVecT && length (result) <= 2 && all (isone ∘ length, result) && (farr isa AdjointOrTranspose || iscall (farr) && (operation (farr) === adjoint || operation (farr) === transpose)) && ndlarr < 2
22922295 return ShapeVecT ()
22932296 end
22942297 return result
22952298end
22962299
22972300function _multiplied_terms_shape (terms:: Tuple )
2298- result, first_arr = _multiplied_shape (ntuple (shape ∘ Base. Fix1 (getindex, terms), Val (length (terms))))
2299- return _check_adjoint_or_transpose (terms, result, first_arr)
2301+ result, first_arr, last_arr = _multiplied_shape (ntuple (shape ∘ Base. Fix1 (getindex, terms), Val (length (terms))))
2302+ return _check_adjoint_or_transpose (terms, result, first_arr, last_arr )
23002303end
23012304
23022305function _multiplied_terms_shape (terms)
@@ -2305,8 +2308,8 @@ function _multiplied_terms_shape(terms)
23052308 for t in terms
23062309 push! (shapes, shape (t))
23072310 end
2308- result, first_arr = _multiplied_shape (shapes)
2309- return _check_adjoint_or_transpose (terms, result, first_arr)
2311+ result, first_arr, last_arr = _multiplied_shape (shapes)
2312+ return _check_adjoint_or_transpose (terms, result, first_arr, last_arr )
23102313end
23112314
23122315function _split_arrterm_scalar_coeff (:: Type{T} , ex:: BasicSymbolic{T} ) where {T}
0 commit comments