Skip to content

Commit 35f270c

Browse files
authored
Patch GLM (#41)
* ensure same eltype for X and y (JuliaStats/GLM.jl#369) * correctly use QR * use :cholesky and dropcollinear like in GLM * test method :cholesky * create own DensePredQR * cleanup
1 parent 70b6ecb commit 35f270c

11 files changed

+216
-86
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
1010
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
13+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1314
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1415
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/RobustModels.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module RobustModels
22

3+
using Pkg: Pkg
4+
35
include("compat.jl")
46

57
# Use README as the docstring of the module and doctest README
@@ -10,10 +12,10 @@ end RobustModels
1012

1113
# Import with `using` to use the module names to prefix the methods
1214
# that are extended from these modules
13-
using GLM
14-
using StatsAPI
15-
using StatsBase
16-
using StatsModels
15+
using GLM: GLM
16+
using StatsAPI: StatsAPI
17+
using StatsBase: StatsBase
18+
using StatsModels: StatsModels
1719

1820
## Import to implement new methods
1921
import Base: show, broadcastable, convert, ==
@@ -73,18 +75,21 @@ using LinearAlgebra:
7375
inv,
7476
diag,
7577
diagm,
78+
rank,
7679
ldiv!
7780

7881
using Random: AbstractRNG, GLOBAL_RNG
7982
using Printf: @printf, @sprintf
80-
using GLM: FPVector, lm, SparsePredChol, DensePredChol, DensePredQR
83+
using GLM: FPVector, lm, SparsePredChol, DensePredChol
8184
using StatsBase:
8285
AbstractWeights, CoefTable, ConvergenceException, median, mad, mad_constant, sample
8386
using StatsModels:
8487
@delegate,
8588
@formula,
89+
formula,
8690
RegressionModel,
8791
FormulaTerm,
92+
InterceptTerm,
8893
ModelFrame,
8994
modelcols,
9095
apply_schema,
@@ -238,6 +243,7 @@ abstract type AbstractRegularizedPred{T} end
238243
Base.broadcastable(m::T) where {T<:AbstractEstimator} = Ref(m)
239244
Base.broadcastable(m::T) where {T<:LossFunction} = Ref(m)
240245

246+
241247
include("tools.jl")
242248
include("losses.jl")
243249
include("estimators.jl")

src/compat.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
1-
using LinearAlgebra: cholesky!
1+
using LinearAlgebra: cholesky!, qr!
2+
3+
function get_pkg_version(m::Module)
4+
toml = Pkg.TOML.parsefile(joinpath(pkgdir(m), "Project.toml"))
5+
return VersionNumber(toml["version"])
6+
end
7+
28

39
## Compatibility layers
410

511
# https://github.com/JuliaStats/GLM.jl/pull/459
612
@static if VERSION < v"1.8.0-DEV.1139"
713
pivoted_cholesky!(A; kwargs...) = cholesky!(A, Val(true); kwargs...)
814
else
15+
using LinearAlgebra: RowMaximum
916
pivoted_cholesky!(A; kwargs...) = cholesky!(A, RowMaximum(); kwargs...)
1017
end
18+
19+
@static if VERSION < v"1.7.0"
20+
pivoted_qr!(A; kwargs...) = qr!(A, Val(true); kwargs...)
21+
else
22+
using LinearAlgebra: ColumnNorm
23+
pivoted_qr!(A; kwargs...) = qr!(A, ColumnNorm(); kwargs...)
24+
end

src/linpred.jl

+130-46
Original file line numberDiff line numberDiff line change
@@ -50,45 +50,143 @@ leverage_weights(p::LinPred, wt::AbstractVector) = sqrt.(1 .- leverage(p, wt))
5050
# beta0
5151
#end
5252

53-
"""
54-
DensePredQR
5553

56-
A `LinPred` type with a dense, unpivoted QR decomposition of `X`
54+
##########################################
55+
###### DensePredQR
56+
##########################################
5757

58-
# Members
58+
@static if get_pkg_version(GLM) < v"1.9"
59+
@warn(
60+
"GLM.DensePredQR(X::AbstractMatrix, pivot::Bool=true) is not defined, " *
61+
"fallback to unpivoted RobustModels.DensePredQR definition. " *
62+
"To use pivoted QR, GLM version should be greater than or equal to v1.9."
63+
)
5964

60-
- `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
61-
- `beta0`: base coefficient vector of length `p`
62-
- `delbeta`: increment to coefficient vector, also of length `p`
63-
- `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
64-
- `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
65-
"""
66-
DensePredQR
67-
68-
PRED_QR_WARNING_ISSUED = false
69-
70-
function qrpred(X::AbstractMatrix, pivot::Bool=false)
71-
try
72-
return DensePredCG(Matrix(X), pivot)
73-
catch e
74-
if e isa MethodError
75-
# GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined
76-
global PRED_QR_WARNING_ISSUED
77-
if !PRED_QR_WARNING_ISSUED
78-
@warn(
79-
"GLM.DensePredCG(X::AbstractMatrix, pivot::Bool) is not defined, " *
80-
"fallback to unpivoted QR. GLM version should be >= 1.9."
81-
)
82-
PRED_QR_WARNING_ISSUED = true
65+
using LinearAlgebra: QRCompactWY, QRPivoted, Diagonal, qr!, qr
66+
67+
"""
68+
DensePredQR
69+
70+
A `LinPred` type with a dense QR decomposition of `X`
71+
72+
# Members
73+
74+
- `X`: Model matrix of size `n` × `p` with `n ≥ p`. Should be full column rank.
75+
- `beta0`: base coefficient vector of length `p`
76+
- `delbeta`: increment to coefficient vector, also of length `p`
77+
- `scratchbeta`: scratch vector of length `p`, used in `linpred!` method
78+
- `qr`: a `QRCompactWY` object created from `X`, with optional row weights.
79+
- `scratchm1`: scratch Matrix{T} of the same size as `X`
80+
- `scratchm2`: scratch Matrix{T} of the same size as `X`
81+
- `scratchR`: scratch Matrix{T} of the same size as `qr.R`, a square matrix.
82+
"""
83+
mutable struct DensePredQR{T<:BlasReal,Q<:Union{QRCompactWY,QRPivoted}} <: DensePred
84+
X::Matrix{T} # model matrix
85+
beta0::Vector{T} # base coefficient vector
86+
delbeta::Vector{T} # coefficient increment
87+
scratchbeta::Vector{T}
88+
qr::Q
89+
scratchm1::Matrix{T}
90+
scratchm2::Matrix{T}
91+
scratchR::Matrix{T}
92+
93+
function DensePredQR(X::AbstractMatrix, pivot::Bool=false)
94+
n, p = size(X)
95+
T = typeof(float(zero(eltype(X))))
96+
97+
if false
98+
# if pivot
99+
F = pivoted_qr!(copy(X))
100+
else
101+
if n >= p
102+
F = qr(X)
103+
else
104+
# adjoint of X so R is square
105+
# cannot use in-place qr!
106+
F = qr(X)
107+
end
83108
end
84-
return DensePredCG(Matrix(X))
109+
110+
return new{T,typeof(F)}(
111+
Matrix{T}(X),
112+
zeros(T, p),
113+
zeros(T, p),
114+
zeros(T, p),
115+
F,
116+
similar(X, T),
117+
similar(X, T),
118+
zeros(T, size(F.R)),
119+
)
120+
end
121+
end
122+
123+
# GLM.DensePredQR(X::AbstractMatrix, pivot::Bool) is not defined
124+
function qrpred(X::AbstractMatrix, pivot::Bool=false)
125+
return DensePredQR(Matrix(X))
126+
end
127+
128+
# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}) is ill-defined
129+
function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}) where {T<:BlasReal}
130+
n, m = size(p.X)
131+
if n >= m
132+
p.delbeta = p.qr \ r
133+
else
134+
p.delbeta = p.qr' \ r
135+
end
136+
return p
137+
end
138+
139+
# GLM.delbeta!(p::DensePredQR{T}, r::Vector{T}, wt::Vector{T}) is not defined
140+
function delbeta!(
141+
p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}, wt::Vector{T}
142+
) where {T<:BlasReal}
143+
rnk = rank(p.qr.R)
144+
X = p.X
145+
W = Diagonal(wt)
146+
sqrtW = Diagonal(sqrt.(wt))
147+
scratchm1 = p.scratchm1 = similar(X, T)
148+
mul!(scratchm1, sqrtW, X)
149+
150+
n, m = size(X)
151+
if n >= m
152+
# W½ X = Q R , with Q'Q = I
153+
# X'WX β = X'y => R'Q'QR β = X'y
154+
# => β = R⁻¹ R⁻ᵀ X'y
155+
qnr = p.qr = qr(scratchm1)
156+
Rinv = p.scratchR = inv(qnr.R)
157+
158+
scratchm2 = p.scratchm2 = similar(X, T)
159+
mul!(scratchm2, W, X)
160+
mul!(p.delbeta, transpose(scratchm2), r)
161+
162+
p.delbeta = Rinv * Rinv' * p.delbeta
85163
else
86-
rethrow()
164+
# (W½ X)' = Q R , with Q'Q = I
165+
# W½X β = W½y => R'Q' β = y
166+
# => β = Q . [R⁻ᵀ y; 0]
167+
qnrT = p.qr = qr(scratchm1')
168+
RTinv = p.scratchR = inv(qnrT.R)'
169+
@assert 1 <= n <= size(p.delbeta, 1)
170+
mul!(view(p.delbeta, 1:n), RTinv, r)
171+
p.delbeta = zeros(size(p.delbeta))
172+
p.delbeta[1:n] .= RTinv * r
173+
lmul!(qnrT.Q, p.delbeta)
87174
end
175+
return p
88176
end
177+
178+
179+
## Use DensePredQR from GLM
180+
else
181+
using GLM: DensePredQR
182+
import GLM: qrpred
89183
end
90184

91185

186+
##########################################
187+
###### [Dense/Sparse]PredCG
188+
##########################################
189+
92190
"""
93191
DensePredCG
94192
@@ -109,20 +207,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
109207
scratchbeta::Vector{T}
110208
scratchm1::Matrix{T}
111209
scratchr1::Vector{T}
112-
function DensePredCG{T}(X::Matrix{T}, beta0::Vector{T}) where {T}
113-
n, p = size(X)
114-
length(beta0) == p || throw(DimensionMismatch("length(β0) ≠ size(X,2)"))
115-
return new{T}(
116-
X,
117-
beta0,
118-
zeros(T, p),
119-
zeros(T, (p, p)),
120-
zeros(T, p),
121-
zeros(T, (n, p)),
122-
zeros(T, n),
123-
)
124-
end
125-
function DensePredCG{T}(X::Matrix{T}) where {T}
210+
211+
function DensePredCG(X::Matrix{T}) where {T<:BlasReal}
126212
n, p = size(X)
127213
return new{T}(
128214
X,
@@ -135,10 +221,8 @@ mutable struct DensePredCG{T<:BlasReal} <: DensePred
135221
)
136222
end
137223
end
138-
DensePredCG(X::Matrix, beta0::Vector) = DensePredCG{eltype(X)}(X, beta0)
139-
DensePredCG(X::Matrix{T}) where {T} = DensePredCG{T}(X, zeros(T, size(X, 2)))
140224
function Base.convert(::Type{DensePredCG{T}}, X::Matrix{T}) where {T}
141-
return DensePredCG{T}(X, zeros(T, size(X, 2)))
225+
return DensePredCG(X)
142226
end
143227

144228
# Compatibility with cholpred(X, pivot)

src/regularizedpred.jl

+3
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ function postupdate_λ!(r::RidgePred)
154154
# Update the extended model matrix with the new value
155155
GG = r.sqrtλ * r.G
156156
@views r.pred.X[(n + 1):(n + m), :] .= GG
157+
158+
# Update other fields
159+
# TODO: update DensePredQR
157160
if isa(r.pred, DensePredChol)
158161
# Recompute the cholesky decomposition
159162
X = r.pred.X

0 commit comments

Comments
 (0)