diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 90f5c03f7fcfb..bc1afc8aaaede 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -407,16 +407,16 @@ end const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal} const BiTri = Union{Bidiagonal,Tridiagonal} -@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function check_A_mul_B!_sizes(C, A, B) mA, nA = size(A) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index ec1bca909ce7b..04a75b391f2b2 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -585,14 +585,14 @@ for Tri in (:UpperTriangular, :LowerTriangular) iszero(α) && return _rmul_or_fill!(C, β) diag′ = iszero(β) ? nothing : diag(C) data = mul!(C.data, D, A.data, α, β) - $Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′)) + $Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′)) end @eval @inline mul!(C::$Tri, A::$Tri, D::Diagonal, α::Number, β::Number) = $Tri(mul!(C.data, A.data, D, α, β)) @eval @inline function mul!(C::$Tri, A::$UTri, D::Diagonal, α::Number, β::Number) iszero(α) && return _rmul_or_fill!(C, β) diag′ = iszero(β) ? nothing : diag(C) data = mul!(C.data, A.data, D, α, β) - $Tri(_setdiag!(data, MulAddMul(α, β), D.diag, diag′)) + $Tri(@stable_muladdmul _setdiag!(data, MulAddMul(α, β), D.diag, diag′)) end end diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index c66f59838e8ba..ccea6800e28e6 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -49,6 +49,76 @@ end end end +""" + @stable_muladdmul + +Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an +argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)` +and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch. + +For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into +``` +if isone(alpha) + if iszero(beta) + f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta)) + else + f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta)) + end +else + if iszero(beta) + f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta)) + else + f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta)) + end +end +``` + +This avoids the type instability of the `MulAddMul(alpha, beta)` constructor, +which causes runtime dispatch in case alpha and zero are not constants. +""" +macro stable_muladdmul(expr) + expr.head == :call || throw(ArgumentError("Can only handle function calls.")) + for (i, e) in enumerate(expr.args) + e isa Expr || continue + if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3 + e.args[2] isa Symbol || continue + e.args[3] isa Symbol || continue + local asym = e.args[2] + local bsym = e.args[3] + + local e_sub11 = copy(expr) + e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub10 = copy(expr) + e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub01 = copy(expr) + e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub00 = copy(expr) + e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_out = quote + if isone($asym) + if iszero($bsym) + $e_sub11 + else + $e_sub10 + end + else + if iszero($bsym) + $e_sub01 + else + $e_sub00 + end + end + end + return esc(e_out) + end + end + throw(ArgumentError("No valid MulAddMul expression found.")) +end + MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false) @inline (::MulAddMul{true})(x) = x diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 6d00b950525e6..43d0f991aa063 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -79,7 +79,7 @@ end @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, alpha::Number, beta::Number) = - generic_matvecmul!(y, 'N', A, x, MulAddMul(alpha, beta)) + @stable_muladdmul generic_matvecmul!(y, 'N', A, x, MulAddMul(alpha, beta)) function *(tA::Transpose{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S} TS = promote_op(matprod, T, S) @@ -94,7 +94,7 @@ end gemv!(y, 'T', tA.parent, x, alpha, beta) @inline mul!(y::AbstractVector, tA::Transpose{<:Any,<:AbstractVecOrMat}, x::AbstractVector, alpha::Number, beta::Number) = - generic_matvecmul!(y, 'T', tA.parent, x, MulAddMul(alpha, beta)) + @stable_muladdmul generic_matvecmul!(y, 'T', tA.parent, x, MulAddMul(alpha, beta)) function *(adjA::Adjoint{<:Any,<:StridedMatrix{T}}, x::StridedVector{S}) where {T<:BlasFloat,S} TS = promote_op(matprod, T, S) @@ -113,7 +113,7 @@ end gemv!(y, 'C', adjA.parent, x, alpha, beta) @inline mul!(y::AbstractVector, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, x::AbstractVector, alpha::Number, beta::Number) = - generic_matvecmul!(y, 'C', adjA.parent, x, MulAddMul(alpha, beta)) + @stable_muladdmul generic_matvecmul!(y, 'C', adjA.parent, x, MulAddMul(alpha, beta)) # Vector-Matrix multiplication (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' @@ -158,7 +158,7 @@ end @inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, alpha::Number, beta::Number) where {T<:BlasFloat} - return gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta) end # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the # first matrix as a real matrix and carry out real matrix matrix multiply @@ -373,9 +373,9 @@ lmul!(A, B) alpha::Number, beta::Number) where {T<:BlasFloat} A = tA.parent if A === B - return syrk_wrapper!(C, 'T', A, MulAddMul(alpha, beta)) + return syrk_wrapper!(C, 'T', A, alpha, beta) else - return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, 'T', 'N', A, B, alpha, beta) end end @inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat, @@ -386,18 +386,18 @@ end alpha::Number, beta::Number) where {T<:BlasFloat} B = tB.parent if A === B - return syrk_wrapper!(C, 'N', A, MulAddMul(alpha, beta)) + return syrk_wrapper!(C, 'N', A, alpha, beta) else - return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, 'N', 'T', A, B, alpha, beta) end end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. @inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, alpha::Number, beta::Number) where {T<:BlasReal} = - gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'N', 'N', A, B, alpha, beta) @inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, alpha::Number, beta::Number) where {T<:BlasReal} = - gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'N', 'T', A, parent(tB), alpha, beta) # collapsing the following two defs with C::AbstractVecOrMat yields ambiguities @inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat}, @@ -409,14 +409,14 @@ end @inline mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, alpha::Number, beta::Number) where {T<:BlasFloat} = - gemm_wrapper!(C, 'T', 'T', tA.parent, tB.parent, MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'T', 'T', tA.parent, tB.parent, alpha, beta) @inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, tB::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = generic_matmatmul!(C, 'T', 'T', tA.parent, tB.parent, MulAddMul(alpha, beta)) @inline mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}, alpha::Number, beta::Number) where {T<:BlasFloat} = - gemm_wrapper!(C, 'T', 'C', tA.parent, adjB.parent, MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'T', 'C', tA.parent, adjB.parent, alpha, beta) @inline mul!(C::AbstractMatrix, tA::Transpose{<:Any,<:AbstractVecOrMat}, tB::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = generic_matmatmul!(C, 'T', 'C', tA.parent, tB.parent, MulAddMul(alpha, beta)) @@ -428,9 +428,9 @@ end alpha::Number, beta::Number) where {T<:BlasComplex} A = adjA.parent if A === B - return herk_wrapper!(C, 'C', A, MulAddMul(alpha, beta)) + return herk_wrapper!(C, 'C', A, alpha, beta) else - return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, 'C', 'N', A, B, alpha, beta) end end @inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, B::AbstractVecOrMat, @@ -444,9 +444,9 @@ end alpha::Number, beta::Number) where {T<:BlasComplex} B = adjB.parent if A === B - return herk_wrapper!(C, 'N', A, MulAddMul(alpha, beta)) + return herk_wrapper!(C, 'N', A, alpha, beta) else - return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, 'N', 'C', A, B, alpha, beta) end end @inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, adjB::Adjoint{<:Any,<:AbstractVecOrMat}, @@ -455,14 +455,14 @@ end @inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}, alpha::Number, beta::Number) where {T<:BlasFloat} = - gemm_wrapper!(C, 'C', 'C', adjA.parent, adjB.parent, MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'C', 'C', adjA.parent, adjB.parent, alpha, beta) @inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, adjB::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = generic_matmatmul!(C, 'C', 'C', adjA.parent, adjB.parent, MulAddMul(alpha, beta)) @inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, alpha::Number, beta::Number) where {T<:BlasFloat} = - gemm_wrapper!(C, 'C', 'T', adjA.parent, tB.parent, MulAddMul(alpha, beta)) + gemm_wrapper!(C, 'C', 'T', adjA.parent, tB.parent, alpha, beta) @inline mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, tB::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = generic_matmatmul!(C, 'C', 'T', adjA.parent, tB.parent, MulAddMul(alpha, beta)) @@ -502,7 +502,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x:: !iszero(stride(x, 1)) # We only check input's stride here. return BLAS.gemv!(tA, alpha, A, x, beta, y) else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -523,7 +523,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -546,12 +546,12 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end -function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasFloat} +@inline function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, + α::Number=true, β::Number=false) where {T<:BlasFloat} nC = checksquare(C) if tA == 'T' (nA, mA) = size(A,1), size(A,2) @@ -563,20 +563,25 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) end - if mA == 0 || nA == 0 || iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) + if mA == 0 || nA == 0 || iszero(α) + return _rmul_or_fill!(C, β) end if mA == 2 && nA == 2 - return matmul2x2!(C, tA, tAt, A, A, _add) + return @stable_muladdmul matmul2x2!(C, tA, tAt, A, A, MulAddMul(α, β)) end if mA == 3 && nA == 3 - return matmul3x3!(C, tA, tAt, A, A, _add) + return @stable_muladdmul matmul3x3!(C, tA, tAt, A, A, MulAddMul(α, β)) end + _syrk_wrapper!(C, tA, tAt, A, α, β) +end + +function _syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, tAt::AbstractChar, A::StridedVecOrMat{T}, + α::Number=true, β::Number=false) where {T<:BlasFloat} # BLAS.syrk! only updates symmetric C # alternatively, make non-zero β a show-stopper for BLAS.syrk! - if iszero(_add.beta) || issymmetric(C) - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + if iszero(β) || issymmetric(C) + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && @@ -585,11 +590,11 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') end end - return gemm_wrapper!(C, tA, tAt, A, A, _add) + return gemm_wrapper!(C, tA, tAt, A, A, α, β) end -function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, - _add = MulAddMul()) where {T<:BlasReal} +@inline function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, + α::Number=true, β::Number=false) where {T<:BlasReal} nC = checksquare(C) if tA == 'C' (nA, mA) = size(A,1), size(A,2) @@ -601,21 +606,26 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) end - if mA == 0 || nA == 0 || iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) + if mA == 0 || nA == 0 || iszero(α) + return _rmul_or_fill!(C, β) end if mA == 2 && nA == 2 - return matmul2x2!(C, tA, tAt, A, A, _add) + return @stable_muladdmul matmul2x2!(C, tA, tAt, A, A, MulAddMul(α, β)) end if mA == 3 && nA == 3 - return matmul3x3!(C, tA, tAt, A, A, _add) + return @stable_muladdmul matmul3x3!(C, tA, tAt, A, A, MulAddMul(α, β)) end + _herk_wrapper!(C, tA, tAt, A, α, β) +end + +function _herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, tAt::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, + α::Number=true, β::Number=false) where {T<:BlasReal} # Result array does not need to be initialized as long as beta==0 # C = Matrix{T}(undef, mA, mA) - if iszero(_add.beta) || issymmetric(C) - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + if iszero(β) || issymmetric(C) + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && @@ -624,7 +634,7 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) end end - return gemm_wrapper!(C, tA, tAt, A, A, _add) + return gemm_wrapper!(C, tA, tAt, A, A, α, β) end function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, @@ -636,9 +646,9 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, gemm_wrapper!(C, tA, tB, A, B) end -function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, - A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasFloat} +@inline function gemm_wrapper!(C::StridedVecOrMat{S}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{S}, B::StridedVecOrMat{T}, + α::Number=true, β::Number=false) where {T<:BlasFloat,S<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) @@ -650,21 +660,28 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) + if mA == 0 || nA == 0 || nB == 0 || iszero(α) if size(C) != (mA, nB) throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) end - return _rmul_or_fill!(C, _add.beta) + return _rmul_or_fill!(C, β) end if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) + return @stable_muladdmul matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β)) end if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) + return @stable_muladdmul matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β)) end - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + # heavier variants (not inlined) + _gemm_wrapper!(C, tA, tB, A, B, α, β) +end + +@noinline function _gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + α::Number=true, β::Number=false) where {T<:BlasFloat} + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && @@ -673,38 +690,16 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - generic_matmatmul!(C, tA, tB, A, B, _add) + # Not using @stable_muladdmul here deliberately to create an inferrence + # barrier in case α, β are not compile-time constants. This avoids compiling + # four versions of generic_matmatmul!() in those cases. + generic_matmatmul!(C, tA, tB, A, B, MulAddMul(α, β)) end -function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, +@noinline function _gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasReal} - mA, nA = lapack_size(tA, A) - mB, nB = lapack_size(tB, B) - - if nA != mB - throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) - end - - if C === A || B === C - throw(ArgumentError("output matrix must not be aliased with input matrix")) - end - - if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) - if size(C) != (mA, nB) - throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) - end - return _rmul_or_fill!(C, _add.beta) - end - - if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) - end - if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) - end - - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + α::Number=true, β::Number=false) where {T<:BlasReal} + alpha, beta = promote(α, β, zero(T)) # Make-sure reinterpret-based optimization is BLAS-compatible. if (alpha isa Union{Bool,T} && @@ -716,7 +711,10 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - generic_matmatmul!(C, tA, tB, A, B, _add) + # Not using @stable_muladdmul here deliberately to create an inferrence + # barrier in case α, β are not compile-time constants. This avoids compiling + # four versions of generic_matmatmul!() in those cases. + generic_matmatmul!(C, tA, tB, A, B, MulAddMul(α, β)) end # blas.jl defines matmul for floats; other integer and mixed precision diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index f96ca812ea0ec..8f107d880526f 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -519,7 +519,7 @@ end if alpha isa Union{Bool,T} && beta isa Union{Bool,T} return BLAS.symv!(A.uplo, alpha, A.data, x, beta, y) else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) end end @inline function mul!(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}, @@ -528,7 +528,7 @@ end if alpha isa Union{Bool,T} && beta isa Union{Bool,T} return BLAS.symv!(A.uplo, alpha, A.data, x, beta, y) else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) end end @inline function mul!(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}, @@ -537,7 +537,7 @@ end if alpha isa Union{Bool,T} && beta isa Union{Bool,T} return BLAS.hemv!(A.uplo, alpha, A.data, x, beta, y) else - return generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) + return @stable_muladdmul generic_matvecmul!(y, 'N', A, x, MulAddMul(α, β)) end end ## Matmat diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 248fc048612c8..359c371785e56 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -462,7 +462,7 @@ for (Trig, UnitTrig) in Any[(UpperTriangular, UnitUpperTriangular), (UnitTrig, Number), (Number, UnitTrig)] @eval @inline mul!(A::$Trig, B::$TB, C::$TC, alpha::Number, beta::Number) = - _mul!(A, B, C, MulAddMul(alpha, beta)) + @stable_muladdmul _mul!(A, B, C, MulAddMul(alpha, beta)) end end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 0150c4c2efdc8..61f12b938a3c5 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -667,6 +667,19 @@ Transpose(x::RootInt) = x @test A * a == [56] end +@testset "#46865: mul!() with non-const alpha, beta" begin + f!(C,A,B,alphas,betas) = mul!(C, A, B, alphas[1], betas[1]) + alphas = [1.0] + betas = [0.5] + for d in [2,3,4] # test native small-matrix cases as well as BLAS + A = rand(d,d) + B = copy(A) + C = copy(A) + f!(C, A, B, alphas, betas) + @test (@allocated f!(C, A, B, alphas, betas)) == 0 + end +end + function test_mul(C, A, B) mul!(C, A, B) @test Array(A) * Array(B) ≈ C