-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Description
I build a differentiable 2D test function (akin to the examples in the docs), sample it and construct its RBF using Surrogates , then evaluate the Zygote.gradient at (i) a random point, which works, and (ii) a sample point, which wrongly returns (NaN, NaN):
using Zygote
using Surrogates
num_samples = 1024
# Defining the function
f = x -> log(x[1]) * x[1]^2 + x[1]^3 + x[2]^3*sin(x[2])
# Sampling points from the function
lb = [1.0, 1.0]
ub = [10.0, 10.0]
x = sample(num_samples, lb, ub, SobolSample())
y = f.(x)
# Constructing the surrogate
my_radial_basis = RadialBasis(x, y, lb, ub)
# Compute gradients at a given point
function compute_gradients(my_pt)
df_dx = Zygote.gradient(my_radial_basis, my_pt)[1]
exact_derivative = Zygote.gradient(f, my_pt)[1]
display("Approximate derivative at (x,y)=$(my_pt): $df_dx")
display("Exact derivative at (x,y)=$(my_pt): $exact_derivative")
end
# Derivatives using random point: works
my_pt = Tuple(((ub .- lb) .* rand(2)) .+ lb)
compute_gradients(my_pt)
# Derivatives using point from sample: gives (NaN, NaN)
my_pt = x[1]
compute_gradients(my_pt)with
julia> versioninfo()
Julia Version 1.12.1
Commit ba1e628ee49 (2025-10-17 13:02 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: macOS (arm64-apple-darwin24.0.0)
CPU: 12 × Apple M3 Pro
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, apple-m3)
GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 6 virtual cores)
Environment:
JULIA_MPI_BINARY =
JULIA_PETSC_LIBRARY =
gives
"Approximate derivative at (x,y)=(4.756243123793594, 9.061345878392638): (82.40466826848662, -613.6814366645162)"
"Exact derivative at (x,y)=(4.756243123793594, 9.061345878392638): (87.45611278220835, -607.8476392599746)"
"Approximate derivative at (x,y)=(1.01318359375, 4.38818359375): (NaN, NaN)"
"Exact derivative at (x,y)=(1.01318359375, 4.38818359375): (4.119346813521075, -81.67677763868926)"
Metadata
Metadata
Assignees
Labels
No labels