diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 931c7009b3..bf9638c473 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -664,6 +664,11 @@ def c_code_cache_version(self): tensor_from_scalar = TensorFromScalar() +@_vectorize_node.register(TensorFromScalar) +def vectorize_tensor_from_scalar(op, node, batch_x): + return identity(batch_x).owner + + class ScalarFromTensor(COp): __props__ = () @@ -2046,6 +2051,7 @@ def register_transfer(fn): """Create a duplicate of `a` (with duplicated storage)""" tensor_copy = Elemwise(ps.identity) pprint.assign(tensor_copy, printing.IgnorePrinter()) +identity = tensor_copy class Default(Op): @@ -4603,6 +4609,7 @@ def ix_(*args): "matrix_transpose", "default", "tensor_copy", + "identity", "transfer", "alloc", "identity_like", diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 2088dd99cd..7d4bdade77 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -7,7 +7,7 @@ import pytensor.scalar as ps from pytensor.compile.function import function -from pytensor.gradient import grad, hessian, jacobian +from pytensor.gradient import grad, jacobian from pytensor.graph.basic import Apply, Constant from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType @@ -484,6 +484,7 @@ def __init__( jac: bool = True, hess: bool = False, hessp: bool = False, + use_vectorized_jac: bool = False, optimizer_kwargs: dict | None = None, ): if not cast(TensorVariable, objective).ndim == 0: @@ -496,6 +497,7 @@ def __init__( ) self.fgraph = FunctionGraph([x, *args], [objective]) + self.use_vectorized_jac = use_vectorized_jac if jac: grad_wrt_x = cast( @@ -505,7 +507,12 @@ def __init__( if hess: hess_wrt_x = cast( - Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0]) + Variable, + jacobian( + self.fgraph.outputs[-1], + self.fgraph.inputs[0], + vectorize=use_vectorized_jac, + ), ) self.fgraph.add_output(hess_wrt_x) @@ -561,7 +568,10 @@ def L_op(self, inputs, outputs, output_grads): implicit_f = grad(inner_fx, inner_x) df_dx, *df_dtheta_columns = jacobian( - implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore" + implicit_f, + [inner_x, *inner_args], + disconnected_inputs="ignore", + vectorize=self.use_vectorized_jac, ) grad_wrt_args = implict_optimization_grads( df_dx=df_dx, @@ -581,6 +591,7 @@ def minimize( method: str = "BFGS", jac: bool = True, hess: bool = False, + use_vectorized_jac: bool = False, optimizer_kwargs: dict | None = None, ) -> tuple[TensorVariable, TensorVariable]: """ @@ -590,18 +601,21 @@ def minimize( ---------- objective : TensorVariable The objective function to minimize. This should be a pytensor variable representing a scalar value. - - x : TensorVariable + x: TensorVariable The variable with respect to which the objective function is minimized. It must be an input to the computational graph of `objective`. - - method : str, optional + method: str, optional The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options. - - jac : bool, optional - Whether to compute and use the gradient of teh objective function with respect to x for optimization. + jac: bool, optional + Whether to compute and use the gradient of the objective function with respect to x for optimization. Default is True. - + hess: bool, optional + Whether to compute and use the Hessian of the objective function with respect to x for optimization. + Default is False. Note that some methods require this, while others do not support it. + use_vectorized_jac: bool, optional + Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a + scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster, + but use more memory. Default is False. optimizer_kwargs Additional keyword arguments to pass to scipy.optimize.minimize @@ -624,6 +638,7 @@ def minimize( method=method, jac=jac, hess=hess, + use_vectorized_jac=use_vectorized_jac, optimizer_kwargs=optimizer_kwargs, ) @@ -804,6 +819,7 @@ def __init__( method: str = "hybr", jac: bool = True, optimizer_kwargs: dict | None = None, + use_vectorized_jac: bool = False, ): if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim: raise ValueError( @@ -817,7 +833,11 @@ def __init__( self.fgraph = FunctionGraph([variables, *args], [equations]) if jac: - jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0]) + jac_wrt_x = jacobian( + self.fgraph.outputs[0], + self.fgraph.inputs[0], + vectorize=use_vectorized_jac, + ) self.fgraph.add_output(atleast_2d(jac_wrt_x)) self.jac = jac @@ -897,8 +917,14 @@ def L_op( inner_x, *inner_args = self.fgraph.inputs inner_fx = self.fgraph.outputs[0] - df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1] - df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore") + df_dx = ( + jacobian(inner_fx, inner_x, vectorize=True) + if not self.jac + else self.fgraph.outputs[1] + ) + df_dtheta_columns = jacobian( + inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True + ) grad_wrt_args = implict_optimization_grads( df_dx=df_dx, @@ -917,6 +943,7 @@ def root( variables: TensorVariable, method: str = "hybr", jac: bool = True, + use_vectorized_jac: bool = False, optimizer_kwargs: dict | None = None, ) -> tuple[TensorVariable, TensorVariable]: """ @@ -935,6 +962,10 @@ def root( jac : bool, optional Whether to compute and use the Jacobian of the `equations` with respect to `variables`. Default is True. Most methods require this. + use_vectorized_jac: bool, optional + Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead. + This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory. + Default is False. optimizer_kwargs : dict, optional Additional keyword arguments to pass to `scipy.optimize.root`. @@ -958,6 +989,7 @@ def root( method=method, jac=jac, optimizer_kwargs=optimizer_kwargs, + use_vectorized_jac=use_vectorized_jac, ) solution, success = cast(