diff --git a/src/transformations/linear.jl b/src/transformations/linear.jl index 658302299..8315959a1 100644 --- a/src/transformations/linear.jl +++ b/src/transformations/linear.jl @@ -3,12 +3,15 @@ Base.@kwdef struct LinearAnalysis{I} dropcollinear::Bool=false interval::I=automatic level::Float64=0.95 + degree::Int=1 end -function add_intercept_column(x::AbstractVector{T}) where {T} - mat = similar(x, float(T), (length(x), 2)) +function make_design_matrix(x::AbstractVector{T}, d::Int) where {T} + mat = similar(x, float(T), (length(x), d+1)) fill!(view(mat, :, 1), 1) - copyto!(view(mat, :, 2), x) + for i in 1:d + copyto!(view(mat, :, i+1), x .^ i) + end return mat end @@ -20,9 +23,9 @@ function (l::LinearAnalysis)(input::ProcessedLayer) default_interval = length(weights) > 0 ? nothing : :confidence interval = l.interval === automatic ? default_interval : l.interval # FIXME: handle collinear case gracefully - lin_model = GLM.lm(add_intercept_column(x), y; wts=weights, l.dropcollinear) + lin_model = GLM.lm(make_design_matrix(x, l.degree), y; wts=weights, l.dropcollinear) x̂ = range(extrema(x)..., length=l.npoints) - pred = GLM.predict(lin_model, add_intercept_column(x̂); interval, l.level) + pred = GLM.predict(lin_model, make_design_matrix(x̂, l.degree); interval, l.level) return if !isnothing(interval) ŷ, lower, upper = pred (x̂, ŷ), (; lower, upper)