Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,23 @@ function Base.isapprox(
)
end

Base.isapprox(::AbstractScalarFunction, ::Number; kwargs...) = false

function Base.isapprox(f::ScalarAffineFunction, g::Number; kwargs...)
return all(t -> iszero(t.coefficient), f.terms) &&
isapprox(f.constant, g; kwargs...)
end

function Base.isapprox(f::ScalarQuadraticFunction, g::Number; kwargs...)
return all(t -> iszero(t.coefficient), f.quadratic_terms) &&
all(t -> iszero(t.coefficient), f.affine_terms) &&
isapprox(f.constant, g; kwargs...)
end

function Base.isapprox(f::Number, g::AbstractScalarFunction; kwargs...)
return isapprox(g, f; kwargs...)
end

# This method is used by CBF in testing.
function Base.isapprox(f::VectorOfVariables, g::VectorAffineFunction; kwargs...)
return isapprox(convert(typeof(g), f), g; kwargs...)
Expand All @@ -924,18 +941,40 @@ function _is_approx(x::AbstractArray, y::AbstractArray; kwargs...)
all(z -> _is_approx(z[1], z[2]; kwargs...), zip(x, y))
end

# This method is not very robust. For example, it doesn't return `true` if `f`
# could be simplified to `g` (or vice versa).
#
# We additionally need a three-way switch to check
#
# * Both arguments are ScalarNonlinearFunction: add to the stack
# * Neither arguments are ScalarNonlinearFunction: use _is_approx
# * Exactly one argument is ScalarNonlinearFunction: return false
_num_snf(::ScalarNonlinearFunction, ::ScalarNonlinearFunction) = 2
_num_snf(::Any, ::ScalarNonlinearFunction) = 1
_num_snf(::ScalarNonlinearFunction, ::Any) = 1
_num_snf(::Any, ::Any) = 0

function Base.isapprox(
f::ScalarNonlinearFunction,
g::ScalarNonlinearFunction;
kwargs...,
)
if f.head != g.head || length(f.args) != length(g.args)
return false
end
for (fi, gi) in zip(f.args, g.args)
if !_is_approx(fi, gi; kwargs...)
stack = Any[(f, g)]
while !isempty(stack)
fi, gi = pop!(stack)
if fi.head != gi.head || length(fi.args) != length(gi.args)
return false
end
for i in 1:length(fi.args)
x, y = fi.args[i], gi.args[i]
if _num_snf(x, y) == 2
push!(stack, (x, y))
elseif _num_snf(x, y) == 1
return false
elseif !_is_approx(x, y; kwargs...)
return false
end
end
end
return true
end
Expand Down
57 changes: 50 additions & 7 deletions test/General/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ module TestFunctions
using Test
import MathOptInterface as MOI

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$name", "test_")
@testset "$(name)" begin
getfield(@__MODULE__, name)()
end
end
end
return
end

"""
test_isbits()

Expand Down Expand Up @@ -511,16 +522,48 @@ function test_copy_ScalarNonlinearFunction_with_arg()
return
end

function runtests()
for name in names(@__MODULE__; all = true)
if startswith("$name", "test_")
@testset "$(name)" begin
getfield(@__MODULE__, name)()
end
end
function test_isapprox_Number()
x = MOI.VariableIndex(1)
for f in Any[
x,
1.0*x+1.0,
0.0*x+1.1,
1.0*x*x+1.0,
0.0*x*x+1.1,
0.0*x*x+0.0*x+1.1,
MOI.ScalarNonlinearFunction(:log, Any[x]),
]
@test !isapprox(f, 1.0)
@test !isapprox(1.0, f)
end
for f in Any[0.0*x+1.0, 0.0*x*x+1.0, 0.0*x*x+0.0*x+1.0]
@test isapprox(f, 1.0)
@test isapprox(1.0, f)
end
return
end

function test_isapprox_ScalarNonlinearFunction()
x = MOI.VariableIndex(1)
y = MOI.VariableIndex(2)
op(head, args...) = MOI.ScalarNonlinearFunction(head, Any[args...])
f = Any[
op(:+, 1.0),
op(:+, 0.5, 0.5),
op(:-, 1.0),
op(:+, x),
op(:+, y),
op(:+, op(:sin, x)),
op(:+, op(:sin, x), op(:cos, x)),
op(:+, op(:sin, x), 1.0 * x),
op(:+, 1.0 * x, op(:sin, x)),
]
for i in 1:length(f), j in 1:length(f)
@test isapprox(f[i], f[j]) == (i == j)
end
return
end

end # TestFunctions

TestFunctions.runtests()
Loading