Skip to content

Commit cd24703

Browse files
authored
Fix conic objective reverse diff (#237)
* Fix conic objective reverse diff * Fix
1 parent 794dd56 commit cd24703

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

src/ConicProgram/ConicProgram.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ end
314314
function MOI.get(model::Model, ::DiffOpt.ReverseObjectiveFunction)
315315
g = model.back_grad_cache.g
316316
πz = model.back_grad_cache.πz
317-
dc = DiffOpt.lazy_combination(-, πz, g, length(g))
317+
dc = DiffOpt.lazy_combination(-, πz, g, length(g), eachindex(model.x))
318318
return DiffOpt.VectorScalarAffineFunction(dc, 0.0)
319319
end
320320

src/diff_opt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,13 @@ function lazy_combination(op::F, α, a, β, b) where {F<:Function}
270270
)
271271
end
272272

273-
function lazy_combination(op::F, α, a, β, b, I::UnitRange) where {F<:Function}
273+
function lazy_combination(op::F, α, a, β, b, I::AbstractUnitRange) where {F<:Function}
274274
return lazy_combination(op, α, view(a, I), β, view(b, I))
275275
end
276276
function lazy_combination(op::F, a, b, i::Integer, args::Vararg{Any,N}) where {F<:Function,N}
277277
return lazy_combination(op, a[i], b, b[i], a, args...)
278278
end
279-
function lazy_combination(op::F, a, b, i::UnitRange, I::UnitRange) where {F<:Function}
279+
function lazy_combination(op::F, a, b, i::AbstractUnitRange, I::AbstractUnitRange) where {F<:Function}
280280
return lazy_combination(op, view(a, i), b, view(b, i), a, I)
281281
end
282282

@@ -412,7 +412,7 @@ function _push_term(I::Vector, J::Vector, V::Vector, neg::Bool, r::Integer, term
412412
push!(J, term.variable.value)
413413
push!(V, neg ? -term.coefficient : term.coefficient)
414414
end
415-
function _push_term(I::Vector, J::Vector, V::Vector, neg::Bool, r::UnitRange, term::MOI.VectorAffineTerm)
415+
function _push_term(I::Vector, J::Vector, V::Vector, neg::Bool, r::AbstractUnitRange, term::MOI.VectorAffineTerm)
416416
_push_term(I, J, V, neg, r[term.output_index], term.scalar_term)
417417
end
418418

test/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Check our results for QPs using the notations of [AK17].
2828
"""
2929
function qp_test(
3030
solver,
31+
diff_model,
3132
lt::Bool,
3233
set_zero::Bool,
3334
canonicalize::Bool;
@@ -79,6 +80,9 @@ function qp_test(
7980
atol = ATOL,
8081
rtol = RTOL,
8182
)
83+
if !all(iszero, Q) && diff_model == DiffOpt.ConicProgram.Model
84+
return # TODO https://github.com/jump-dev/DiffOpt.jl/pull/231
85+
end
8286
n = length(q)
8387
@assert n == LinearAlgebra.checksquare(Q)
8488
@assert n == size(A, 2)
@@ -142,6 +146,8 @@ function qp_test(
142146
end
143147
@_test(convert(Vector{Float64}, _λ), λ)
144148

149+
MOI.set(model, DiffOpt.ModelConstructor(), diff_model)
150+
145151
#dobjb = v' * (dQb / 2.0) * v + dqb' * v
146152
# TODO, it should .-
147153
#dleb = dGb * v .+ dhb
@@ -267,11 +273,16 @@ function qp_test(
267273
@test pprod pprod atol = ATOL rtol = RTOL
268274
end
269275

276+
function qp_test(solver, lt, set_zero, canonicalize; kws...)
277+
@testset "With $diff_model" for diff_model in [DiffOpt.ConicProgram.Model, DiffOpt.QuadraticProgram.Model]
278+
qp_test(solver, diff_model, lt, set_zero, canonicalize; kws...)
279+
end
280+
end
281+
270282
function qp_test(solver; kws...)
271283
@testset "With $(lt ? "LessThan" : "GreaterThan") constraints" for lt in [true, false]
272284
@testset "With$(set_zero ? "" : "out") setting zero tangents" for set_zero in [true, false]
273285
@testset "With$(canonicalize ? "" : "out") canonicalization" for canonicalize in [true, false]
274-
qp_test(solver, lt, set_zero, canonicalize; kws...)
275286
end
276287
end
277288
end

0 commit comments

Comments
 (0)