From 096317d7bfb88f3c478c093621fba280ae289a9f Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 2 Aug 2024 21:47:22 -0400 Subject: [PATCH 1/3] Make propagate more like a monadic bind by supporting stochastic triple creating functions --- src/propagate.jl | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/propagate.jl b/src/propagate.jl index 5bf5f64..dee9ff5 100644 --- a/src/propagate.jl +++ b/src/propagate.jl @@ -42,12 +42,14 @@ function strip_Δs(arg; use_dual = Val(true)) end """ - propagate(f, args...; keep_deltas = Val(false)) + propagate(f, args...; keep_deltas = Val(false), keep_triples = Val(false)) Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`). Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`. If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`. +If `f` may itself return stochastic triples even once perturbations have been stripped from `args`, +(e.g. because `f` closes over stochastic triples), then provide `keep_triples = Val(true)`. This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator overloading rules based on dispatch. For example: @@ -89,6 +91,7 @@ StochasticTriple of Int64: function propagate(f, args...; keep_deltas = Val(false), + keep_triples = Val(false), provided_st_rep = nothing, deriv = nothing) # TODO: support kwargs to f (or just use kwfunc in macro) @@ -118,15 +121,16 @@ function propagate(f, end primal_args = structural_map(get_value, args) - input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args) - #= - TODO: the below is dangerous is general. - It should be safe so long as f does not close over stochastic triples. - (If f is a closure, the parameters of f should be treated like any other parameters; - if they are stochastic triples and we are ignoring them, dangerous in general.) - =# + input_args = if keep_triples isa Val{true} + structural_map(x -> strip_Δs(x; use_dual = Val(false)), args) + elseif keep_deltas isa Val{true} + structural_map(strip_Δs, args) + else + primal_args + end out = f(input_args...) val = structural_map(value, out) + Δs1 = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), out) # TODO: what does the only_vals do in the below and why? Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args; only_vals = Val{true}()) @@ -142,11 +146,14 @@ function propagate(f, alt = f(perturbed_args...) return structural_map((x, y) -> value(x) - y, alt, val) end - Δs = map(map_func, Δs_coupled; out_rep = val, deriv) + Δs2 = map(map_func, Δs_coupled; out_rep = val, deriv) + # TODO: make sure all FI backends support interface needed below - new_out = structural_map(out, scalarize(Δs; out_rep = val)) do leaf_out, leaf_Δs + new_out = structural_map(out, Δs1, scalarize(Δs2; out_rep = val)) do leaf_out, leaf_Δs1, leaf_Δs2 + leaf_Δs = combine(backendtype(st_rep), (leaf_Δs1, leaf_Δs2)) StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out), leaf_Δs) end return new_out end + From 0682f44cca2ac25a881f0c908d22535eecba1c19 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 2 Aug 2024 21:49:28 -0400 Subject: [PATCH 2/3] Format --- src/propagate.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/propagate.jl b/src/propagate.jl index dee9ff5..2ad7d7c 100644 --- a/src/propagate.jl +++ b/src/propagate.jl @@ -123,14 +123,14 @@ function propagate(f, primal_args = structural_map(get_value, args) input_args = if keep_triples isa Val{true} structural_map(x -> strip_Δs(x; use_dual = Val(false)), args) - elseif keep_deltas isa Val{true} + elseif keep_deltas isa Val{true} structural_map(strip_Δs, args) else primal_args end out = f(input_args...) val = structural_map(value, out) - Δs1 = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), out) + Δs1 = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), out) # TODO: what does the only_vals do in the below and why? Δs_all = structural_map(Base.Fix2(get_Δs, backendtype(st_rep)), args; only_vals = Val{true}()) @@ -149,11 +149,11 @@ function propagate(f, Δs2 = map(map_func, Δs_coupled; out_rep = val, deriv) # TODO: make sure all FI backends support interface needed below - new_out = structural_map(out, Δs1, scalarize(Δs2; out_rep = val)) do leaf_out, leaf_Δs1, leaf_Δs2 + new_out = structural_map( + out, Δs1, scalarize(Δs2; out_rep = val)) do leaf_out, leaf_Δs1, leaf_Δs2 leaf_Δs = combine(backendtype(st_rep), (leaf_Δs1, leaf_Δs2)) StochasticAD.StochasticTriple{tag(st_rep)}(value(leaf_out), delta(leaf_out), leaf_Δs) end return new_out end - From f0680c7e72bfd25b30ba16989c061948a706e991 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 2 Aug 2024 22:21:05 -0400 Subject: [PATCH 3/3] Tweak docstring --- src/propagate.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/propagate.jl b/src/propagate.jl index 2ad7d7c..4b533ee 100644 --- a/src/propagate.jl +++ b/src/propagate.jl @@ -46,10 +46,12 @@ end Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`). -Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`. +Handles functions `f` with any input and output that is `fmap`-able by `Functors.jl`, including functions `f` that are closures +over stochastic triples. If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`. -If `f` may itself return stochastic triples even once perturbations have been stripped from `args`, -(e.g. because `f` closes over stochastic triples), then provide `keep_triples = Val(true)`. +If continuous perturbations to `args` can cause discrete pertubations to be created within `f`, +then provide `keep_triples = Val(true)`. + This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator overloading rules based on dispatch. For example: