Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 156 additions & 48 deletions src/Nonlinear/SymbolicAD/SymbolicAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,19 @@ function simplify!(f::MOI.ScalarAffineFunction{T}) where {T}
if isempty(f.terms)
return f.constant
end
if iszero(f.constant) && length(f.terms) == 1
term = only(f.terms)
if isone(term.coefficient)
return term.variable
end
end
return f
end

function simplify!(f::MOI.ScalarQuadraticFunction{T}) where {T}
f = MOI.Utilities.canonicalize!(f)
if isempty(f.quadratic_terms)
if isempty(f.affine_terms)
return f.constant
end
return MOI.ScalarAffineFunction(f.affine_terms, f.constant)
return simplify!(MOI.ScalarAffineFunction(f.affine_terms, f.constant))
end
return f
end
Expand Down Expand Up @@ -117,7 +120,7 @@ function simplify!(f::MOI.ScalarNonlinearFunction)
push!(result_stack, arg)
end
end
return _simplify_if_affine!(only(result_stack))
return _simplify_if_quadratic!(only(result_stack))
end

function simplify!(f::MOI.VectorAffineFunction{T}) where {T}
Expand All @@ -140,10 +143,12 @@ function simplify!(f::MOI.VectorQuadraticFunction{T}) where {T}
end

function simplify!(f::MOI.VectorNonlinearFunction)
for (i, row) in enumerate(f.rows)
f.rows[i] = simplify!(row)
rows = simplify!.(f.rows)
Y = reduce(promote_type, typeof.(rows))
if isconcretetype(Y)
return MOI.Utilities.vectorize(convert(Vector{Y}, rows))
end
return f
return MOI.VectorNonlinearFunction(rows)
end

# If a ScalarNonlinearFunction has only constant arguments, we should return
Expand Down Expand Up @@ -1507,100 +1512,203 @@ function MOI.eval_hessian_lagrangian(model::Evaluator, H, x, σ, μ)
end

# A default fallback for all types
_add_to_affine!(::Any, ::Any, ::T) where {T} = nothing
_add_to_quadratic!(::Any, ::Real, ::Any) = nothing
_add_to_quadratic!(::Any, ::Real, ::Any, ::Any) = nothing

# The creation of `ret::MOI.ScalarAffineFunction` has been delayed until now.
function _add_to_affine!(
::Nothing,
f::Union{Real,MOI.VariableIndex,MOI.ScalarAffineFunction},
# The creation of `ret::MOI.ScalarQuadraticFunction` has been delayed until now.
function _add_to_quadratic!(
::Missing,
scale::T,
) where {T}
return _add_to_affine!(zero(MOI.ScalarAffineFunction{T}), f, scale)
end

function _add_to_affine!(
ret::MOI.ScalarAffineFunction{T},
x::S,
f::Union{
Real,
MOI.VariableIndex,
MOI.ScalarAffineFunction,
MOI.ScalarQuadraticFunction,
}...,
) where {T<:Real}
return _add_to_quadratic!(zero(MOI.ScalarQuadraticFunction{T}), scale, f...)
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
) where {T,S<:Real}
x::S,
) where {T<:Real,S<:Real}
if promote_type(T, S) != T
return # We can't store `S` in `T`.
end
ret.constant += scale * convert(T, x)
return ret
end

function _add_to_affine!(
ret::MOI.ScalarAffineFunction{T},
x::MOI.VariableIndex,
function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
) where {T}
push!(ret.terms, MOI.ScalarAffineTerm(scale, x))
f::MOI.ScalarAffineTerm{S},
) where {T<:Real,S}
@assert promote_type(T, S) == T
push!(
ret.affine_terms,
MOI.ScalarAffineTerm{T}(scale * f.coefficient, f.variable),
)
return ret
end

function _add_to_affine!(
ret::MOI.ScalarAffineFunction{T},
f::MOI.ScalarAffineFunction{S},
function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
) where {T,S}
f::MOI.ScalarQuadraticTerm{S},
) where {T<:Real,S}
@assert promote_type(T, S) == T
push!(
ret.quadratic_terms,
MOI.ScalarQuadraticTerm{T}(
scale * f.coefficient,
f.variable_1,
f.variable_2,
),
)
return ret
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
x::MOI.VariableIndex,
) where {T<:Real}
return _add_to_quadratic!(ret, scale, MOI.ScalarAffineTerm(one(T), x))
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
f::MOI.ScalarAffineFunction{S},
) where {T<:Real,S}
if promote_type(T, S) != T
return # We can't store `S` in `T`.
end
ret = _add_to_affine!(ret, f.constant, scale)
ret = _add_to_quadratic!(ret, scale, f.constant)
for term in f.terms
ret = _add_to_affine!(ret, term.variable, scale * term.coefficient)
ret = _add_to_quadratic!(ret, scale, term)
end
return ret
end

