diff --git a/Project.toml b/Project.toml index f53d33729..7a5bbf39c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "3.8.1" +version = "3.8.2" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index fb13f50b4..3e707228c 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -30,6 +30,12 @@ include("types.jl") # Methods on symbolic objects using SpecialFunctions, NaNMath + +# NaNMath.pow does not handle x::Int ^ y::Int -> ::Int +# Use this instead as a wrapper over NaNMath.pow +pow(x,y) = NaNMath.pow(x,y) +pow(x::Int, y::Int) = x^y + import IfElse: ifelse # need to not bring IfElse name in or it will clash with Rewriters.IfElse include("methods.jl") diff --git a/src/code.jl b/src/code.jl index 8b1953b20..34baeb84f 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, sorted_arguments, metadata, isterm, term, maketerm + symtype, sorted_arguments, metadata, isterm, term, maketerm, pow import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -146,12 +146,12 @@ function function_to_expr(op::typeof(^), O, st) return toexpr(Term(inv, Any[ex]), st) else args = Any[Term(inv, Any[ex]), -args[2]] - op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow + op = get(st.rewrites, :nanmath, false) === true ? pow : op return toexpr(Term(op, args), st) end end get(st.rewrites, :nanmath, false) === true || return nothing - return toexpr(Term(NaNMath.pow, args), st) + return toexpr(Term(pow, args), st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) diff --git a/src/methods.jl b/src/methods.jl index 2baef6424..665e65321 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -20,7 +20,7 @@ const monadic = [deg2rad, rad2deg, transpose, asind, log1p, acsch, const diadic = [max, min, hypot, atan, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk, hankelh1, hankelh2, - polygamma, beta, logbeta, NaNMath.pow] + polygamma, beta, logbeta, pow] const previously_declared_for = Set([]) const basic_monadic = [-, +] diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index 163467d6a..7a495a794 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -42,7 +42,7 @@ const num_spec = let ()->rand([a b c d e f])] binops = SymbolicUtils.diadic - nopow = setdiff(binops, [(^), NaNMath.pow, besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk]) + nopow = setdiff(binops, [(^), SymbolicUtils.pow, besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk]) twoargfns = vcat(nopow, (x,y)->x isa Union{Int, Rational, Complex{<:Rational}} ? x * y : x^y) fns = vcat(1 .=> vcat(SymbolicUtils.monadic, [one, zero]), 2 .=> vcat(twoargfns, fill(+, 5), [-,-], fill(*, 5), fill(/, 40)),