From bf3b7ae8b0f5360f394fc3901545ea2087071f0f Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 4 Feb 2026 16:07:42 +1300 Subject: [PATCH 1/2] Fix isapprox for ScalarNonlinearFunction and comparison to Number --- src/functions.jl | 49 ++++++++++++++++++++++++++++++++---- test/General/functions.jl | 52 +++++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 12 deletions(-) diff --git a/src/functions.jl b/src/functions.jl index c1a5c36b53..f5bf984029 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -912,6 +912,23 @@ function Base.isapprox( ) end +Base.isapprox(::AbstractScalarFunction, ::Number; kwargs...) = false + +function Base.isapprox(f::ScalarAffineFunction, g::Number; kwargs...) + return all(t -> iszero(t.coefficient), f.terms) && + isapprox(f.constant, g; kwargs...) +end + +function Base.isapprox(f::ScalarQuadraticFunction, g::Number; kwargs...) + return all(t -> iszero(t.coefficient), f.quadratic_terms) && + all(t -> iszero(t.coefficient), f.affine_terms) && + isapprox(f.constant, g; kwargs...) +end + +function Base.isapprox(f::Number, g::AbstractScalarFunction; kwargs...) + return isapprox(g, f; kwargs...) +end + # This method is used by CBF in testing. function Base.isapprox(f::VectorOfVariables, g::VectorAffineFunction; kwargs...) return isapprox(convert(typeof(g), f), g; kwargs...) @@ -924,18 +941,40 @@ function _is_approx(x::AbstractArray, y::AbstractArray; kwargs...) all(z -> _is_approx(z[1], z[2]; kwargs...), zip(x, y)) end +# This method is not very robust. For example, it doesn't return `true` if `f` +# could be simplified to `g` (or vice versa). +# +# We additionally need a three-way switch to check +# +# * Both arguments are ScalarNonlinearFunction: add to the stack +# * Neither arguments are ScalarNonlinearFunction: use _is_approx +# * Exactly one argument is ScalarNonlinearFunction: return false +_num_snf(::ScalarNonlinearFunction, ::ScalarNonlinearFunction) = 2 +_num_snf(::Any, ::ScalarNonlinearFunction) = 1 +_num_snf(::ScalarNonlinearFunction, ::Any) = 1 +_num_snf(::Any, ::Any) = 0 + function Base.isapprox( f::ScalarNonlinearFunction, g::ScalarNonlinearFunction; kwargs..., ) - if f.head != g.head || length(f.args) != length(g.args) - return false - end - for (fi, gi) in zip(f.args, g.args) - if !_is_approx(fi, gi; kwargs...) + stack = Any[(f, g)] + while !isempty(stack) + fi, gi = pop!(stack) + if fi.head != gi.head || length(fi.args) != length(gi.args) return false end + for i in 1:length(fi.args) + x, y = fi.args[i], gi.args[i] + if _num_snf(x, y) == 2 + push!(stack, (x, y)) + elseif _num_snf(x, y) == 1 + return false + elseif !_is_approx(x, y; kwargs...) + return false + end + end end return true end diff --git a/test/General/functions.jl b/test/General/functions.jl index d6143d161b..56dbf19ed0 100644 --- a/test/General/functions.jl +++ b/test/General/functions.jl @@ -9,6 +9,17 @@ module TestFunctions using Test import MathOptInterface as MOI +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$name", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + """ test_isbits() @@ -511,16 +522,43 @@ function test_copy_ScalarNonlinearFunction_with_arg() return end -function runtests() - for name in names(@__MODULE__; all = true) - if startswith("$name", "test_") - @testset "$(name)" begin - getfield(@__MODULE__, name)() - end - end +function test_isapprox_Number() + x = MOI.VariableIndex(1) + for f in Any[ + x, + 1.0*x+1.0, + 0.0*x+1.1, + 1.0*x*x+1.0, + 0.0*x*x+1.1, + 0.0*x*x+0.0*x+1.1, + MOI.ScalarNonlinearFunction(:log, Any[x]), + ] + @test !isapprox(f, 1.0) + @test !isapprox(1.0, f) + end + for f in Any[0.0*x+1.0, 0.0*x*x+1.0, 0.0*x*x+0.0*x+1.0] + @test isapprox(f, 1.0) + @test isapprox(1.0, f) end + return end +function test_isapprox_ScalarNonlinearFunction() + x = MOI.VariableIndex(1) + y = MOI.VariableIndex(2) + f = Any[ + MOI.ScalarNonlinearFunction(:+, Any[1.0]), + MOI.ScalarNonlinearFunction(:+, Any[0.5, 0.5]), + MOI.ScalarNonlinearFunction(:-, Any[1.0]), + MOI.ScalarNonlinearFunction(:+, Any[x]), + MOI.ScalarNonlinearFunction(:+, Any[y]), + ] + for i in 1:length(f), j in 1:length(f) + @test isapprox(f[i], f[j]) == (i == j) + end + return end +end # TestFunctions + TestFunctions.runtests() From caadd184951984e9abbb29f21c4f14018beda604 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Wed, 4 Feb 2026 16:39:20 +1300 Subject: [PATCH 2/2] Update --- test/General/functions.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/test/General/functions.jl b/test/General/functions.jl index 56dbf19ed0..fdea07e6c1 100644 --- a/test/General/functions.jl +++ b/test/General/functions.jl @@ -546,12 +546,17 @@ end function test_isapprox_ScalarNonlinearFunction() x = MOI.VariableIndex(1) y = MOI.VariableIndex(2) + op(head, args...) = MOI.ScalarNonlinearFunction(head, Any[args...]) f = Any[ - MOI.ScalarNonlinearFunction(:+, Any[1.0]), - MOI.ScalarNonlinearFunction(:+, Any[0.5, 0.5]), - MOI.ScalarNonlinearFunction(:-, Any[1.0]), - MOI.ScalarNonlinearFunction(:+, Any[x]), - MOI.ScalarNonlinearFunction(:+, Any[y]), + op(:+, 1.0), + op(:+, 0.5, 0.5), + op(:-, 1.0), + op(:+, x), + op(:+, y), + op(:+, op(:sin, x)), + op(:+, op(:sin, x), op(:cos, x)), + op(:+, op(:sin, x), 1.0 * x), + op(:+, 1.0 * x, op(:sin, x)), ] for i in 1:length(f), j in 1:length(f) @test isapprox(f[i], f[j]) == (i == j)