function _add_to_affine!(
ret::Union{Nothing,MOI.ScalarAffineFunction{T}},
f::MOI.ScalarNonlinearFunction,
function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
) where {T}
f::MOI.ScalarQuadraticFunction{S},
) where {T<:Real,S}
if promote_type(T, S) != T
return # We can't store `S` in `T`.
end
ret = _add_to_quadratic!(ret, scale, f.constant)
for term in f.affine_terms
ret = _add_to_quadratic!(ret, scale, term)
end
for q_term in f.quadratic_terms
ret = _add_to_quadratic!(ret, scale, q_term)
end
return ret
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
f::MOI.VariableIndex,
g::MOI.VariableIndex,
) where {T<:Real}
return _add_to_quadratic!(ret, scale, one(T) * f * g)
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
f::MOI.ScalarAffineFunction{F},
g::MOI.ScalarAffineFunction{G},
) where {T<:Real,F,G}
H = MOI.ScalarAffineFunction{promote_type(F, G)}
return _add_to_quadratic!(ret, scale, convert(H, f) * convert(H, g))
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
f::MOI.VariableIndex,
g::MOI.ScalarAffineFunction,
) where {T<:Real}
return _add_to_quadratic!(ret, scale, f * g)
end

function _add_to_quadratic!(
ret::MOI.ScalarQuadraticFunction{T},
scale::T,
f::MOI.ScalarAffineFunction,
g::MOI.VariableIndex,
) where {T<:Real}
return _add_to_quadratic!(ret, scale, g, f)
end

function _add_to_quadratic!(
ret::Union{Missing,MOI.ScalarQuadraticFunction{T}},
scale::T,
f::MOI.ScalarNonlinearFunction,
) where {T<:Real}
if f.head == :+
for arg in f.args
ret = _add_to_affine!(ret, arg, scale)
ret = _add_to_quadratic!(ret, scale, arg)
if ret === nothing
return
end
end
return ret
elseif f.head == :-
if length(f.args) == 1
return _add_to_affine!(ret, only(f.args), -scale)
return _add_to_quadratic!(ret, -scale, only(f.args))
end
@assert length(f.args) == 2
ret = _add_to_affine!(ret, f.args[1], scale)
ret = _add_to_quadratic!(ret, scale, f.args[1])
if ret === nothing
return
end
return _add_to_affine!(ret, f.args[2], -scale)
return _add_to_quadratic!(ret, -scale, f.args[2])
elseif f.head == :*
y = nothing
y1, y2 = nothing, nothing
for arg in f.args
if arg isa Real
scale *= arg
elseif y === nothing
y = arg
elseif y1 === nothing
y1 = arg
elseif y2 === nothing
y2 = arg
else
return # We already have a `y`. Can't multiple factors.
end
end
return _add_to_affine!(ret, something(y, one(T)), convert(T, scale))
if y1 === nothing
@assert y2 === nothing
return _add_to_quadratic!(ret, one(T), scale)
elseif y2 === nothing
return _add_to_quadratic!(ret, scale, y1)
else
return _add_to_quadratic!(ret, scale, y1, y2)
end
elseif f.head == :^ && f.args[2] isa Real && f.args[2] == 2
return _add_to_quadratic!(ret, scale, f.args[1], f.args[1])
elseif f.head == :/ && f.args[2] isa Real
return _add_to_quadratic!(ret, convert(T, scale / f.args[2]), f.args[1])
end
return # An unsupported f.head
end

function _simplify_if_affine!(f::MOI.ScalarNonlinearFunction)
ret = _add_to_affine!(nothing, f, 1.0)
function _simplify_if_quadratic!(f::MOI.ScalarNonlinearFunction)
ret = _add_to_quadratic!(missing, 1.0, f)
if ret === nothing
return f
end
return simplify!(ret::MOI.ScalarAffineFunction{Float64})
return simplify!(ret::MOI.ScalarQuadraticFunction{Float64})
end

_simplify_if_affine!(f::Any) = f
_simplify_if_quadratic!(f::Any) = f

end # module
Loading
Loading