Skip to content

Commit ec6f9bb

Browse files
committed
Add lie costs
1 parent 935a270 commit ec6f9bb

File tree

1 file changed

+363
-1
lines changed

1 file changed

+363
-1
lines changed

src/lie_costs.jl

Lines changed: 363 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,366 @@ function Rotations.jacobian(::Type{RodriguesParam}, q::UnitQuaternion)
218218
]
219219
end
220220

221-
Rotations.jacobian(::Type{R}, q::R) where R = I
221+
Rotations.jacobian(::Type{R}, q::R) where R = I
222+
223+
224+
225+
############################################################################################
226+
# QUADRATIC QUATERNION COST FUNCTION
227+
############################################################################################
228+
"""
229+
DiagonalQuatCost
230+
231+
Quadratic cost function for states that includes a 3D rotation, that penalizes deviations
232+
from a provided 3D rotation, represented as a Unit Quaternion.
233+
234+
The cost function penalizes geodesic distance between unit quaternions:
235+
236+
``\\min 1 \\pm q_d^T q```
237+
238+
We've found this perform better than penalizing a quadratic on the quaternion error
239+
state ([`ErrorQuadratic`](@ref)). This cost should still be considered experimental.
240+
241+
# Constructors
242+
* `DiagonalQuatCost(Q::Diagonal, R::Diagonal, q, r, c, w, q_ref, q_ind; terminal)`
243+
* `DiagonalQuatCost(Q::Diagonal, R::Diagonal; q, r, c, w, q_ref, q_ind, terminal)`
244+
245+
where `q_ref` is the reference quaternion (provided as a `SVector{4}`), and
246+
`q_ind::SVector{4,Int}` provides the indices of the quaternion in the state vector
247+
(default = `SA[4,5,6,7]`). Note that `Q` and `q` are the size of the full state,
248+
so `Q.diag[q_ind]` and `q[qind]` should typically be zero.
249+
"""
250+
struct DiagonalQuatCost{N,M,T,N4} <: QuadraticCostFunction{N,M,T}
251+
Q::Diagonal{T,SVector{N,T}}
252+
R::Diagonal{T,SVector{M,T}}
253+
q::SVector{N,T}
254+
r::SVector{M,T}
255+
c::T
256+
w::T
257+
q_ref::SVector{4,T}
258+
q_ind::SVector{4,Int}
259+
Iq::SMatrix{N,4,T,N4}
260+
terminal::Bool
261+
function DiagonalQuatCost(Q::Diagonal{T,SVector{N,T}}, R::Diagonal{T,SVector{M,T}},
262+
q::SVector{N,T}, r::SVector{M,T}, c::T, w::T,
263+
q_ref::SVector{4,T}, q_ind::SVector{4,Int}; terminal::Bool=false) where {T,N,M}
264+
Iq = @MMatrix zeros(N,4)
265+
for i = 1:4
266+
Iq[q_ind[i],i] = 1
267+
end
268+
Iq = SMatrix{N,4}(Iq)
269+
return new{N,M,T,N*4}(Q, R, q, r, c, w, q_ref, q_ind, Iq, terminal)
270+
end
271+
end
272+
273+
state_dim(::DiagonalQuatCost{N,M,T}) where {T,N,M} = N
274+
control_dim(::DiagonalQuatCost{N,M,T}) where {T,N,M} = M
275+
is_blockdiag(::DiagonalQuatCost) = true
276+
is_diag(::DiagonalQuatCost) = true
277+
278+
function DiagonalQuatCost(Q::Diagonal{T,SVector{N,T}}, R::Diagonal{T,SVector{M,T}};
279+
q=(@SVector zeros(N)), r=(@SVector zeros(M)), c=zero(T), w=one(T),
280+
q_ref=(@SVector [1.0,0,0,0]), q_ind=(@SVector [4,5,6,7])) where {T,N,M}
281+
DiagonalQuatCost(Q, R, q, r, c, q_ref, q_ind)
282+
end
283+
284+
function stage_cost(cost::DiagonalQuatCost, x::SVector, u::SVector)
285+
stage_cost(cost, x) + 0.5*u'cost.R*u + cost.r'u
286+
end
287+
288+
function stage_cost(cost::DiagonalQuatCost, x::SVector)
289+
J = 0.5*x'cost.Q*x + cost.q'x + cost.c
290+
q = x[cost.q_ind]
291+
dq = cost.q_ref'q
292+
J += cost.w*min(1+dq, 1-dq)
293+
end
294+
295+
function gradient!(E::QuadraticCostFunction, cost::DiagonalQuatCost{T,N,M},
296+
x::SVector) where {T,N,M}
297+
Qx = cost.Q*x + cost.q
298+
q = x[cost.q_ind]
299+
dq = cost.q_ref'q
300+
if dq < 0
301+
Qx += cost.w*cost.Iq*cost.q_ref
302+
else
303+
Qx -= cost.w*cost.Iq*cost.q_ref
304+
end
305+
E.q .= Qx
306+
return false
307+
end
308+
309+
"""
310+
QuatLQRCost(Q, R, xf, [uf; w, quat_ind])
311+
312+
Defines a cost function that uses a quadratic penalty on deviations from a reference state,
313+
including a quaratic penalty on the geodesic distance between a quaternion and a
314+
reference quaternion. See [`DiagonalQuatCost`](@ref).
315+
"""
316+
function QuatLQRCost(Q::Diagonal{T,SVector{N,T}}, R::Diagonal{T,SVector{M,T}}, xf,
317+
uf=(@SVector zeros(M)); w=one(T), quat_ind=(@SVector [4,5,6,7])) where {T,N,M}
318+
r = -R*uf
319+
q = -Q*xf
320+
c = 0.5*xf'Q*xf + 0.5*uf'R*uf
321+
q_ref = xf[quat_ind]
322+
return DiagonalQuatCost(Q, R, q, r, c, w, q_ref, quat_ind)
323+
end
324+
325+
function change_dimension(cost::DiagonalQuatCost, n, m, ix, iu)
326+
Qd = zeros(n)
327+
Rd = zeros(m)
328+
q = zeros(n)
329+
r = zeros(m)
330+
Qd[ix] = diag(cost.Q)
331+
Rd[iu] = diag(cost.R)
332+
q[ix] = cost.q
333+
r[iu] = cost.r
334+
qind = (1:n)[ix[cost.q_ind]]
335+
DiagonalQuatCost(Diagonal(SVector{n}(Qd)), Diagonal(SVector{m}(Rd)),
336+
SVector{n}(q), SVector{m}(r), cost.c, cost.w, cost.q_ref, qind)
337+
end
338+
339+
function (+)(cost1::DiagonalQuatCost, cost2::QuadraticCostFunction)
340+
@assert state_dim(cost1) == state_dim(cost2)
341+
@assert control_dim(cost1) == control_dim(cost2)
342+
is_diag(cost2) || @assert norm(cost2.H) 0
343+
DiagonalQuatCost(cost1.Q + cost2.Q, cost1.R + cost2.R,
344+
cost1.q + cost2.q, cost1.r + cost2.r, cost1.c + cost2.c,
345+
cost1.w, cost1.q_ref, cost1.q_ind)
346+
end
347+
348+
(+)(cost1::QuadraticCostFunction, cost2::DiagonalQuatCost) = cost2 + cost1
349+
350+
function Base.copy(c::DiagonalQuatCost)
351+
DiagonalQuatCost(c.Q, c.R, c.q, c.r, c.c, c.w, c.q_ref, c.q_ind)
352+
end
353+
354+
355+
############################################################################################
356+
# Error Quadratic
357+
############################################################################################
358+
359+
"""
360+
ErrorQuadratic{Rot,N,M}
361+
362+
Cost function that uses a quadratic penalty on the error state, for a state that includes
363+
a single 3D rotation.
364+
"""
365+
struct ErrorQuadratic{Rot,N,M} <: CostFunction
366+
model::RD.RigidBody{Rot}
367+
Q::Diagonal{Float64,SVector{12,Float64}}
368+
R::Diagonal{Float64,SVector{M,Float64}}
369+
r::SVector{M,Float64}
370+
c::Float64
371+
x_ref::SVector{N,Float64}
372+
q_ind::SVector{4,Int}
373+
end
374+
function Base.copy(c::ErrorQuadratic)
375+
ErrorQuadratic(c.model, c.Q, c.R, c.r, c.c, c.x_ref, c.q_ind)
376+
end
377+
378+
state_dim(::ErrorQuadratic{Rot,N,M}) where {Rot,N,M} = N
379+
control_dim(::ErrorQuadratic{Rot,N,M}) where {Rot,N,M} = M
380+
381+
function ErrorQuadratic(model::RD.RigidBody{Rot}, Q::Diagonal{T,<:SVector{N0}},
382+
R::Diagonal{T,<:SVector{M}},
383+
x_ref::SVector{N},
384+
u_ref=(@SVector zeros(T,M));
385+
r=(@SVector zeros(T,M)),
386+
c=zero(T),
387+
q_ind=(@SVector [4,5,6,7])
388+
) where {T,N,N0,M,Rot}
389+
if Rot <: UnitQuaternion && N0 == N
390+
Qd = deleteat(Q.diag, 4)
391+
Q = Diagonal(Qd)
392+
end
393+
r += -R*u_ref
394+
c += 0.5*u_ref'R*u_ref
395+
return ErrorQuadratic{Rot,N,M}(model, Q, R, r, c, x_ref, q_ind)
396+
end
397+
398+
399+
function stage_cost(cost::ErrorQuadratic, x::SVector)
400+
dx = RD.state_diff(cost.model, x, cost.x_ref, Rotations.ExponentialMap())
401+
return 0.5*dx'cost.Q*dx + cost.c
402+
end
403+
404+
function stage_cost(cost::ErrorQuadratic, x::SVector, u::SVector)
405+
stage_cost(cost, x) + 0.5*u'cost.R*u + cost.r'u
406+
end
407+
408+
409+
function gradient!(E::QuadraticCostFunction, cost::ErrorQuadratic, x)
410+
f(x) = stage_cost(cost, x)
411+
ForwardDiff.gradient!(E.q, f, x)
412+
return false
413+
414+
model = cost.model
415+
Q = cost.Q
416+
q = RD.orientation(model, x)
417+
q_ref = RD.orientation(model, cost.x_ref)
418+
dq = Rotations.params(q_ref \ q)
419+
err = RD.state_diff(model, x, cost.x_ref)
420+
dx = @SVector [err[1], err[2], err[3],
421+
dq[1], dq[2], dq[3], dq[4],
422+
err[7], err[8], err[9],
423+
err[10], err[11], err[12]]
424+
# G = state_diff_jacobian(model, dx) # n × dn
425+
426+
# Gradient
427+
dmap = inverse_map_jacobian(model, dx) # dn × n
428+
# Qx = G'dmap'Q*err
429+
Qx = dmap'Q*err
430+
E.q = Qx
431+
return false
432+
end
433+
function gradient!(E::QuadraticCostFunction, cost::ErrorQuadratic, x, u)
434+
gradient!(E, cost, x)
435+
Qu = cost.R*u
436+
E.r .= Qu
437+
return false
438+
end
439+
440+
function hessian!(E::QuadraticCostFunction, cost::ErrorQuadratic, x)
441+
f(x) = stage_cost(cost, x)
442+
ForwardDiff.hessian!(E.Q, f, x)
443+
return false
444+
445+
model = cost.model
446+
Q = cost.Q
447+
q = RD.orientation(model, x)
448+
q_ref = RD.orientation(model, cost.x_ref)
449+
dq = Rotations.params(q_ref\q)
450+
err = RD.state_diff(model, x, cost.x_ref)
451+
dx = @SVector [err[1], err[2], err[3],
452+
dq[1], dq[2], dq[3], dq[4],
453+
err[7], err[8], err[9],
454+
err[10], err[11], err[12]]
455+
# G = state_diff_jacobian(model, dx) # n × dn
456+
457+
# Gradient
458+
dmap = inverse_map_jacobian(model, dx) # dn × n
459+
460+
# Hessian
461+
∇jac = inverse_map_∇jacobian(model, dx, Q*err)
462+
# Qxx = G'dmap'Q*dmap*G + G'∇jac*G + ∇²differential(model, x, dmap'Q*err)
463+
Qxx = dmap'Q*dmap + ∇jac #+ ∇²differential(model, x, dmap'Q*err)
464+
E.Q = Qxx
465+
E.H .*= 0
466+
return false
467+
end
468+
469+
function hessian!(E::QuadraticCostFunction, cost::ErrorQuadratic, x, u)
470+
hessian!(E, cost, x)
471+
E.R .= cost.R
472+
return false
473+
end
474+
475+
476+
function change_dimension(cost::ErrorQuadratic, n, m)
477+
n0,m0 = state_dim(cost), control_dim(cost)
478+
Q_diag = diag(cost.Q)
479+
R_diag = diag(cost.R)
480+
r = cost.r
481+
if n0 != n
482+
dn = n - n0 # assumes n > n0
483+
pad = @SVector zeros(dn) # assume the new states don't have quaternions
484+
Q_diag = [Q_diag; pad]
485+
end
486+
if m0 != m
487+
dm = m - m0 # assumes m > m0
488+
pad = @SVector zeros(dm)
489+
R_diag = [R_diag; pad]
490+
r = [r; pad]
491+
end
492+
ErrorQuadratic(cost.model, Diagonal(Q_diag), Diagonal(R_diag), r, cost.c,
493+
cost.x_ref, cost.q_ind)
494+
end
495+
496+
function (+)(cost1::ErrorQuadratic, cost2::QuadraticCost)
497+
@assert control_dim(cost1) == control_dim(cost2)
498+
@assert norm(cost2.H) 0
499+
@assert norm(cost2.q) 0
500+
if state_dim(cost2) == 13
501+
rm_quat = @SVector [1,2,3,4,5,6,8,9,10,11,12,13]
502+
Q2 = Diagonal(diag(cost2.Q)[rm_quat])
503+
else
504+
Q2 = cost2.Q
505+
end
506+
ErrorQuadratic(cost1.model, cost1.Q + Q2, cost1.R + cost2.R,
507+
cost1.r + cost2.r, cost1.c + cost2.c,
508+
cost1.x_ref, cost1.q_ind)
509+
end
510+
511+
(+)(cost1::QuadraticCost, cost2::ErrorQuadratic) = cost2 + cost1
512+
513+
@generated function RD.state_diff_jacobian(model::RD.RigidBody{<:UnitQuaternion},
514+
x0::SVector{N,T}, errmap::D=Rotations.CayleyMap()) where {N,T,D}
515+
if D <: IdentityMap
516+
:(I)
517+
else
518+
quote
519+
q0 = RD.orientation(model, x0)
520+
# G = TrajectoryOptimization.∇differential(q0)
521+
G = Rotations.∇differential(q0)
522+
I1 = @SMatrix [1 0 0 0 0 0 0 0 0 0 0 0;
523+
0 1 0 0 0 0 0 0 0 0 0 0;
524+
0 0 1 0 0 0 0 0 0 0 0 0;
525+
0 0 0 G[1] G[5] G[ 9] 0 0 0 0 0 0;
526+
0 0 0 G[2] G[6] G[10] 0 0 0 0 0 0;
527+
0 0 0 G[3] G[7] G[11] 0 0 0 0 0 0;
528+
0 0 0 G[4] G[8] G[12] 0 0 0 0 0 0;
529+
0 0 0 0 0 0 1 0 0 0 0 0;
530+
0 0 0 0 0 0 0 1 0 0 0 0;
531+
0 0 0 0 0 0 0 0 1 0 0 0;
532+
0 0 0 0 0 0 0 0 0 1 0 0;
533+
0 0 0 0 0 0 0 0 0 0 1 0;
534+
0 0 0 0 0 0 0 0 0 0 0 1.]
535+
end
536+
end
537+
end
538+
function inverse_map_jacobian(model::RD.RigidBody{<:UnitQuaternion},
539+
x::SVector, errmap=Rotations.CayleyMap())
540+
q = RD.orientation(model, x)
541+
# G = TrajectoryOptimization.inverse_map_jacobian(q)
542+
G = Rotations.jacobian(inv(errmap), q)
543+
return @SMatrix [
544+
1 0 0 0 0 0 0 0 0 0 0 0 0;
545+
0 1 0 0 0 0 0 0 0 0 0 0 0;
546+
0 0 1 0 0 0 0 0 0 0 0 0 0;
547+
0 0 0 G[1] G[4] G[7] G[10] 0 0 0 0 0 0;
548+
0 0 0 G[2] G[5] G[8] G[11] 0 0 0 0 0 0;
549+
0 0 0 G[3] G[6] G[9] G[12] 0 0 0 0 0 0;
550+
0 0 0 0 0 0 0 1 0 0 0 0 0;
551+
0 0 0 0 0 0 0 0 1 0 0 0 0;
552+
0 0 0 0 0 0 0 0 0 1 0 0 0;
553+
0 0 0 0 0 0 0 0 0 0 1 0 0;
554+
0 0 0 0 0 0 0 0 0 0 0 1 0;
555+
0 0 0 0 0 0 0 0 0 0 0 0 1;
556+
]
557+
end
558+
559+
function inverse_map_∇jacobian(model::RD.RigidBody{<:UnitQuaternion},
560+
x::SVector, b::SVector, errmap=Rotations.CayleyMap())
561+
q = RD.orientation(model, x)
562+
bq = @SVector [b[4], b[5], b[6]]
563+
# ∇G = TrajectoryOptimization.inverse_map_∇jacobian(q, bq)
564+
∇G = Rotations.∇jacobian(inv(errmap), q, bq)
565+
return @SMatrix [
566+
0 0 0 0 0 0 0 0 0 0 0 0 0;
567+
0 0 0 0 0 0 0 0 0 0 0 0 0;
568+
0 0 0 0 0 0 0 0 0 0 0 0 0;
569+
0 0 0 ∇G[1] ∇G[5] ∇G[ 9] ∇G[13] 0 0 0 0 0 0;
570+
0 0 0 ∇G[2] ∇G[6] ∇G[10] ∇G[14] 0 0 0 0 0 0;
571+
0 0 0 ∇G[3] ∇G[7] ∇G[11] ∇G[15] 0 0 0 0 0 0;
572+
0 0 0 ∇G[4] ∇G[8] ∇G[12] ∇G[16] 0 0 0 0 0 0;
573+
0 0 0 0 0 0 0 0 0 0 0 0 0;
574+
0 0 0 0 0 0 0 0 0 0 0 0 0;
575+
0 0 0 0 0 0 0 0 0 0 0 0 0;
576+
0 0 0 0 0 0 0 0 0 0 0 0 0;
577+
0 0 0 0 0 0 0 0 0 0 0 0 0;
578+
0 0 0 0 0 0 0 0 0 0 0 0 0;
579+
]
580+
end
581+
582+
583+

0 commit comments

Comments
 (0)