diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fa4d7a7..63e0ed5 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: Tests on: push: - branches: [main] + branches: [main, develop] tags: ["*"] pull_request: workflow_dispatch: @@ -32,6 +32,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e '.[dev,jax,web]' + pip install -e '.[dev,jax,torch,web]' - name: Test run: pytest \ No newline at end of file diff --git a/conftest.py b/conftest.py index 432e51e..0c941c1 100644 --- a/conftest.py +++ b/conftest.py @@ -2,13 +2,13 @@ import pytest import numpy -BACKENDS: set[str] = set(('cpu', 'jax', 'cuda')) +BACKENDS: set[str] = set(('cpu', 'jax', 'cupy', 'torch')) AVAILABLE_BACKENDS: set[str] = set(('cpu',)) try: import cupy # pyright: ignore[reportMissingImports] if cupy.cuda.runtime.getDeviceCount() > 0: - AVAILABLE_BACKENDS.add('cuda') + AVAILABLE_BACKENDS.add('cupy') except ImportError: pass @@ -19,6 +19,14 @@ pass +try: + import torch # pyright: ignore[reportMissingImports] # noqa: F401 + torch.asarray([1, 2, 3, 4]).numpy(force=True) # ensures torch is loaded, fixes a strange error with pytest + AVAILABLE_BACKENDS.add('torch') +except ImportError: + pass + + def pytest_addoption(parser: pytest.Parser): parser.addoption("--save-expected", action="store_true", dest='save-expected', default=False, help="Overwrite expected files with the results of tests.") diff --git a/examples/mos2_epie.yaml b/examples/mos2_epie.yaml index 6005257..b7fe806 100644 --- a/examples/mos2_epie.yaml +++ b/examples/mos2_epie.yaml @@ -1,11 +1,11 @@ --- name: "mos2_epie" -backend: cupy +backend: torch # raw data source raw_data: type: empad - path: "sample_data/simulated_mos2/mos2_0.00_dstep1.0.json" + path: "~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json" post_load: - type: poisson @@ -32,7 +32,8 @@ engines: beta_probe: 0.5 group_constraints: [] - iter_constraints: [] + iter_constraints: + - type: remove_phase_ramp update_probe: {after: 5} diff --git a/notebooks/conventions.ipynb b/notebooks/conventions.ipynb index ace72a1..6c745f6 100644 --- a/notebooks/conventions.ipynb +++ b/notebooks/conventions.ipynb @@ -74,6 +74,25 @@ "\n", "Nevertheless, pixels are often drawn as \"little squares\", requiring us to choose a convention of where to place these squares. We place each pixel at the center of the sampling point it represents. e.g. the pixel center is at integer coordinates." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notes for backend-agnostic coding\n", + "\n", + "Backends are libraries which support the Array API: https://data-apis.org/array-api/2024.12/\n", + "\n", + "https://github.com/pytorch/pytorch/issues/58743\n", + "\n", + "- Use `xp.size()` rather than `arr.size` (for Torch)\n", + "- Use `at()` util for in-place modifications (for Jax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/phaser/engines/common/noise_models.py b/phaser/engines/common/noise_models.py index 66cd3be..c366f24 100644 --- a/phaser/engines/common/noise_models.py +++ b/phaser/engines/common/noise_models.py @@ -6,7 +6,7 @@ from phaser.hooks.solver import NoiseModel from phaser.plan import AmplitudeNoisePlan, AnscombeNoisePlan, PoissonNoisePlan -from phaser.utils.num import get_array_module, Float +from phaser.utils.num import get_array_module, Float, to_numpy from phaser.state import ReconsState diff --git a/phaser/engines/common/regularizers.py b/phaser/engines/common/regularizers.py index 15c9fc1..8db02f3 100644 --- a/phaser/engines/common/regularizers.py +++ b/phaser/engines/common/regularizers.py @@ -1,5 +1,6 @@ from functools import partial import logging +from math import prod import typing as t import numpy @@ -205,7 +206,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(sim.object.data - 1.0)) - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -222,8 +223,9 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(abs2(sim.object.data - 1.0)) - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) - return (cost * cost_scale * self.cost, state) + + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) + return (cost * cost_scale * self.cost, state) # type: ignore class ObjPhaseL1: @@ -239,7 +241,7 @@ def calc_loss_group( xp = get_array_module(sim.object.data) cost = xp.sum(xp.abs(xp.angle(sim.object.data))) - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -261,7 +263,7 @@ def calc_loss_group( xp.abs(fft2(xp.prod(sim.object.data, axis=0))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -289,7 +291,7 @@ def calc_loss_group( #) # scale cost by fraction of the total reconstruction in the group # TODO also scale by # of pixels or similar? - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -311,9 +313,9 @@ def calc_loss_group( xp.sum(abs2(xp.diff(sim.object.data, axis=-2))) ) # scale cost by fraction of the total reconstruction in the group - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) - return (cost * cost_scale * self.cost, state) + return (cost * cost_scale * self.cost, state) # type: ignore class LayersTotalVariation: @@ -333,7 +335,7 @@ def calc_loss_group( cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) return (cost * cost_scale * self.cost, state) @@ -355,9 +357,9 @@ def calc_loss_group( cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0))) # scale cost by fraction of the total reconstruction in the group - cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype) + cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) - return (cost * cost_scale * self.cost, state) + return (cost * cost_scale * self.cost, state) # type: ignore class ProbePhaseTikhonov: diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index fb88163..836b255 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -3,14 +3,15 @@ import typing as t import numpy -from numpy.typing import NDArray, DTypeLike +from numpy.typing import NDArray from typing_extensions import Self from phaser.utils.num import ( get_array_module, to_real_dtype, to_complex_dtype, fft2, ifft2, is_jax, to_numpy, block_until_ready, ufunc_outer ) -from phaser.utils.misc import FloatKey, jax_dataclass, create_compact_groupings, create_sparse_groupings, shuffled +from phaser.utils.tree import tree_dataclass +from phaser.utils.misc import FloatKey, create_compact_groupings, create_sparse_groupings, shuffled from phaser.utils.optics import fresnel_propagator, fourier_shift_filter from phaser.state import ReconsState from phaser.hooks.solver import NoiseModel @@ -83,7 +84,7 @@ def stream_patterns( continue -@jax_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx')) +@tree_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx')) class SimulationState: state: ReconsState @@ -99,7 +100,7 @@ class SimulationState: iter_constraint_states: t.Tuple[t.Any, ...] xp: t.Any - dtype: DTypeLike + dtype: t.Type[numpy.floating] start_iter: int def __init__( @@ -109,7 +110,7 @@ def __init__( group_constraints: t.Tuple[GroupConstraint[t.Any], ...], iter_constraints: t.Tuple[IterConstraint[t.Any], ...], xp: t.Any, - dtype: DTypeLike, + dtype: t.Type[numpy.floating], noise_model_state: t.Optional[t.Any] = None, group_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None, iter_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None, diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 578fcdf..1ce2fec 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -3,7 +3,7 @@ import numpy from phaser.utils.misc import mask_fraction_of_groups -from phaser.utils.num import cast_array_module, to_numpy, to_complex_dtype +from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype from phaser.observer import Observer from phaser.hooks import EngineArgs from phaser.plan import ConventionalEnginePlan @@ -17,6 +17,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: xp = cast_array_module(args['xp']) dtype = args['dtype'] + cdtype = to_complex_dtype(dtype) observer: Observer = args.get('observer', Observer()) seed = args['seed'] @@ -37,12 +38,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: xp=xp, dtype=dtype ) patterns = args['data'].patterns - pattern_mask = xp.array(args['data'].pattern_mask) + pattern_mask = xp.asarray(args['data'].pattern_mask) - assert patterns.dtype == sim.dtype - assert pattern_mask.dtype == sim.dtype - assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + assert_dtype(patterns, dtype) + assert_dtype(pattern_mask, dtype) + assert_dtype(sim.state.object.data, cdtype) + assert_dtype(sim.state.probe.data, cdtype) solver = props.solver(props) sim = solver.init(sim) @@ -87,8 +88,8 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: calc_error_mask=calc_error_mask, observer=observer, ) - assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + assert_dtype(sim.state.object.data, cdtype) + assert_dtype(sim.state.probe.data, cdtype) sim = sim.apply_iter_constraints() @@ -104,7 +105,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: update_mag = xp.linalg.norm(pos_update, axis=-1, keepdims=True) logger.info(f"Position update: mean {xp.mean(update_mag)}") sim.state.scan += pos_update - assert sim.state.scan.dtype == sim.dtype + assert_dtype(sim.state.scan, dtype) # check positions are at least overlapping object sim.state.object.sampling.check_scan(sim.state.scan, sim.state.probe.sampling.extent / 2.) diff --git a/phaser/engines/conventional/solvers.py b/phaser/engines/conventional/solvers.py index 7d6cdc5..d689d95 100644 --- a/phaser/engines/conventional/solvers.py +++ b/phaser/engines/conventional/solvers.py @@ -107,8 +107,8 @@ def run_iteration( gamma=gamma, ) check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") - assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + #assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) + #assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) sim = sim.apply_group_constraints(group) @@ -365,8 +365,8 @@ def run_iteration( update_probe=update_probe, ) check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}") - assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) - assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) + #assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype) + #assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype) sim = sim.apply_group_constraints(group) diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 84a4cce..6b84bc3 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -7,11 +7,11 @@ from typing_extensions import Self from phaser.hooks.solver import NoiseModel -from phaser.utils.misc import jax_dataclass from phaser.utils.num import ( - get_array_module, cast_array_module, jit, - fft2, ifft2, abs2, check_finite, at, Float, to_real_dtype + assert_dtype, get_array_module, cast_array_module, jit, + fft2, ifft2, abs2, check_finite, at, Float, to_complex_dtype, to_numpy, to_real_dtype ) +import phaser.utils.tree as tree from phaser.utils.optics import fourier_shift_filter from phaser.observer import Observer from phaser.state import ReconsState @@ -73,13 +73,14 @@ def process_solvers( ('tilt',): 'tilt' } -def extract_vars(state: ReconsState, vars: t.AbstractSet[ReconsVar], group: t.Optional[NDArray[numpy.integer]] = None) -> t.Tuple[t.Dict[ReconsVar, t.Any], ReconsState]: - import jax.tree_util +def _normalize_path(path: t.Tuple[tree.GetAttrKey, ...]) -> t.Tuple[str, ...]: + return tuple(p.name for p in path) +def extract_vars(state: ReconsState, vars: t.AbstractSet[ReconsVar], group: t.Optional[NDArray[numpy.integer]] = None) -> t.Tuple[t.Dict[ReconsVar, t.Any], ReconsState]: d = {} - def f(path: t.Tuple[str, ...], val: t.Any): - if (var := _PATH_MAP.get(path)) and var in vars: + def f(path: t.Tuple[tree.GetAttrKey, ...], val: t.Any): + if (var := _PATH_MAP.get(_normalize_path(path))) and var in vars: if var in _PER_ITER_VARS and group is not None: d[var] = val[tuple(group)] else: @@ -87,21 +88,19 @@ def f(path: t.Tuple[str, ...], val: t.Any): return None return val - state = jax.tree_util.tree_map_with_path(f, state, is_leaf=lambda x: x is None) + state = tree.map_with_path(f, state, is_leaf=lambda x: x is None) return (d, state) def insert_vars(vars: t.Dict[ReconsVar, t.Any], state: ReconsState, group: t.Optional[NDArray[numpy.integer]] = None) -> ReconsState: - import jax.tree_util - - def f(path: t.Tuple[str, ...], val: t.Any): - if (var := _PATH_MAP.get(path)): + def f(path: t.Tuple[tree.GetAttrKey, ...], val: t.Any): + if (var := _PATH_MAP.get(_normalize_path(path))): if var in vars: return vars[var] if var in _PER_ITER_VARS and val is not None and group is not None: return val[tuple(group)] return val - return jax.tree_util.tree_map_with_path(f, state, is_leaf=lambda x: x is None) + return tree.map_with_path(f, state, is_leaf=lambda x: x is None) def apply_update(state: ReconsState, update: t.Dict[ReconsVar, numpy.ndarray]) -> ReconsState: @@ -128,7 +127,7 @@ def filter_vars(d: t.Dict[ReconsVar, t.Any], vars: t.AbstractSet[ReconsVar]) -> return {k: v for (k, v) in d.items() if k in vars} -@jax_dataclass +@tree.tree_dataclass class SolverStates: noise_model_state: t.Any group_solver_states: t.List[t.Any] @@ -154,18 +153,17 @@ def init_state( def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: - import jax - import jax.numpy - from optax.tree_utils import tree_zeros_like - jax.config.update('jax_traceback_filtering', 'off') - - xp = cast_array_module(jax.numpy) - dtype = t.cast(t.Type[numpy.floating], args['dtype']) + #jax.config.update('jax_traceback_filtering', 'off') + xp = cast_array_module(args['xp']) + dtype = args['dtype'] + cdtype = to_complex_dtype(dtype) observer: Observer = args.get('observer', Observer()) state = args['state'] seed = args['seed'] patterns = args['data'].patterns - pattern_mask = args['data'].pattern_mask + pattern_mask = xp.array(args['data'].pattern_mask) + assert_dtype(patterns, dtype) + assert_dtype(pattern_mask, dtype) noise_model = props.noise_model(None) @@ -204,7 +202,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: rescale_factor = xp.mean(rescale_factors) logger.info("Pre-calculated intensities") - logger.info(f"Rescaling initial probe intensity by {rescale_factor:.2e}") + logger.info(f"Rescaling initial probe intensity by {float(rescale_factor):.2e}") state.probe.data *= xp.sqrt(rescale_factor) probe_int = xp.sum(abs2(state.probe.data)) @@ -215,6 +213,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: iter_constraint_states = [reg.init_state(state) for reg in iter_constraints] #with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + #with torch.profiler.profile(with_stack=True) as prof: for i in range(1, props.niter+1): state.iter.engine_iter = i state.iter.total_iter = start_i + i @@ -224,7 +223,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: set(k for (k, flag) in flags.items() if flag({'state': state, 'niter': props.niter})) ) # gradients for per-iteration solvers - iter_grads = tree_zeros_like(extract_vars(state, iter_vars & _PER_ITER_VARS)[0]) + iter_grads = tree.zeros_like(extract_vars(state, iter_vars & _PER_ITER_VARS)[0]) # whether to shuffle groups this iteration iter_shuffle_groups = shuffle_groups({'state': state, 'niter': props.niter}) @@ -257,7 +256,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: xp=xp, dtype=dtype ) - losses.append(loss) + losses.append(float(loss)) check_finite(state.object.data, state.probe.data, context=f"object or probe, group {group_i}") observer.update_group(state, props.send_every_group) @@ -278,9 +277,13 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: state, iter_constraint_states[reg_i] ) + assert_dtype(state.object.data, cdtype) + assert_dtype(state.probe.data, cdtype) + if 'positions' in iter_vars: # check positions are at least overlapping object state.object.sampling.check_scan(state.scan, state.probe.sampling.extent / 2.) + assert_dtype(state.scan, dtype) state.progress.iters = numpy.concatenate([state.progress.iters, [i + start_i]]) state.progress.detector_errors = numpy.concatenate([state.progress.detector_errors, [loss]]) @@ -312,25 +315,25 @@ def run_group( xp: t.Any, dtype: t.Type[numpy.floating], ) -> t.Tuple[ReconsState, float, t.Dict[ReconsVar, t.Any], SolverStates]: - import jax xp = cast_array_module(xp) - ((loss, solver_states), grad) = jax.value_and_grad(run_model, has_aux=True)( + ((loss, solver_states), grad) = tree.value_and_grad(run_model, has_aux=True, xp=xp, sign=-1)( *extract_vars(state, vars, group), group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask, noise_model=noise_model, regularizers=regularizers, solver_states=solver_states, xp=xp, dtype=dtype ) - # steepest descent direction - grad = jax.tree.map(lambda v: -v.conj(), grad, is_leaf=lambda x: x is None) for k in grad.keys(): if k == 'probe': grad[k] /= group.shape[-1] else: grad[k] /= probe_int * group.shape[-1] + #print(f"obj grad: {xp.max(abs2(grad['object']))}") + #print(f"probe grad: {xp.max(abs2(grad['probe'])) if 'probe' in grad else None}") + # update iter grads at group - iter_grads = jax.tree.map(lambda v1, v2: at(v1, tuple(group)).set(v2), iter_grads, filter_vars(grad, vars & _PER_ITER_VARS)) + iter_grads = tree.map(lambda v1, v2: at(v1, tuple(group)).set(v2), iter_grads, filter_vars(grad, vars & _PER_ITER_VARS)) for (sol_i, solver) in enumerate(group_solvers): solver_grads = filter_vars(grad, solver.params) diff --git a/phaser/engines/gradient/solvers.py b/phaser/engines/gradient/solvers.py index f0d3425..a9adc4c 100644 --- a/phaser/engines/gradient/solvers.py +++ b/phaser/engines/gradient/solvers.py @@ -1,91 +1,104 @@ -import logging +""" +Gradient-descent solvers + +Much of this is adapted from [Optax](https://github.com/google-deepmind/optax), +but modified to use our generic array and pytree utilities. + +Optax is released under the Apache license: + +> Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +> +> Licensed under the Apache License, Version 2.0 (the "License"); +> you may not use this file except in compliance with the License. +> You may obtain a copy of the License at +> +> http://www.apache.org/licenses/LICENSE-2.0 +> +> Unless required by applicable law or agreed to in writing, software +> distributed under the License is distributed on an "AS IS" BASIS, +> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +> See the License for the specific language governing permissions and +> limitations under the License. +""" + import typing as t import numpy -from numpy.typing import NDArray, ArrayLike +from numpy.typing import NDArray -from phaser.utils.num import as_array, abs2 +from phaser.utils.num import get_array_module +import phaser.utils.tree as tree from phaser.hooks.solver import GradientSolver, GradientSolverArgs -from phaser.hooks.schedule import FlagArgs, ScheduleLike +from phaser.hooks.schedule import ScheduleLike, Schedule from phaser.types import ReconsVar, process_schedule -from phaser.plan import GradientEnginePlan, AdamSolverPlan, PolyakSGDSolverPlan, SGDSolverPlan +from phaser.plan import AdamSolverPlan, PolyakSGDSolverPlan, SGDSolverPlan from phaser.state import ReconsState -from .run import extract_vars, apply_update +from .run import extract_vars -import optax -from optax import GradientTransformation, GradientTransformationExtraArgs -from optax.schedules import StatefulSchedule +OptState: t.TypeAlias = tree.Tree +Params: t.TypeAlias = tree.Tree +Updates: t.TypeAlias = Params -class OptaxScheduleWrapper(StatefulSchedule): - def __init__(self, schedule: ScheduleLike): - self.inner = process_schedule(schedule) +class TransformInitFn(t.Protocol): + def __call__(self, params: Params) -> OptState: + ... - def init(self) -> t.Optional[float]: - return None - def update_for_iter(self, sim: ReconsState, state: t.Optional[float], niter: int) -> float: - return self.inner({'state': sim, 'niter': niter}) +class TransformUpdateFn(t.Protocol): + def __call__( + self, updates: Updates, state: OptState, params: t.Optional[Params] = None, + **extra_args: t.Any, + ) -> t.Tuple[Updates, OptState]: + ... - # mock update from inside jax - def update( - self, state: t.Optional[float], - **extra_args, - ) -> t.Optional[float]: - return state - def __call__( - self, state: t.Optional[float], - **extra_args, - ) -> float: - assert state is not None - return state +class GradientTransformation(t.NamedTuple): + init: TransformInitFn + update: TransformUpdateFn -OptaxSolverState: t.TypeAlias = t.Tuple[t.Any, t.Dict[str, t.Optional[float]]] +ScheduledSolverState: t.TypeAlias = t.Tuple[t.Any, t.Dict[str, t.Optional[float]]] -class OptaxSolver(GradientSolver[OptaxSolverState]): +class ScheduledSolver(GradientSolver[ScheduledSolverState]): def __init__(self, name: str, factory: t.Callable[..., GradientTransformation], hyperparams: t.Mapping[str, ScheduleLike], params: t.Iterable[ReconsVar]): self.factory: t.Callable[..., GradientTransformation] = factory #self.inner: GradientTransformationExtraArgs = optax.with_extra_args_support(solver) - self.hyperparams: t.Dict[str, OptaxScheduleWrapper] = {k: OptaxScheduleWrapper(v) for (k, v) in hyperparams.items()} + self.hyperparams: t.Dict[str, Schedule] = {k: process_schedule(v) for (k, v) in hyperparams.items()} self.params: t.FrozenSet[ReconsVar] = frozenset(params) self.name: str = name # or self.inner.__class__.__name__ - def init_state(self, sim: ReconsState) -> OptaxSolverState: + def init_state(self, sim: ReconsState) -> ScheduledSolverState: return ( None, - {k: v.init() for (k, v) in self.hyperparams.items()}, + {k: None for (k, v) in self.hyperparams.items()}, ) - def _resolve(self, hparams: t.Mapping[str, t.Optional[float]]) -> GradientTransformationExtraArgs: - return optax.with_extra_args_support( - self.factory(**{k: v(hparams[k]) for (k, v) in self.hyperparams.items()}) - ) + def _resolve(self, hparams: t.Mapping[str, t.Optional[float]]) -> GradientTransformation: + return self.factory(**{k: hparams[k] for k in self.hyperparams.keys()}) - def update_for_iter(self, sim: ReconsState, state: OptaxSolverState, niter: int) -> OptaxSolverState: - hparams_state: t.Dict[str, t.Optional[float]] = {k: v.update_for_iter(sim, state[1][k], niter) for (k, v) in self.hyperparams.items()} + def update_for_iter(self, sim: ReconsState, state: ScheduledSolverState, niter: int) -> ScheduledSolverState: + hparams_state: t.Dict[str, t.Optional[float]] = {k: v({'state': sim, 'niter': niter}) for (k, v) in self.hyperparams.items()} return ( self._resolve(hparams_state).init(params=extract_vars(sim, self.params)[0]) if state[0] is None else state[0], hparams_state ) def update( - self, sim: 'ReconsState', state: OptaxSolverState, grad: t.Dict[ReconsVar, numpy.ndarray], loss: float, - ) -> t.Tuple[t.Dict[ReconsVar, numpy.ndarray], OptaxSolverState]: + self, sim: 'ReconsState', state: ScheduledSolverState, grad: t.Dict[ReconsVar, numpy.ndarray], loss: float, + ) -> t.Tuple[t.Dict[ReconsVar, numpy.ndarray], ScheduledSolverState]: (inner_state, hparams_state) = state - hparams_state = {k: v.update(hparams_state[k]) for (k, v) in self.hyperparams.items()} (updates, inner_state) = self._resolve(hparams_state).update( grad, inner_state, params=extract_vars(sim, self.params)[0], value=loss, loss=loss ) return (t.cast(t.Dict[ReconsVar, t.Any], updates), (inner_state, hparams_state)) -class SGDSolver(OptaxSolver): +class SGDSolver(ScheduledSolver): def __init__(self, args: GradientSolverArgs, props: SGDSolverPlan): hparams = { 'learning_rate': props.learning_rate @@ -94,33 +107,33 @@ def __init__(self, args: GradientSolverArgs, props: SGDSolverPlan): if props.momentum is not None: hparams['momentum'] = props.momentum def factory(**kwargs: t.Any) -> GradientTransformation: - return optax.chain( - optax.trace(kwargs['momentum'], props.nesterov), - optax.scale_by_learning_rate(kwargs['learning_rate'], flip_sign=False), + return chain( + trace(kwargs['momentum'], props.nesterov), + scale_by_learning_rate(kwargs['learning_rate']), ) else: def factory(**kwargs: t.Any) -> GradientTransformation: - return optax.scale_by_learning_rate(kwargs['learning_rate'], flip_sign=False) + return scale_by_learning_rate(kwargs['learning_rate']) super().__init__('sgd', factory, hparams, args['params']) -class AdamSolver(OptaxSolver): +class AdamSolver(ScheduledSolver): def __init__(self, args: GradientSolverArgs, props: AdamSolverPlan): hparams = { 'learning_rate': props.learning_rate } def factory(**kwargs) -> GradientTransformation: - return optax.chain( - optax.scale_by_adam(props.b1, props.b2, props.eps, props.eps_root, nesterov=props.nesterov), - optax.scale_by_learning_rate(learning_rate=kwargs['learning_rate'], flip_sign=False), + return chain( + scale_by_adam(props.b1, props.b2, props.eps, props.eps_root, nesterov=props.nesterov), + scale_by_learning_rate(learning_rate=kwargs['learning_rate']), ) super().__init__('adam', factory, hparams, args['params']) -class PolyakSGDSolver(OptaxSolver): +class PolyakSGDSolver(ScheduledSolver): def __init__(self, args: GradientSolverArgs, props: PolyakSGDSolverPlan): hparams = { 'max_learning_rate': props.max_learning_rate, @@ -128,12 +141,157 @@ def __init__(self, args: GradientSolverArgs, props: PolyakSGDSolverPlan): } def factory(**kwargs) -> GradientTransformation: - return optax.chain( - optax.scale_by_learning_rate(kwargs['scaling'], flip_sign=False), - optax.scale_by_polyak( + return chain( + scale_by_learning_rate(kwargs['scaling']), + scale_by_polyak( max_learning_rate=kwargs['max_learning_rate'], f_min=props.f_min, eps=props.eps, #variant='sps', ) ) - super().__init__('polyak_sgd', factory, hparams, args['params']) \ No newline at end of file + super().__init__('polyak_sgd', factory, hparams, args['params']) + + +def chain( + *args: GradientTransformation +) -> GradientTransformation: + init_fns = tuple(arg.init for arg in args) + update_fns = tuple(arg.update for arg in args) + + def init_fn(params: Params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None, **extra_args): + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params, **extra_args) + new_state.append(new_s) + return updates, tuple(new_state) + + return GradientTransformation(init_fn, update_fn) + + +def trace( + decay: float, + nesterov: bool = False, + accumulator_dtype: t.Optional[t.Any] = None, +) -> GradientTransformation: + + def init_fn(params): + return tree.zeros_like(params, dtype=accumulator_dtype) + + def update_fn(updates: Updates, state: Updates, params=None, **extra_args: t.Any): + del params + f = lambda g, t: g + decay * t # noqa: E731 + new_trace = tree.map( + lambda g, t: None if g is None else f(g, t), + updates, + state.trace, + is_leaf=lambda g: g is None, + ) + updates = tree.map(f, updates, new_trace) if nesterov else new_trace + new_trace = tree.cast(new_trace, accumulator_dtype) + return updates, new_trace + + return GradientTransformation(init_fn, update_fn) + + +def scale_by_learning_rate( + learning_rate: float, *, + flip_sign: bool = False, +) -> GradientTransformation: + if flip_sign: + learning_rate *= -1 + + def update_fn(updates: Updates, state: None, params=None, **extra_args: t.Any): + del params + updates = tree.map(lambda g: learning_rate * g, updates) + return updates, state + + return GradientTransformation(lambda params: None, update_fn) + + +class ScaleByAdamState(t.NamedTuple): + n: NDArray[numpy.int32] # shape () + mu: Updates + nu: Updates + + +def scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: t.Optional[t.Any] = None, + *, + nesterov: bool = False, +) -> GradientTransformation: + def init_fn(params: Params) -> ScaleByAdamState: + xp = get_array_module(params) + mu = tree.zeros_like(params, dtype=mu_dtype) # First moment + nu = tree.zeros_like(params) # Second moment + return ScaleByAdamState(n=xp.zeros((), dtype=xp.int32), mu=mu, nu=nu) + + def update_fn( + updates: Updates, state: ScaleByAdamState, params: t.Any = None, **kwargs: t.Any + ) -> t.Tuple[Updates, ScaleByAdamState]: + xp = get_array_module(updates) + del params + mu = tree.update_moment(updates, state.mu, b1, 1) + nu = tree.update_moment_per_elem_norm(updates, state.nu, b2, 2) + n_inc = safe_increment(state.n) + + if nesterov: + mu_hat = tree.map( + lambda m, g: b1 * m + (1 - b1) * g, + tree.bias_correction(mu, b1, safe_increment(n_inc)), + tree.bias_correction(updates, b1, n_inc), + ) + else: + mu_hat = tree.bias_correction(mu, b1, n_inc) + + nu_hat = tree.bias_correction(nu, b2, n_inc) + updates = tree.map( + lambda m, v: None if m is None else m / (xp.sqrt(v + eps_root) + eps), + mu_hat, + nu_hat, + is_leaf=lambda x: x is None, + ) + mu = tree.cast(mu, mu_dtype) + return updates, ScaleByAdamState(n=n_inc, mu=mu, nu=nu) + + return GradientTransformation(init_fn, update_fn) + + +def scale_by_polyak( + f_min: float = 0.0, + max_learning_rate: float = 1.0, + eps: float = 0.0 +) -> GradientTransformation: + def update_fn( + updates: Updates, state: None, params: t.Any = None, *, value: float, **kwargs: t.Any + ): + del params + del kwargs + xp = get_array_module(updates) + grad_sq_norm = tree.squared_norm(updates) + gap = xp.array(value - f_min).astype(grad_sq_norm.dtype) + step = xp.where( + grad_sq_norm + eps <= xp.finfo(float).eps, + xp.array(0.0), + xp.minimum(gap / (grad_sq_norm + eps), max_learning_rate), + ) + updates = tree.scale(step, updates) + return updates, state + + return GradientTransformation(lambda params: None, update_fn) + + +def safe_increment(n: NDArray[numpy.int32]) -> NDArray[numpy.int32]: + xp = get_array_module(n) + + max_value = xp.iinfo(n.dtype).max + max_value = xp.array(max_value, dtype=n.dtype) + return xp.where( + n < max_value, n + xp.ones_like(n), max_value + ) \ No newline at end of file diff --git a/phaser/execute.py b/phaser/execute.py index 597d244..00d8b9b 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -8,7 +8,7 @@ import pane from phaser.types import EarlyTermination -from phaser.utils.num import cast_array_module, get_array_module, get_backend_module, xp_is_jax, Sampling, to_complex_dtype +from phaser.utils.num import Device, cast_array_module, get_array_module, get_backend_devices, get_backend_module, set_default_device, to_device, xp_is_jax, Sampling, to_complex_dtype, xp_is_torch from phaser.utils.object import ObjectSampling from phaser.utils.misc import unwrap from .hooks import EngineHook, Hook, ObjectHook, RawData @@ -50,7 +50,7 @@ def execute_engine( engine: EngineHook, ) -> PreparedRecons: xp = get_array_module(recons.state.object.data, recons.state.probe.data) - dtype = recons.patterns.patterns.dtype + dtype = recons.patterns.patterns.dtype.type plan = t.cast(EnginePlan, engine.props) engine_i = recons.state.iter.engine_num @@ -223,16 +223,36 @@ def load_raw_data( def initialize_reconstruction( - plan: ReconsPlan, *, xp: t.Any = None, seed: t.Any = None, - name: t.Optional[str] = None, + plan: ReconsPlan, *, xp: t.Any = None, device: t.Optional[Device] = None, + seed: t.Any = None, name: t.Optional[str] = None, init_state: t.Union[ReconsState, PartialReconsState, None] = None, observers: t.Union[Observer, t.Iterable[Observer], None] = None, override_observers: t.Union[Observer, t.Iterable[Observer], None] = None, ) -> PreparedRecons: - xp = cast_array_module(get_backend_module(plan.backend) if xp is None else xp) + logging.basicConfig(level=logging.INFO) + + if xp is not None: + xp = cast_array_module(xp) + # TODO: nicer output here + logging.info(f"Using manually-specified backend {xp}") + devices = get_backend_devices(xp) + logging.info(f"Available devices: {list(devices)}") + manual = device is not None + device = to_device(device, xp) if device is not None else devices[0] + logging.info(f"Using {'manually-specified ' if manual else ''}device {device}") + else: + xp = get_backend_module(plan.backend) + logging.info(f"Using {'plan-specified' if plan.backend is not None else 'default'} backend {xp}") + devices = get_backend_devices(xp) + logging.info(f"Available devices: {list(devices)}") + + device = to_device(plan.device, xp) if plan.device is not None else devices[0] + logging.info(f"Using {'plan-specified ' if plan.device is not None else ''}device {device}") + + set_default_device(device, xp) + observer = _normalize_observers(observers, override_observers) - logging.basicConfig(level=logging.INFO) logging.info("Executing plan...") observer.init_recons(plan) @@ -360,8 +380,8 @@ def initialize_reconstruction( def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine: EnginePlan) -> t.Tuple[Patterns, ReconsState]: # TODO: more graceful - if isinstance(engine, GradientEnginePlan) and not xp_is_jax(xp): - raise ValueError("The gradient descent engine requires the jax backend.") + if isinstance(engine, GradientEnginePlan) and not (xp_is_jax(xp) or xp_is_torch(xp)): + raise ValueError("The gradient descent engine requires the 'jax' or 'torch' backend.") state = state.to_xp(xp) diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index ff86e08..e8d38a5 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -202,7 +202,7 @@ class PostInitHook(Hook[PostInitArgs, t.Tuple['Patterns', 'ReconsState']]): class EngineArgs(t.TypedDict): data: 'Patterns' state: 'ReconsState' - dtype: DTypeLike + dtype: t.Type[numpy.floating] xp: t.Any recons_name: str observer: 'Observer' diff --git a/phaser/hooks/preprocessing.py b/phaser/hooks/preprocessing.py index 60b7813..20bd6fe 100644 --- a/phaser/hooks/preprocessing.py +++ b/phaser/hooks/preprocessing.py @@ -59,7 +59,7 @@ def add_poisson_noise(raw_data: RawData, props: PoissonProps) -> RawData: logger.info(f"Mean pattern intensity: {numpy.nanmean(numpy.nansum(patterns, axis=(-1, -2)))}") - raw_data['patterns'] = xp.array(patterns) + raw_data['patterns'] = xp.asarray(patterns) return raw_data @@ -79,7 +79,7 @@ def drop_nan_patterns(args: PostInitArgs, props: DropNanProps) -> t.Tuple[Patter logger.info(f"Dropping {n}/{patterns.shape[0]} patterns which are at least {props.threshold:.1%} NaN values") patterns = patterns[~mask] - if scan.shape[0] == mask.size: + if scan.shape[0] == xp.size(mask): # apply mask to scan as well scan = scan[~mask] elif scan.shape[0] != patterns.shape[0]: @@ -111,7 +111,7 @@ def diffraction_align(args: PostInitArgs, props: t.Any = None) -> t.Tuple[Patter sum_pattern = xp.zeros(patterns.patterns.shape[-2:], dtype=patterns.patterns.dtype) for group in groups: - pats = xp.array(patterns.patterns[tuple(group)]) * xp.array(patterns.pattern_mask) + pats = xp.asarray(patterns.patterns[tuple(group)]) * xp.asarray(patterns.pattern_mask) sum_pattern += t.cast(NDArray[numpy.floating], xp.nansum(pats, axis=tuple(range(pats.ndim - 2)))) mean_pattern = sum_pattern / math.prod(patterns.patterns.shape[:-2]) diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index 6e402fb..1c9f31c 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -21,7 +21,7 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.flo ) if props.affine is not None: - affine = xp.array(props.affine, dtype=scan.dtype) + affine = xp.asarray(props.affine, dtype=scan.dtype) # equivalent to (affine @ scan.T).T (active transformation) scan = scan @ affine.T diff --git a/phaser/hooks/schedule.py b/phaser/hooks/schedule.py index ceac461..017abe0 100644 --- a/phaser/hooks/schedule.py +++ b/phaser/hooks/schedule.py @@ -3,7 +3,7 @@ import numpy -from ..types import Dataclass, Flag, process_schedule +from ..types import Dataclass, SimpleFlag, process_schedule from .hook import Hook if t.TYPE_CHECKING: @@ -23,7 +23,9 @@ class ScheduleHook(Hook[FlagArgs, float]): known = {} -FlagLike: t.TypeAlias = t.Union[bool, Flag, FlagHook] +Flag: t.TypeAlias = t.Callable[['FlagArgs'], bool] +Schedule: t.TypeAlias = t.Callable[['FlagArgs'], float] +FlagLike: t.TypeAlias = t.Union[bool, SimpleFlag, FlagHook] ScheduleLike: t.TypeAlias = t.Union[float, ScheduleHook] diff --git a/phaser/plan.py b/phaser/plan.py index 1326638..8690eae 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -1,7 +1,7 @@ from pathlib import Path import typing as t -from .types import Dataclass, Slices, BackendName, Flag, ReconsVars, IsVersion, EmptyDict +from .types import Dataclass, Slices, BackendName, SimpleFlag, ReconsVars, IsVersion, EmptyDict from .hooks import RawDataHook, ProbeHook, ObjectHook, ScanHook, EngineHook, PostInitHook, PostLoadHook, TiltHook from .hooks.solver import NoiseModelHook, ConventionalSolverHook, PositionSolverHook, GradientSolverHook from .hooks.schedule import FlagLike, ScheduleLike @@ -64,7 +64,7 @@ class EnginePlan(Dataclass, kw_only=True): update_positions: FlagLike = False update_tilt: FlagLike = False - calc_error: FlagLike = Flag(every=1) + calc_error: FlagLike = SimpleFlag(every=1) calc_error_fraction: float = 0.1 save: FlagLike = False @@ -180,6 +180,7 @@ class ReconsPlan(Dataclass, kw_only=True): name: str backend: t.Optional[BackendName] = None + device: t.Optional[str] = None dtype: t.Literal['float32', 'float64'] = 'float32' wavelength: t.Optional[float] = None diff --git a/phaser/state.py b/phaser/state.py index f09b98a..076b211 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -5,16 +5,16 @@ from typing_extensions import Self from phaser.utils.num import Sampling, to_numpy, get_array_module, Float -from phaser.utils.misc import jax_dataclass +from phaser.utils.tree import tree_dataclass from phaser.utils.object import ObjectSampling if t.TYPE_CHECKING: from phaser.utils.io import HdfLike - from phaser.utils.image import _BoundaryMode + from phaser.utils.image import _InterpBoundaryMode from phaser.observer import Observer, ObserverSet -@jax_dataclass +@tree_dataclass class Patterns(): patterns: NDArray[numpy.floating] """Raw diffraction patterns, with 0-frequency sample in corner""" @@ -27,7 +27,7 @@ def to_numpy(self) -> Self: ) -@jax_dataclass +@tree_dataclass class IterState(): engine_num: int """Engine number. 1-indexed (0 means before any reconstruction).""" @@ -57,7 +57,7 @@ def empty() -> 'IterState': return IterState(0, 0, 0) -@jax_dataclass(static_fields=('sampling',)) +@tree_dataclass(static_fields=('sampling',)) class ProbeState(): sampling: Sampling """Probe coordinate system. See `Sampling` for more details.""" @@ -68,7 +68,7 @@ def resample( self, new_samp: Sampling, rotation: float = 0.0, order: int = 1, - mode: '_BoundaryMode' = 'grid-constant', + mode: '_InterpBoundaryMode' = 'grid-constant', ) -> Self: new_data = self.sampling.resample( self.data, new_samp, @@ -80,7 +80,7 @@ def resample( def to_xp(self, xp: t.Any) -> Self: return self.__class__( - self.sampling, xp.array(self.data) + self.sampling, xp.asarray(self.data) ) def to_numpy(self) -> Self: @@ -93,7 +93,7 @@ def copy(self) -> Self: return copy.deepcopy(self) -@jax_dataclass(static_fields=('sampling',)) +@tree_dataclass(static_fields=('sampling',)) class ObjectState(): sampling: ObjectSampling """Object coordinate system. See `ObjectSampling` for more details.""" @@ -107,7 +107,7 @@ class ObjectState(): def to_xp(self, xp: t.Any) -> Self: return self.__class__( - self.sampling, xp.array(self.data), xp.array(self.thicknesses) + self.sampling, xp.asarray(self.data), xp.asarray(self.thicknesses) ) def to_numpy(self) -> Self: @@ -118,7 +118,7 @@ def to_numpy(self) -> Self: def zs(self) -> NDArray[numpy.floating]: xp = get_array_module(self.thicknesses) if len(self.thicknesses) < 2: - return xp.array([0.], dtype=self.thicknesses.dtype) + return xp.asarray([0.], dtype=self.thicknesses.dtype) return xp.cumsum(self.thicknesses) - self.thicknesses def copy(self) -> Self: @@ -126,7 +126,7 @@ def copy(self) -> Self: return copy.deepcopy(self) -@jax_dataclass +@tree_dataclass class ProgressState: iters: NDArray[numpy.integer] """Iterations error measurements were taken at.""" @@ -162,7 +162,7 @@ def __eq__(self, other: t.Any) -> bool: xp.array_equal(self.detector_errors, other.detector_errors) ) -@jax_dataclass(kw_only=True, static_fields=('progress',)) +@tree_dataclass(kw_only=True, static_fields=('progress',)) class ReconsState: iter: IterState wavelength: Float @@ -180,8 +180,8 @@ def to_xp(self, xp: t.Any) -> Self: iter=self.iter, probe=self.probe.to_xp(xp), object=self.object.to_xp(xp), - scan=xp.array(self.scan), - tilt=None if self.tilt is None else xp.array(self.tilt), + scan=xp.asarray(self.scan), + tilt=None if self.tilt is None else xp.asarray(self.tilt), progress=self.progress, wavelength=self.wavelength, ) @@ -211,7 +211,7 @@ def read_hdf5(file: 'HdfLike') -> 'ReconsState': return hdf5_read_state(file).to_complete() -@jax_dataclass(kw_only=True, static_fields=('progress',)) +@tree_dataclass(kw_only=True, static_fields=('progress',)) class PartialReconsState: iter: t.Optional[IterState] = None wavelength: t.Optional[Float] = None @@ -260,7 +260,7 @@ def read_hdf5(file: 'HdfLike') -> 'PartialReconsState': return hdf5_read_state(file) -@jax_dataclass(static_fields=('name', 'observer')) +@tree_dataclass(static_fields=('name', 'observer')) class PreparedRecons: patterns: Patterns state: ReconsState diff --git a/phaser/types.py b/phaser/types.py index 5714612..3f1006f 100644 --- a/phaser/types.py +++ b/phaser/types.py @@ -12,7 +12,7 @@ if t.TYPE_CHECKING: from phaser.state import ReconsState - from phaser.hooks.schedule import FlagArgs, FlagLike, ScheduleLike + from phaser.hooks.schedule import FlagArgs, FlagLike, Flag, ScheduleLike, Schedule T = t.TypeVar('T') @@ -75,7 +75,7 @@ def __hash__(self) -> int: return hash(self.__class__.__name__) -BackendName: t.TypeAlias = t.Literal['cuda', 'cupy', 'jax', 'cpu', 'numpy'] +BackendName: t.TypeAlias = t.Literal['cupy', 'jax', 'torch', 'numpy'] ReconsVar: t.TypeAlias = t.Literal['object', 'probe', 'positions', 'tilt'] ReconsVars: t.TypeAlias = t.Annotated[t.FrozenSet[ReconsVar], _ReconsVarsAnnotation()] @@ -122,7 +122,7 @@ def thicknesses(self) -> t.List[float]: Slices: t.TypeAlias = t.Union[SliceList, SliceStep, SliceTotal] -class Flag(Dataclass): +class SimpleFlag(Dataclass): after: int = 0 every: int = 1 before: t.Optional[int] = None @@ -154,21 +154,21 @@ def __call__(self, args: 'FlagArgs') -> bool: @lru_cache -def process_flag(flag: 'FlagLike') -> t.Callable[['FlagArgs'], bool]: +def process_flag(flag: 'FlagLike') -> 'Flag': if isinstance(flag, bool): return _ConstFlag(flag) return flag @lru_cache -def process_schedule(schedule: 'ScheduleLike') -> t.Callable[['FlagArgs'], float]: +def process_schedule(schedule: 'ScheduleLike') -> 'Schedule': if isinstance(schedule, (int, float)): return lambda _: schedule return schedule def flag_any_true(flag: t.Callable[['FlagArgs'], bool], niter: int) -> bool: - if isinstance(flag, Flag): + if isinstance(flag, SimpleFlag): return flag.any_true(niter) elif isinstance(flag, _ConstFlag): return flag.val diff --git a/phaser/utils/_cuda_kernels.py b/phaser/utils/_cuda_kernels.py index 6c85523..b5e6d44 100644 --- a/phaser/utils/_cuda_kernels.py +++ b/phaser/utils/_cuda_kernels.py @@ -1,10 +1,13 @@ import functools +import re import typing as t import cupy # pyright: ignore[reportMissingImports] import numpy +from phaser.utils.misc import _MockModule + # grid # block # thread @@ -211,3 +214,34 @@ def _get_cutout_kernel(dtype: numpy.dtype, operation: str) -> cupy.RawKernel: """, kernel_name) kernel.compile() return kernel + + +def get_devices() -> t.Tuple[str, ...]: + n: int = cupy.cuda.runtime.getDeviceCount() + return tuple(f'cuda:{i}' for i in range(n)) + + +def to_device(device: t.Union[str, int, cupy.cuda.Device]) -> cupy.cuda.Device: + if isinstance(device, (int, cupy.cuda.Device)): + return cupy.cuda.Device(device) + device = str(device) + if (match := re.fullmatch(r'cuda:(\d)+', device)): + return cupy.cuda.Device(int(match[1])) + raise ValueError(f"Invalid device '{device}'") + + +def set_default_device(device: cupy.cuda.Device): + if not isinstance(device, cupy.cuda.Device): + raise TypeError(f"Invalid device '{device}' for backend cupy") + device.use() + + +def _wrap_call(f, *args: t.Any, **kwargs: t.Any) -> t.Any: + if (device := kwargs.pop('device', None)) is not None: + with to_device(device): + return f(*args, **kwargs) + + return f(*args, **kwargs) + + +mock_cupy = _MockModule(cupy, {}, _wrap_call) \ No newline at end of file diff --git a/phaser/utils/_jax_kernels.py b/phaser/utils/_jax_kernels.py index a0be3be..b0fcc0c 100644 --- a/phaser/utils/_jax_kernels.py +++ b/phaser/utils/_jax_kernels.py @@ -7,6 +7,9 @@ import jax.numpy as jnp # pyright: ignore[reportMissingImports] +Device: t.TypeAlias = t.Any + + def to_nd(arr: jax.Array, n: int) -> jax.Array: if arr.ndim > n: arr = arr.reshape(-1, *arr.shape[arr.ndim - n + 1:]) @@ -99,4 +102,49 @@ def affine_transform( return jax.vmap( lambda a: jax.scipy.ndimage.map_coordinates(a, tuple(coords), order=order, mode=jax_mode, cval=cval), - )(to_nd(input, n_axes + 1)).reshape((*input.shape[:-n_axes], *output_shape)) \ No newline at end of file + )(to_nd(input, n_axes + 1)).reshape((*input.shape[:-n_axes], *output_shape)) + + +def get_devices() -> t.Tuple[Device, ...]: + devices = [] + + for backend in ('gpu', 'tpu', 'cpu'): + try: + devices.extend(jax.devices(backend)) + except RuntimeError: + pass + + return tuple(devices) + + +def to_device(device: t.Union[str, Device]) -> Device: + if isinstance(device, jax.Device): + return device + + split = device.rsplit(':', maxsplit=1) + if len(split) == 1: + [backend] = split + index = 0 + else: + [backend, index] = split + index = int(index) + + try: + backend_devices = jax.devices(backend) + except RuntimeError: + raise RuntimeError(f"Can't use device '{device}': jax backend '{backend}' is unavailable") + + try: + return backend_devices[index] + except IndexError: + pass + if len(backend_devices) == 0: + raise RuntimeError(f"Can't use device '{device}': No available devices on jax backend '{backend}'") + raise RuntimeError(f"Can't use device '{device}': Device index {index} not available" + f" ({len(backend_devices)} device(s) on jax backend '{backend}')") + + +def set_default_device(device: Device): + if not isinstance(device, jax.Device): + raise TypeError(f"Invalid device '{device}' for backend jax") + jax.config.update('jax_default_device', device) \ No newline at end of file diff --git a/phaser/utils/_torch_kernels.py b/phaser/utils/_torch_kernels.py new file mode 100644 index 0000000..1817d1d --- /dev/null +++ b/phaser/utils/_torch_kernels.py @@ -0,0 +1,449 @@ +import functools +import itertools +import operator +from types import ModuleType +import typing as t + +import numpy +from numpy.typing import ArrayLike +import torch + +from phaser.utils.num import _PadMode +from phaser.utils.image import _InterpBoundaryMode +from phaser.utils.misc import _MockModule + + +def get_cutouts(obj: torch.Tensor, start_idxs: torch.Tensor, cutout_shape: t.Tuple[int, int]) -> torch.Tensor: + #out_shape = (*start_idxs.shape[:-1], *obj.shape[:-2], *cutout_shape) + ys, xs = torch.arange(cutout_shape[0]), torch.arange(cutout_shape[1]) + yy, xx = torch.meshgrid(ys, xs, indexing='ij') + yy = start_idxs[..., 0][..., None, None] + yy + xx = start_idxs[..., 1][..., None, None] + xx + + out = obj[..., yy, xx] + if obj.ndim > 2: + # oof + out = torch.permute(out, (*(i + obj.ndim - 2 for i in range(start_idxs.ndim - 1)), *range(obj.ndim - 2), -2, -1)) + #assert out.shape == out_shape + return out + + +class _MockTensor(torch.Tensor): + #@property + #def dtype(self) -> t.Type[numpy.generic]: + # return to_numpy_dtype(super().dtype) + + @property + def T(self) -> '_MockTensor': # pyright: ignore[reportIncompatibleVariableOverride] + if self.ndim <= 2: + return _MockTensor(super().T) + return t.cast(_MockTensor, self.permute(*range(self.ndim - 1, -1, -1))) + + def astype(self, dtype: t.Union[str, torch.dtype, t.Type[numpy.generic]]) -> '_MockTensor': + return t.cast(_MockTensor, self.to(to_torch_dtype(dtype))) + + +_TORCH_TO_NUMPY_DTYPE: t.Dict[torch.dtype, t.Type[numpy.generic]] = { + torch.bool : numpy.bool, + torch.uint8 : numpy.uint8, + torch.int8 : numpy.int8, + torch.int16 : numpy.int16, + torch.int32 : numpy.int32, + torch.int64 : numpy.int64, + torch.float16 : numpy.float16, + torch.float32 : numpy.float32, + torch.float64 : numpy.float64, + torch.complex64 : numpy.complex64, + torch.complex128 : numpy.complex128, +} + +_NUMPY_TO_TORCH_DTYPE: t.Dict[t.Type[numpy.generic], torch.dtype] = { + numpy.bool : torch.bool, + numpy.uint8 : torch.uint8, + numpy.int8 : torch.int8, + numpy.int16 : torch.int16, + numpy.int32 : torch.int32, + numpy.int64 : torch.int64, + numpy.float16 : torch.float16, + numpy.float32 : torch.float32, + numpy.float64 : torch.float64, + numpy.complex64 : torch.complex64, + numpy.complex128 : torch.complex128, +} + + +def to_torch_dtype(dtype: t.Union[str, torch.dtype, numpy.dtype, t.Type[numpy.generic]]) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, numpy.dtype): + dtype = dtype.type + elif not isinstance(dtype, type) or not issubclass(dtype, numpy.generic): + dtype = numpy.dtype(dtype).type + + try: + return _NUMPY_TO_TORCH_DTYPE[dtype] + except KeyError: + raise ValueError(f"Can't convert dtype '{dtype}' to a PyTorch dtype") + + +def to_numpy_dtype(dtype: t.Union[str, torch.dtype, numpy.dtype, t.Type[numpy.generic]]) -> t.Type[numpy.generic]: + if isinstance(dtype, str): + return numpy.dtype(dtype).type + if isinstance(dtype, numpy.dtype): + return dtype.type + if isinstance(dtype, torch.dtype): + return _TORCH_TO_NUMPY_DTYPE[dtype] + return dtype + + +def _mirror(idx: torch.Tensor, size: int) -> torch.Tensor: + s = size -1 + return torch.abs((idx + s) % (2 * s) - s) + + +_BOUNDARY_FNS: t.Dict[str, t.Callable[[torch.Tensor, int], torch.Tensor]] = { + 'nearest': lambda idx, size: torch.clip(idx, 0, size - 1), + 'grid-wrap': lambda idx, size: idx % size, + 'reflect': lambda idx, size: torch.floor_divide(_mirror(2*idx+1, 2*size+1), 2), + 'mirror': _mirror, +} + +_PAD_MODE_MAP: t.Dict[_PadMode, str] = { + 'constant': 'constant', + 'edge': 'replicate', + 'reflect': 'reflect', + 'wrap': 'circular', +} + +def min( + arr: torch.Tensor, axis: t.Union[int, t.Tuple[int, ...], None] = None, *, + keepdims: bool = False +) -> torch.Tensor: + if axis is None: + if keepdims: + return torch.min(arr).reshape((1,) * arr.ndim) + return torch.min(arr) + return torch.amin(arr, axis, keepdim=keepdims) + + +def max( + arr: torch.Tensor, axis: t.Union[int, t.Tuple[int, ...], None] = None, *, + keepdims: bool = False +) -> torch.Tensor: + if axis is None: + if keepdims: + return torch.max(arr).reshape((1,) * arr.ndim) + return torch.max(arr) + return torch.amax(arr, axis, keepdim=keepdims) + + +def nanmin( + arr: torch.Tensor, axis: t.Union[int, t.Tuple[int, ...], None] = None, *, + keepdims: bool = False +) -> torch.Tensor: + return min(torch.nan_to_num(arr, nan=torch.inf), axis, keepdims=keepdims) + + +def nanmax( + arr: torch.Tensor, axis: t.Union[int, t.Tuple[int, ...], None] = None, *, + keepdims: bool = False +) -> torch.Tensor: + return max(torch.nan_to_num(arr, nan=-torch.inf), axis, keepdims=keepdims) + + +def minimum( + x1: ArrayLike, x2: ArrayLike +) -> torch.Tensor: + if not isinstance(x1, torch.Tensor): + x1 = _MockTensor(torch.asarray(x1)) + if not isinstance(x2, torch.Tensor): + x2 = _MockTensor(torch.asarray(x2)) + + return torch.minimum(x1, x2) + + +def maximum( + x1: ArrayLike, x2: ArrayLike +) -> torch.Tensor: + if not isinstance(x1, torch.Tensor): + x1 = _MockTensor(torch.asarray(x1)) + if not isinstance(x2, torch.Tensor): + x2 = _MockTensor(torch.asarray(x2)) + + return torch.maximum(x1, x2) + + +def split( + arr: torch.Tensor, sections: int, *, axis: int = 0 +) -> t.Tuple[torch.Tensor, ...]: + if arr.shape[axis] % sections != 0: + raise ValueError("array split does not result in an equal division") + return torch.split(arr, arr.shape[axis] // sections, axis) + + +def pad( + arr: torch.Tensor, pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *, + mode: _PadMode = 'constant', cval: float = 0. +) -> torch.Tensor: + if mode not in ('constant', 'edge', 'reflect', 'wrap'): + raise ValueError(f"Unsupported padding mode '{mode}'") + + pad = (pad_width, pad_width) if isinstance(pad_width, int) else pad_width + + if isinstance(pad[0], int): + pad = (pad,) + + if len(pad) == 1: + pad = tuple(pad) * arr.ndim + elif len(pad) != arr.ndim: + raise ValueError(f"Invalid `pad_width` '{pad_width}'.") + + pad = tuple(itertools.chain.from_iterable(t.cast(t.Sequence[t.Tuple[int, int]], reversed(pad)))) + + kwargs = {'value': cval} if mode == 'constant' else {} + return _MockTensor(torch.nn.functional.pad(arr, pad, mode=_PAD_MODE_MAP[mode], **kwargs)) + + +def unwrap(arr: torch.Tensor, discont: t.Optional[float] = None, axis: int = -1, *, + period: float = 2.*numpy.pi) -> torch.Tensor: + if discont is None: + discont = period / 2 + + diff = torch.diff(arr, dim=axis) + dtype = torch.result_type(diff, period) + + if dtype.is_floating_point: + interval_high = period / 2 + boundary_ambiguous = True + else: + interval_high, rem = divmod(period, 2) + boundary_ambiguous = rem == 0 + + interval_low = -interval_high + diffmod = torch.remainder(diff - interval_low, period) + interval_low + if boundary_ambiguous: + diffmod[(diffmod == interval_low) & (diff > 0)] = interval_high + + phase_correct = diffmod - diff + phase_correct[abs(diff) < discont] = 0. + + prepend_shape = list(arr.shape) + prepend_shape[axis] = 1 + return arr + torch.cat([torch.zeros(prepend_shape, dtype=dtype), torch.cumsum(phase_correct, axis)], dim=axis) + + +def indices( + shape: t.Tuple[int, ...], dtype: t.Union[str, None, t.Type[numpy.generic], torch.dtype] = None, sparse: bool = False +) -> t.Union[torch.Tensor, t.Tuple[torch.Tensor, ...]]: + dtype = to_torch_dtype(dtype) if dtype is not None else torch.int64 + + n = len(shape) + + if sparse: + return tuple( + _MockTensor(torch.arange(s, dtype=dtype).reshape((1,) * i + (s,) + (1,) * (n - i - 1))) + for (i, s) in enumerate(shape) + ) + + arrs = tuple(torch.arange(s, dtype=dtype) for s in shape) + return _MockTensor(torch.stack(torch.meshgrid(*arrs, indexing='ij'), dim=0)) + + +def size(arr: torch.Tensor, axis: t.Optional[int]) -> int: + return arr.size(axis) if axis is not None else arr.numel() + + +def asarray( + arr: t.Any, dtype: t.Union[str, torch.dtype, numpy.dtype, t.Type[numpy.generic], None] = None, *, + copy: t.Optional[bool] = None, +) -> _MockTensor: + dtype = to_torch_dtype(dtype) if dtype is not None else None + requires_grad = arr.requires_grad if isinstance(arr, torch.Tensor) else False + + if isinstance(arr, numpy.ndarray) and arr.flags['WRITEABLE'] and not copy: + device = torch.get_default_device() + if device.type == 'cuda': + return _MockTensor(torch.from_numpy(arr).to(device=device, dtype=dtype, non_blocking=True)) + + return _MockTensor(torch.asarray(arr, dtype=dtype, requires_grad=requires_grad, copy=copy)) + + +def affine_transform( + input: torch.Tensor, matrix: ArrayLike, + offset: t.Optional[ArrayLike] = None, + output_shape: t.Optional[t.Tuple[int, ...]] = None, + order: int = 1, mode: _InterpBoundaryMode = 'grid-constant', + cval: ArrayLike = 0.0, +) -> torch.Tensor: + + if output_shape is None: + output_shape = input.shape + n_axes = len(output_shape) # num axes to transform over + + idxs = t.cast(torch.Tensor, indices(output_shape, dtype=torch.float64)) + + matrix = asarray(matrix) + if matrix.size() == (n_axes + 1, n_axes + 1): + # homogenous transform matrix + coords = torch.tensordot( + matrix, torch.stack((*idxs, torch.ones_like(idxs[0])), dim=0), dims=1 + )[:-1] + elif matrix.size() == (n_axes,): + coords = (idxs.T * matrix + asarray(offset)).T + else: + raise ValueError(f"Expected matrix of shape ({n_axes + 1}, {n_axes + 1}) or ({n_axes},), instead got shape {matrix.shape}") + + return _MockTensor(torch.vmap( + lambda a: map_coordinates(a, coords, order=order, mode=mode, cval=cval) + )(input.reshape(-1, *input.shape[-n_axes:])).reshape((*input.shape[:-n_axes], *output_shape))) + + +def map_coordinates( + arr: torch.Tensor, coordinates: torch.Tensor, + order: int = 1, mode: _InterpBoundaryMode = 'grid-constant', + cval: ArrayLike = 0.0 +) -> torch.Tensor: + from phaser.utils.num import to_real_dtype + if arr.ndim != coordinates.shape[0]: + raise ValueError("invalid shape for coordinate array") + + if order not in (0, 1): + raise ValueError(f"Interpolation order {order} not supported (torch currently only supports order=0, 1)") + + if mode == 'grid-constant': + return _map_coordinates_constant( + arr, coordinates, order=order, cval=cval + ) + + remap_fn = _BOUNDARY_FNS.get(mode) + if remap_fn is None: + raise ValueError(f"Interpolation mode '{mode}' not supported (torch supports one of " + "('constant', 'nearest', 'reflect', 'mirror', 'grid-wrap'))") + + weight_dtype = to_torch_dtype(to_real_dtype(to_numpy_dtype(arr.dtype))) + + ax_nodes: t.List[t.Tuple[t.Tuple[torch.Tensor, torch.Tensor], ...]] = [] + + for ax_coords, size in zip(coordinates, arr.shape): + if order == 1: + lower = torch.floor(ax_coords) + upper_weight = ax_coords - lower + lower_idx = lower.type(torch.int32) + ax_nodes.append(( + (remap_fn(lower_idx, size), 1.0 - upper_weight), + (remap_fn(lower_idx + 1, size), upper_weight), + )) + else: + idx = torch.round(ax_coords).type(torch.int32) + ax_nodes.append(((remap_fn(idx, size), torch.ones((), dtype=weight_dtype)),)) + + outputs = [] + for corner in itertools.product(*ax_nodes): + idxs, weights = zip(*corner) + outputs.append(arr[idxs] * functools.reduce(operator.mul, weights)) + + result = functools.reduce(operator.add, outputs) + return _MockTensor(result.type(arr.dtype)) + + +def _map_coordinates_constant( + arr: torch.Tensor, coordinates: torch.Tensor, + order: int = 1, cval: ArrayLike = 0.0 +) -> torch.Tensor: + from phaser.utils.num import to_real_dtype + weight_dtype = to_torch_dtype(to_real_dtype(to_numpy_dtype(arr.dtype))) + cval = torch.tensor(cval) + + is_valid = lambda idx, size: (0 <= idx) & (idx < size) # noqa: E731 + clip = lambda idx, size: torch.clip(idx, 0, size - 1) # noqa: E731 + + ax_nodes: t.List[t.Tuple[t.Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]] = [] + + for ax_coords, size in zip(coordinates, arr.shape): + if order == 1: + lower = torch.floor(ax_coords) + upper_weight = ax_coords - lower + lower_idx = lower.type(torch.int32) + ax_nodes.append(( + (clip(lower_idx, size), is_valid(lower_idx, size), 1.0 - upper_weight), + (clip(lower_idx + 1, size), is_valid(lower_idx + 1, size), upper_weight), + )) + else: + idx = torch.round(ax_coords).type(torch.int32) + ax_nodes.append(((clip(idx, size), is_valid(idx, size), torch.ones((), dtype=weight_dtype)),)) + + outputs = [] + for corner in itertools.product(*ax_nodes): + idxs, valids, weights = zip(*corner) + val = torch.where(functools.reduce(operator.and_, valids), arr[idxs], cval) + outputs.append(val * functools.reduce(operator.mul, weights)) + + result = functools.reduce(operator.add, outputs) + return result.type(arr.dtype) + + +def get_devices() -> t.Tuple[torch.device, ...]: + devices = [] + devices.extend(f'cuda:{i}' for i in range(torch.cuda.device_count())) + + if torch.backends.mps.is_available(): + devices.append('mps') + + return tuple(map(torch.device, devices)) + + +def to_device(device: t.Union[str, torch.device]) -> torch.device: + if isinstance(device, torch.device): + return device + return torch.device(device) + + +def set_default_device(device: torch.device): + if not isinstance(device, torch.device): + raise TypeError(f"Invalid device '{device}' for backend torch") + torch.set_default_device(device) + + +def _wrap_call(f, *args: t.Any, **kwargs: t.Any) -> t.Any: + try: + kwargs['dtype'] = to_torch_dtype(kwargs['dtype']) + except KeyError: + pass + + try: + kwargs['dim'] = kwargs.pop('axes') + except KeyError: + try: + kwargs['dim'] = kwargs.pop('axis') + except KeyError: + pass + + if f is torch.asarray and isinstance(args[0], numpy.ndarray): + if not args[0].flags['W']: + raise ValueError() + + result = f(*args, **kwargs) + # TODO: deal with tuples of output, pytrees, etc. here + # this will result in some nasty bugs + if isinstance(result, torch.Tensor): + return _MockTensor(result) + return result + + +mock_torch = _MockModule(torch, { + 'torch.array': functools.update_wrapper(lambda *args, **kwargs: _MockTensor(_wrap_call(torch.asarray, *args, **kwargs)), torch.asarray), # type: ignore + 'torch.asarray': asarray, + 'torch.mod': functools.update_wrapper(lambda *args, **kwargs: _MockTensor(_wrap_call(torch.remainder, *args, **kwargs)), torch.remainder), # type: ignore + 'torch.split': split, + 'torch.pad': pad, + 'torch.min': min, 'torch.max': max, + 'torch.nanmin': nanmin, 'torch.nanmax': nanmax, + 'torch.minimum': minimum, 'torch.maximum': maximum, + 'torch.unwrap': unwrap, + 'torch.indices': indices, + 'torch.size': size, + 'torch.iscomplexobj': lambda arr: torch.is_complex(arr), + 'torch.isrealobj': lambda arr: not torch.is_complex(arr), +}, _wrap_call) + +mock_torch._MockTensor = _MockTensor # type: ignore \ No newline at end of file diff --git a/phaser/utils/image.py b/phaser/utils/image.py index 45c1655..6b92ada 100644 --- a/phaser/utils/image.py +++ b/phaser/utils/image.py @@ -8,7 +8,7 @@ import numpy from numpy.typing import ArrayLike, NDArray -from .num import get_array_module, get_scipy_module, to_numpy, at, is_jax, abs2 +from .num import get_array_module, get_scipy_module, to_numpy, at, abs2, xp_is_jax, xp_is_torch NumT = t.TypeVar('NumT', bound=numpy.number) @@ -131,7 +131,7 @@ def scale_to_integral_type( return (xp.clip((imax + 1) / (vmax - vmin) * (arr - vmin), 0, imax)).astype(dtype) -_BoundaryMode: t.TypeAlias = t.Literal['constant', 'nearest', 'mirror', 'reflect', 'wrap', 'grid-mirror', 'grid-wrap', 'grid-constant'] +_InterpBoundaryMode: t.TypeAlias = t.Literal['constant', 'nearest', 'mirror', 'reflect', 'wrap', 'grid-mirror', 'grid-wrap', 'grid-constant'] def to_affine_matrix(arr: ArrayLike, ndim: int = 2) -> NDArray[numpy.floating]: @@ -182,19 +182,24 @@ def affine_transform( offset: t.Optional[ArrayLike] = None, output_shape: t.Optional[t.Tuple[int, ...]] = None, order: int = 1, - mode: _BoundaryMode = 'grid-constant', + mode: _InterpBoundaryMode = 'grid-constant', cval: t.Union[NumT, float] = 0.0, ) -> NDArray[NumT]: if mode in ('constant', 'wrap'): # these modes aren't supported by jax raise ValueError(f"Resampling mode '{mode}' not supported (try 'grid-constant' or 'grid-wrap' instead)") - xp = get_array_module(input, matrix, offset) - scipy = get_scipy_module(input, matrix, offset) - if is_jax(input): - if order > 1: + if xp_is_torch(xp): + from ._torch_kernels import affine_transform, torch + return t.cast(NDArray[NumT], affine_transform( + t.cast(torch.Tensor, input), matrix, offset, + output_shape, order, mode, cval + )) + + if xp_is_jax(xp): + if order not in (0, 1): raise ValueError(f"Interpolation order {order} not supported (jax currently only supports order=0, 1)") from ._jax_kernels import affine_transform, jax return t.cast(NDArray[NumT], affine_transform( @@ -202,6 +207,8 @@ def affine_transform( output_shape, order, mode, cval )) + scipy = get_scipy_module(input, matrix, offset) + if offset is None: offset = 0. if output_shape is None: diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 700432a..be5948d 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -208,10 +208,10 @@ def hdf5_write_object_state(state: ObjectState, group: h5py.Group): assert state.data.ndim == 3 assert state.thicknesses.ndim == 1 n_z = state.data.shape[0] - assert state.thicknesses.ndim == 1 - assert state.thicknesses.size == n_z if n_z > 1 else state.thicknesses.size in (0, 1) thick = to_numpy(state.thicknesses) + assert thick.ndim == 1 + assert thick.size == n_z if n_z > 1 else thick.size in (0, 1) group.create_dataset('thicknesses', data=thick) zs = group.create_dataset('zs', data=to_numpy(state.zs())) zs.make_scale("z") diff --git a/phaser/utils/misc.py b/phaser/utils/misc.py index 0c5467a..cdd118f 100644 --- a/phaser/utils/misc.py +++ b/phaser/utils/misc.py @@ -1,11 +1,11 @@ -import dataclasses +import functools import math +from types import ModuleType import typing as t import numpy from numpy.typing import NDArray from numpy.random import SeedSequence, PCG64, BitGenerator, Generator -from typing_extensions import dataclass_transform T = t.TypeVar('T') @@ -217,79 +217,57 @@ def __eq__(self, other: t.Any) -> bool: round(self, 5) == round(other, 5) -@t.overload -@dataclass_transform(kw_only_default=False, frozen_default=False) -def jax_dataclass(cls: t.Type[T], /, *, - init: bool = True, kw_only: bool = False, frozen: bool = False, - static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), -) -> t.Type[T]: - ... - -@t.overload -@dataclass_transform(kw_only_default=False, frozen_default=False) -def jax_dataclass(*, - init: bool = True, kw_only: bool = False, frozen: bool = False, - static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), -) -> t.Callable[[t.Type[T]], t.Type[T]]: - ... - -def jax_dataclass(cls: t.Optional[t.Type[T]] = None, /, *, - init: bool = True, kw_only: bool = False, frozen: bool = False, - static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), -) -> t.Union[t.Type[T], t.Callable[[t.Type[T]], t.Type[T]]]: - if cls is None: - return lambda cls: jax_dataclass(cls, init=init, kw_only=kw_only, frozen=frozen, - static_fields=static_fields, drop_fields=drop_fields) - - cls = dataclasses.dataclass(init=init, kw_only=kw_only, frozen=frozen)(cls) - _register_dataclass(cls, static_fields=static_fields, drop_fields=drop_fields) - return cls - +def unwrap(val: t.Optional[T]) -> T: + assert val is not None + return val -def _register_dataclass(cls: type, static_fields: t.Sequence[str], drop_fields: t.Sequence[str]): - try: - from jax.tree_util import register_pytree_with_keys - except ImportError: - return - fields = dataclasses.fields(cls) - field_names = {field.name for field in fields} +class _MockModule: + def __init__(self, module: ModuleType, rewrites: t.Dict[str, t.Callable], wrap: t.Callable): + self._inner: ModuleType = module + self._rewrites: t.Dict[str, t.Callable] = rewrites + self._wrap: t.Callable = wrap - if (extra := set(static_fields).difference(field_names)): - raise ValueError(f"Unknown field(s) passed to 'static_fields': {', '.join(map(repr, extra))}") - if (extra := set(drop_fields).difference(field_names)): - raise ValueError(f"Unknown field(s) passed to 'drop_fields': {', '.join(map(repr, extra))}") + self.__name__ = module.__name__ + """ + self.__spec__ = module.__spec__ + self.__package__ = module.__package__ + self.__loader__ = module.__loader__ + self.__path__ = module.__path__ + self.__doc__ = module.__doc__ + self.__annotations__ = module.__annotations__ + if hasattr(module, '__file__') and hasattr(module, '__cached__'): + self.__file__ = module.__file__ + self.__cached__ = module.__cached__ + """ - data_fields = tuple(field_names.difference(static_fields).difference(drop_fields)) + self.__setattr__ = lambda name, val: setattr(self._inner, name, val) - def flatten_with_keys(x: t.Any, /) -> tuple[t.Iterable[tuple[str, t.Any]], t.Hashable]: - meta = tuple(getattr(x, name) for name in static_fields) - trees = tuple((name, getattr(x, name)) for name in data_fields) - return trees, meta + def __getattr__(self, name: t.Any) -> t.Any: + fullpath = f"{self.__name__}.{name}" + if (rewrite := self._rewrites.get(fullpath, None)): + if (val := getattr(self._inner, name, None)) is not None: + return functools.update_wrapper(rewrite, val) + return rewrite - def unflatten(meta: t.Hashable, trees: t.Iterable[t.Any], /) -> t.Any: - if not isinstance(meta, tuple): - raise TypeError - static_args = dict(zip(static_fields, meta, strict=True)) - data_args = dict(zip(data_fields, trees, strict=True)) - return cls(**static_args, **data_args) + val = getattr(self._inner, name) - def flatten(x: t.Any, /) -> tuple[t.Iterable[t.Any], t.Hashable]: - hashed = tuple(getattr(x, name) for name in static_fields) - trees = tuple(getattr(x, name) for name in data_fields) - return trees, hashed + if isinstance(val, ModuleType): + return _MockModule(val, self._rewrites, self._wrap) - register_pytree_with_keys(cls, flatten_with_keys, unflatten, flatten) + if hasattr(val, '__call__') and not isinstance(val, type): + def inner(*args, **kwargs): + return self._wrap(val, *args, **kwargs) + return inner + return functools.update_wrapper(inner, val) -def unwrap(val: t.Optional[T]) -> T: - assert val is not None - return val + return val __all__ = [ 'create_rng', 'create_rng_group', 'create_sparse_groupings', 'create_compact_groupings', 'mask_fraction_of_groups', 'FloatKey', - 'jax_dataclass', 'unwrap', + 'unwrap', ] diff --git a/phaser/utils/num.py b/phaser/utils/num.py index db97103..bbddf9a 100644 --- a/phaser/utils/num.py +++ b/phaser/utils/num.py @@ -3,21 +3,25 @@ """ import functools +from itertools import chain import logging import warnings +from types import ModuleType, EllipsisType import typing as t +import sys import numpy from numpy.typing import ArrayLike, DTypeLike, NDArray from phaser.types import BackendName -from .misc import jax_dataclass +from .tree import tree_dataclass if t.TYPE_CHECKING: - from phaser.utils.image import _BoundaryMode + from phaser.utils.image import _InterpBoundaryMode +Device: t.TypeAlias = t.Any Float: t.TypeAlias = t.Union[float, numpy.floating] NumT = t.TypeVar('NumT', bound=numpy.number) FloatT = t.TypeVar('FloatT', bound=numpy.floating) @@ -27,98 +31,220 @@ P = t.ParamSpec('P') IndexLike: t.TypeAlias = t.Union[ - int, + int, slice, EllipsisType, NDArray[numpy.integer[t.Any]], NDArray[numpy.bool_], - t.Tuple[t.Union[int, NDArray[numpy.integer[t.Any]], NDArray[numpy.bool_]], ...], + t.Tuple[t.Union[int, slice, EllipsisType, NDArray[numpy.integer[t.Any]], NDArray[numpy.bool_]], ...], ] - logger = logging.getLogger(__name__) -try: + +def _load_cupy() -> ModuleType: + from ._cuda_kernels import mock_cupy + + with warnings.catch_warnings(): + # https://github.com/cupy/cupy/issues/8718 + warnings.filterwarnings(action='ignore', message=r"cupyx\.jit\.rawkernel is experimental", category=FutureWarning) + import cupyx.scipy.signal # pyright: ignore[reportMissingImports,reportUnusedImport] + import cupyx.scipy.ndimage # pyright: ignore[reportMissingImports,reportUnusedImport] # noqa: F401 + + return t.cast(ModuleType, mock_cupy) + +def _load_jax() -> ModuleType: import jax jax.config.update('jax_enable_x64', jax.default_backend() != 'METAL') - #jax.config.update('jax_log_compiles', True) - #jax.config.update('jax_debug_nans', True) -except ImportError: - pass + import jax.scipy + return jax.numpy -def get_backend_module(backend: t.Optional[BackendName] = None): - """Get the module `xp` associated with a compute backend""" - if backend is None: - return get_default_backend_module() +def _load_torch() -> ModuleType: + from ._torch_kernels import mock_torch + return t.cast(ModuleType, mock_torch) - backend = t.cast(BackendName, backend.lower()) - if backend not in ('cuda', 'cupy', 'jax', 'cpu', 'numpy'): - raise ValueError(f"Unknown backend '{backend}'") - if not t.TYPE_CHECKING: +_NAME_REMAP: t.Dict[BackendName, BackendName] = {} + +_LOAD_FNS: t.Dict[BackendName, t.Callable[[], ModuleType]] = { + 'cupy': _load_cupy, + 'jax': _load_jax, + 'torch': _load_torch, +} + + +class _BackendLoader: + def __init__(self): + self.inner: t.Dict[BackendName, t.Optional[ModuleType]] = {} + + def _normalize(self, backend: BackendName) -> BackendName: + name = t.cast(BackendName, backend.lower()) + name = _NAME_REMAP.get(name, name) + + if name not in ('cupy', 'jax', 'numpy', 'torch'): + raise ValueError(f"Unknown backend '{backend}'") + return name + + def _load(self, name: BackendName): try: - if backend == 'jax': - import jax.numpy - return jax.numpy - if backend in ('cupy', 'cuda'): - import cupy - return cupy + self.inner[name] = _LOAD_FNS[name]() except ImportError: - raise ValueError(f"Backend '{backend}' is not available") + self.inner[name] = None - return numpy + def get(self, name: BackendName): + name = self._normalize(name) + if name == 'numpy': + return numpy + if name not in self.inner: + self._load(name) -def detect_supported_backends() -> t.Dict[BackendName, t.Tuple[str, ...]]: - backends: t.Dict[BackendName, t.Tuple[str, ...]] = {'numpy': ('cpu',)} + return None if t.TYPE_CHECKING else self.inner[name] - try: - import jax.numpy # type: ignore - devices = jax.devices() - backends['jax'] = tuple(f"{device.platform}:{device.id}" for device in devices) - except ImportError: - pass + def __getitem__(self, name: BackendName): + if (backend := self.get(name)) is not None: + return backend - try: - import cupy # type: ignore - n_devices = cupy.cuda.runtime.getDeviceCount() - backends['cupy'] = tuple(f'cuda:{i}' for i in range(n_devices)) - except ImportError: - pass + raise ValueError(f"Backend '{name}' is not available") + +_BACKEND_LOADER = _BackendLoader() + + +def get_backend_module(backend: t.Optional[BackendName] = None): + """Get the module `xp` associated with a compute backend""" + if backend is None: + backend = get_default_backend() + + return _BACKEND_LOADER[backend] - return backends +def get_backend_scipy(backend: BackendName): + """Get the scipy module associated with a compute backend""" + + name = _BACKEND_LOADER._normalize(backend) + # ensure backend is loadable + _BACKEND_LOADER[backend] -def get_default_backend_module(): if not t.TYPE_CHECKING: + if name == 'torch': + raise ValueError("`get_backend_scipy` is not supported for the PyTorch backend") + if name == 'jax': + return sys.modules['jax.scipy'] + if name == 'cupy': + return sys.modules['cupyx.scipy'] + + import scipy + return scipy + + +def get_default_backend() -> BackendName: + # check for jax or torch GPUs first + if _BACKEND_LOADER.get('jax') is not None: + import jax try: - import jax.numpy - return jax.numpy - except ImportError: + if len(jax.devices('gpu')): + return 'jax' + except RuntimeError: pass - try: - import cupy - return cupy - except ImportError: + if len(jax.devices('tpu')): + return 'jax' + except RuntimeError: pass - - return numpy + if _BACKEND_LOADER.get('torch') is not None: + import torch + if torch.get_default_device().type != 'cpu': + return 'torch' + + for backend in ('jax', 'torch', 'cupy'): + if _BACKEND_LOADER.get(backend) is not None: + return backend + return 'numpy' + + +def get_devices() -> t.Tuple[t.Tuple[str, Device], ...]: + devices: t.List[t.Tuple[str, Device]] = [] + + if _BACKEND_LOADER.get('jax') is not None: + from ._jax_kernels import get_devices + devices.extend(('jax', device) for device in get_devices()) + if _BACKEND_LOADER.get('torch') is not None: + from ._torch_kernels import get_devices + devices.extend(('torch', device) for device in get_devices()) + if _BACKEND_LOADER.get('cupy') is not None: + from ._cuda_kernels import get_devices + devices.extend(('cupy', device) for device in get_devices()) + devices.append(('numpy', 'cpu')) + + return tuple(devices) + + +def to_device(device: t.Union[str, Device], xp: t.Any) -> Device: + if xp_is_torch(xp): + from ._torch_kernels import to_device + return to_device(device) + if xp_is_cupy(xp): + from ._cuda_kernels import to_device + return to_device(device) + if xp_is_jax(xp): + from ._jax_kernels import to_device + return to_device(device) + if xp is not numpy: + raise TypeError(f"Expected an array backend, got '{xp}'") + if device != 'cpu': + raise ValueError(f"Invalid device '{device}' for backend 'numpy'") + return device + + +def get_backend_devices(xp: t.Any) -> t.Tuple[Device, ...]: + if xp_is_torch(xp): + from ._torch_kernels import get_devices + return get_devices() + if xp_is_cupy(xp): + from ._cuda_kernels import get_devices + return get_devices() + if xp_is_jax(xp): + from ._jax_kernels import get_devices + return get_devices() + if xp is not numpy: + raise TypeError(f"Expected an array backend, got '{xp}'") + + return ('cpu',) + + +def set_default_device(device: Device, xp: t.Any): + if xp_is_torch(xp): + from ._torch_kernels import set_default_device + set_default_device(device) + elif xp_is_cupy(xp): + from ._cuda_kernels import set_default_device + set_default_device(device) + elif xp_is_jax(xp): + from ._jax_kernels import set_default_device + set_default_device(device) + elif xp is not numpy: + raise TypeError(f"Expected an array backend, got '{xp}'") + elif device != 'cpu': + raise ValueError(f"Invalid device '{device}' for backend 'numpy'") def get_array_module(*arrs: t.Optional[ArrayLike]): - try: - import jax - if any(isinstance(arr, jax.Array) for arr in arrs) \ - and not t.TYPE_CHECKING: - return jax.numpy - except ImportError: - pass - try: - from cupy import get_array_module as f # type: ignore - if not t.TYPE_CHECKING: - return f(*arrs) - except ImportError: - pass + if (xp := _BACKEND_LOADER.get('jax')) is not None: + import jax.tree + if any( + isinstance(arr, xp.ndarray) + for arr in chain.from_iterable(map(jax.tree.leaves, arrs)) + ): + return xp + if (xp := _BACKEND_LOADER.get('torch')) is not None: + from torch.utils._pytree import tree_leaves + if any( + isinstance(arr, (xp._MockTensor, xp._C.TensorBase)) # type: ignore + for arr in chain.from_iterable(map(tree_leaves, arrs)) + ): + return xp + if (xp := _BACKEND_LOADER.get('cupy')) is not None: + if any(isinstance(arr, xp.ndarray) for arr in arrs): + return xp return numpy @@ -131,33 +257,22 @@ def cast_array_module(xp: t.Any): def get_scipy_module(*arrs: t.Optional[ArrayLike]): # pyright: ignore[reportMissingImports,reportUnusedImport] - import scipy - - try: - import jax - if any(isinstance(arr, jax.Array) for arr in arrs) \ - and not t.TYPE_CHECKING: - return jax.scipy - except ImportError: - pass - try: - with warnings.catch_warnings(): - # https://github.com/cupy/cupy/issues/8718 - warnings.filterwarnings(action='ignore', message=r"cupyx\.jit\.rawkernel is experimental", category=FutureWarning) - - import cupyx.scipy.signal # pyright: ignore[reportMissingImports] - import cupyx.scipy.ndimage # pyright: ignore[reportMissingImports] # noqa: F401 - from cupyx.scipy import get_array_module as f # pyright: ignore[reportMissingImports] - - if not t.TYPE_CHECKING: + if not t.TYPE_CHECKING: + if (xp := _BACKEND_LOADER.get('jax')) is not None: + if any(isinstance(arr, xp.ndarray) for arr in arrs): + return sys.modules['jax.scipy'] + if (xp := _BACKEND_LOADER.get('torch')) is not None: + if any(isinstance(arr, (xp._MockTensor, xp._C.TensorBase)) for arr in arrs): # type: ignore + raise ValueError("`get_scipy_module` is not supported for the PyTorch backend") + if (xp := _BACKEND_LOADER.get('cupy')) is not None: + f = sys.modules['cupyx.scipy'].get_array_module return f(*arrs) - except ImportError: - pass + import scipy return scipy -def to_numpy(arr: t.Union[DTypeT, NDArray[DTypeT]], stream=None) -> NDArray[DTypeT]: +def to_numpy(arr: t.Union[DTypeT, NDArray[DTypeT], float, DTypeT], stream=None) -> NDArray[DTypeT]: """ Convert an array to numpy. For cupy backend, this is equivalent to `cupy.asnumpy`. @@ -165,7 +280,8 @@ def to_numpy(arr: t.Union[DTypeT, NDArray[DTypeT]], stream=None) -> NDArray[DTyp if not t.TYPE_CHECKING: if is_jax(arr): return numpy.array(arr) - + if is_torch(arr): + return arr.numpy(force=True) if is_cupy(arr): return arr.get(stream) @@ -180,7 +296,8 @@ def as_numpy(arr: ArrayLike, stream=None) -> NDArray: if not t.TYPE_CHECKING: if is_jax(arr): return numpy.array(arr) - + if is_torch(arr): + return arr.numpy(force=True) if is_cupy(arr): return arr.get(stream) @@ -201,46 +318,57 @@ def as_array(arr: ArrayLike, xp: t.Any = None) -> numpy.ndarray: return numpy.asarray(arr) -def is_cupy(arr: NDArray[DTypeT]) -> bool: - try: - import cupy # pyright: ignore[reportMissingImports] - except ImportError: +def is_cupy(arr: NDArray[numpy.generic]) -> bool: + if (cupy := _BACKEND_LOADER.get('cupy')) is None: return False return isinstance(arr, cupy.ndarray) def is_jax(arr: t.Any) -> bool: - try: - import jax # pyright: ignore[reportMissingImports] - except ImportError: + if (jnp := _BACKEND_LOADER.get('jax')) is None: return False + import jax # pyright[ignoreMissingImports] + return any( - isinstance(arr, jax.Array) for arr in jax.tree_util.tree_leaves(arr) + isinstance(arr, jnp.ndarray) + for arr in jax.tree_util.tree_leaves(arr) ) -def xp_is_cupy(xp: t.Any) -> bool: - try: - import cupy # pyright: ignore[reportMissingImports] - return xp is cupy - except ImportError: +def is_torch(arr: t.Any) -> bool: + if (torch := t.cast(ModuleType, _BACKEND_LOADER.get('torch'))) is None: return False + return any( + isinstance(arr, (torch._MockTensor, torch._C.TensorBase)) + for arr in torch.utils._pytree.tree_leaves(arr) + ) + + +def xp_is_cupy(xp: t.Any) -> bool: + return xp is sys.modules.get('cupy') def xp_is_jax(xp: t.Any) -> bool: - try: - import jax.numpy # pyright: ignore[reportMissingImports] - return xp is jax.numpy - except ImportError: + return xp is sys.modules.get('jax.numpy') + +def xp_is_torch(xp: t.Any) -> bool: + if (torch := _BACKEND_LOADER.get('torch')) is None: return False + return xp is torch def block_until_ready(arr: NDArray[DTypeT]) -> NDArray[DTypeT]: if hasattr(arr, 'block_until_ready'): # jax return arr.block_until_ready() # type: ignore + if is_torch(arr): + import torch + device = torch.get_default_device() + if device.type == 'cuda': + torch.cuda.synchronize(device) + if is_cupy(arr): - import cupy # pyright: ignore[reportMissingImports] + cupy = sys.modules['cupy'] stream = cupy.cuda.get_current_stream() stream.synchronize() @@ -261,20 +389,17 @@ def __init__( self.inner = f functools.update_wrapper(self, f) - if cupy_fuse: - try: - import cupy # pyright: ignore[reportMissingImports] - self.inner = cupy.fuse()(self.inner) - except ImportError: - pass + if cupy_fuse and (cupy := _BACKEND_LOADER.get('cupy')): + self.inner = cupy.fuse()(self.inner) # type: ignore # in jax: self.__call__ -> jax.jit -> jax_f -> f # otherwise: self.__call__ -> f - try: - import jax - except ImportError: - self.jax_jit = None - else: + if _BACKEND_LOADER.get('jax') is not None: + if t.TYPE_CHECKING: + import jax + else: + jax = sys.modules['jax'] + @functools.wraps(f) def jax_f(*args: P.args, **kwargs: P.kwargs) -> T: logger.info(f"JIT-compiling kernel '{self.__qualname__}'...") @@ -285,6 +410,8 @@ def jax_f(*args: P.args, **kwargs: P.kwargs) -> T: donate_argnums=donate_argnums, donate_argnames=donate_argnames, inline=inline, #compiler_options=compiler_options ) + else: + self.jax_jit = None def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: @@ -316,12 +443,8 @@ def fuse(*args, **kwargs) -> t.Callable[[T], T]: """ Equivalent to `cupy.fuse`, if supported. """ - try: - import cupy # pyright: ignore[reportMissingImports] - if not t.TYPE_CHECKING: - return cupy.fuse(*args, **kwargs) - except ImportError: - pass + if (xp := _BACKEND_LOADER.get('cupy')): + return xp.fuse(*args, **kwargs) # type: ignore return lambda x: x @@ -333,6 +456,17 @@ def debug_callback(callback: t.Callable[P, None], *args: P.args, **kwargs: P.kwa callback(*args, **kwargs) +def assert_dtype(arr: numpy.ndarray, dtype: t.Type[numpy.generic]): + if is_torch(arr): + from ._torch_kernels import to_torch_dtype, to_numpy_dtype + + if arr.dtype != to_torch_dtype(dtype): + raise TypeError(f"Expected array to be dtype {dtype}, got dtype {to_numpy_dtype(arr.dtype)} instead") + else: + if arr.dtype != dtype: + raise TypeError(f"Expected array to be dtype {dtype}, got dtype {arr.dtype} instead") + + _COMPLEX_MAP: t.Dict[t.Type[numpy.floating], t.Type[numpy.complexfloating]] = { numpy.floating: numpy.complexfloating, numpy.float32: numpy.complex64, @@ -367,6 +501,9 @@ def to_complex_dtype(dtype: DTypeLike) -> t.Type[numpy.complexfloating]: """ Convert a floating point dtype to a complex version. """ + if _BACKEND_LOADER.get('torch') is not None: + from ._torch_kernels import to_numpy_dtype + dtype = to_numpy_dtype(dtype) # type: ignore if not (isinstance(dtype, type) and issubclass(dtype, numpy.generic)): dtype = numpy.dtype(dtype).type @@ -399,6 +536,9 @@ def to_real_dtype(dtype: DTypeLike) -> t.Type[numpy.floating]: """ Convert a complex dtype to a plain float version. """ + if _BACKEND_LOADER.get('torch') is not None: + from ._torch_kernels import to_numpy_dtype + dtype = to_numpy_dtype(dtype) # type: ignore if not (isinstance(dtype, type) and issubclass(dtype, numpy.generic)): dtype = numpy.dtype(dtype).type @@ -439,6 +579,8 @@ def ifft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: """ xp = get_array_module(a) + if xp_is_torch(xp): + return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), dim=(-2, -1)) # type: ignore return xp.fft.fftshift(xp.fft.ifft2(a, norm='ortho'), axes=(-2, -1)) @t.overload @@ -465,6 +607,8 @@ def fft2(a: ArrayLike) -> NDArray[numpy.complexfloating]: """ xp = get_array_module(a) + if xp_is_torch(xp): + return xp.fft.fft2(xp.fft.ifftshift(a, dim=(-2, -1)), norm='ortho') # type: ignore return xp.fft.fft2(xp.fft.ifftshift(a, axes=(-2, -1)), norm='ortho') @@ -506,10 +650,51 @@ def abs2(x: ArrayLike) -> NDArray[numpy.floating]: """ Return the squared amplitude of a complex array. - This is cheaper than `abs(x)**2.` + This is cheaper than `abs(x)**2` """ - x = get_array_module(x).array(x) - return x.real**2. + x.imag**2. # type: ignore + xp = get_array_module(x) + x = xp.asarray(x) + + if xp_is_torch(xp): + if not xp.is_complex(x): # type: ignore + return x**2 # type: ignore + else: + if not xp.iscomplexobj(x): + return x**2 # type: ignore + + return x.real**2 + x.imag**2 # type: ignore + + +_PadMode: t.TypeAlias = t.Literal['constant', 'edge', 'reflect', 'wrap'] + + +@t.overload +def pad( + arr: NDArray[DTypeT], pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *, + mode: _PadMode = 'constant', cval: float = 0., +) -> NDArray[DTypeT]: + ... + +@t.overload +def pad( + arr: ArrayLike, pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *, + mode: _PadMode = 'constant', cval: float = 0., +) -> numpy.ndarray: + ... + +def pad( + arr: ArrayLike, pad_width: t.Union[int, t.Tuple[int, int], t.Sequence[t.Tuple[int, int]]], /, *, + mode: _PadMode = 'constant', cval: float = 0., +) -> numpy.ndarray: + xp = get_array_module(arr) + + if xp_is_torch(xp): + pass + #from ._torch_kernels import pad + #return pad(arr, pad_width, mode=mode, cval=cval) # type: ignore + + return xp.pad(arr, pad_width, mode=mode, constant_values=cval) + @t.overload @@ -529,6 +714,9 @@ def ufunc_outer(ufunc: numpy.ufunc, x: ArrayLike, y: ArrayLike) -> numpy.ndarray from ._jax_kernels import outer return outer(ufunc, x, y) + if not t.TYPE_CHECKING and is_torch(x): + return ufunc(x[(..., *((None,) * y.ndim))], y[(*((None,) * x.ndim), ...)]) + return ufunc.outer(x, y) @@ -541,7 +729,7 @@ def check_finite(*arrs: NDArray[numpy.inexact], context: t.Optional[str] = None) raise ValueError("NaN or inf encountered") -@jax_dataclass(frozen=True, init=False, drop_fields=('extent',)) +@tree_dataclass(frozen=True, init=False, drop_fields=('extent',)) class Sampling: shape: NDArray[numpy.int_] """Sampling shape (n_y, n_x)""" @@ -714,7 +902,7 @@ def resample( self, arr: NDArray[NumT], new_samp: 'Sampling', *, rotation: float = 0.0, order: int = 1, - mode: '_BoundaryMode' = 'grid-constant', + mode: '_InterpBoundaryMode' = 'grid-constant', cval: t.Union[NumT, float] = 0.0, ) -> NDArray[NumT]: from .image import affine_transform, rotation_matrix @@ -736,7 +924,7 @@ def resample_recip( self, arr: NDArray[NumT], new_samp: 'Sampling', *, rotation: float = 0.0, order: int = 1, - mode: '_BoundaryMode' = 'grid-constant', + mode: '_InterpBoundaryMode' = 'grid-constant', cval: t.Union[NumT, float] = 0.0, fftshift: bool = True, ) -> NDArray[NumT]: @@ -811,7 +999,7 @@ def at(arr: NDArray[DTypeT], idx: IndexLike) -> _AtImpl[DTypeT]: __all__ = [ - 'get_backend_module', 'get_default_backend_module', + 'get_backend_module', 'get_default_backend', 'get_array_module', 'cast_array_module', 'get_scipy_module', 'to_numpy', 'as_numpy', 'as_array', 'is_cupy', 'is_jax', 'xp_is_cupy', 'xp_is_jax', diff --git a/phaser/utils/object.py b/phaser/utils/object.py index cadb136..c211523 100644 --- a/phaser/utils/object.py +++ b/phaser/utils/object.py @@ -12,13 +12,14 @@ from numpy.typing import ArrayLike, DTypeLike, NDArray from typing_extensions import Self -from .num import get_array_module, cast_array_module, to_real_dtype, as_numpy, at +from .num import get_array_module, cast_array_module, is_torch, to_real_dtype, as_numpy, at from .num import as_array, is_cupy, is_jax, NumT, ComplexT, DTypeT -from .misc import create_rng, jax_dataclass +from .tree import tree_dataclass +from .misc import create_rng if t.TYPE_CHECKING: - from phaser.utils.image import _BoundaryMode + from phaser.utils.image import _InterpBoundaryMode @t.overload @@ -49,7 +50,7 @@ def random_phase_object(shape: t.Iterable[int], sigma: float = 1e-6, *, seed: t. rng = create_rng(seed, 'random_phase_object') real_dtype = to_real_dtype(dtype) if dtype is not None else numpy.float64 - obj_angle = xp2.array(rng.normal(0., sigma, tuple(shape)), dtype=real_dtype) + obj_angle = xp2.asarray(rng.normal(0., sigma, tuple(shape)), dtype=real_dtype) return xp2.cos(obj_angle) + xp2.sin(obj_angle) * 1.j @@ -98,7 +99,7 @@ def resample_slices( # TODO more options in this case? new_total_thick = numpy.sum(new_thicknesses) - slice_frac = xp.array((new_thicknesses / new_total_thick)[(slice(None), *repeat(None, obj.ndim - 1))]) + slice_frac = xp.asarray((new_thicknesses / new_total_thick)[(slice(None), *repeat(None, obj.ndim - 1))]) return xp.exp((xp.log(obj) * slice_frac).astype(obj.dtype)) if obj.shape[0] != len(old_thicknesses): @@ -178,7 +179,7 @@ def _interp1d(arr: NDArray[NumT], old_zs: NDArray[numpy.floating], new_zs: NDArr else: slice_i = slice_is[i] # linearly interpolate - t = xp.array(float((new_z - old_zs[slice_i]) / delta_zs[slice_i]), dtype=real_dtype) + t = xp.asarray(float((new_z - old_zs[slice_i]) / delta_zs[slice_i]), dtype=real_dtype) slice = ((1-t)*arr[slice_i] + t*arr[slice_i + 1]).astype(arr.dtype) new_arr = at(new_arr, i).set(slice) @@ -186,7 +187,7 @@ def _interp1d(arr: NDArray[NumT], old_zs: NDArray[numpy.floating], new_zs: NDArr return new_arr -@jax_dataclass(frozen=True, init=False) +@tree_dataclass(frozen=True, init=False) class ObjectSampling: shape: NDArray[numpy.int_] """Sampling shape `(n_y, n_x)`""" @@ -286,19 +287,17 @@ def check_scan(self, scan_positions: NDArray[numpy.floating], pad: ArrayLike = 0 (scan_positions[..., 1] < obj_min[1]) | (scan_positions[..., 1] > obj_max[1]) ) if (n_outside := int(xp.sum(outside))): - raise ValueError(f"{n_outside}/{outside.size} probe positions completely outside object") + raise ValueError(f"{n_outside}/{xp.size(outside)} probe positions completely outside object") def _pos_to_object_idx(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> NDArray[numpy.float64]: """Return starting index for the cutout closest to centered around `pos` (`(y, x)`)""" - - if not is_jax(pos): # allow jax tracers to work right - pos = as_numpy(pos) + xp = get_array_module(pos) # for a given cutout, shift to the top left pixel of that cutout # e.g. a 2x2 cutout needs shifted by s/2 - shift = -numpy.maximum(0., (numpy.array(cutout_shape[-2:]) - 1.)) / 2. + shift = -xp.maximum(0., (xp.array(cutout_shape[-2:]) - 1.)) / 2. - return ((pos - self.corner) / self.sampling + shift).astype(numpy.float64) # type: ignore + return ((pos - xp.array(self.corner.copy())) / xp.array(self.sampling.copy()) + shift).astype(numpy.float64) # type: ignore def slice_at_pos(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> t.Tuple[slice, slice]: """ @@ -312,9 +311,10 @@ def slice_at_pos(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> t.Tup Returns slices which can be used to index into an object. E.g. `obj[slice_at_pos(pos, (32, 32))]` will return an array of shape `(32, 32)`. """ + xp = get_array_module(pos) idxs = self._pos_to_object_idx(pos, cutout_shape) - (start_i, start_j) = map(int, numpy.round(idxs).astype(numpy.int64)) + (start_i, start_j) = map(int, xp.round(idxs).astype(numpy.int64)) assert start_i >= 0 and start_j >= 0 return ( slice(start_i, start_i + cutout_shape[-2]), @@ -327,8 +327,10 @@ def get_subpx_shifts(self, pos: ArrayLike, cutout_shape: t.Tuple[int, ...]) -> N Returns the shift from the rounded position towards the actual position, in length units. """ + xp = get_array_module(pos) + pos = self._pos_to_object_idx(as_array(pos), cutout_shape) - return (pos - get_array_module(pos).round(pos)).astype(numpy.float64) * self.sampling + return (pos - xp.round(pos)).astype(numpy.float64) * xp.asarray(self.sampling, copy=True) @t.overload def cutout( # pyright: ignore[reportOverlappingOverload] @@ -342,7 +344,7 @@ def cutout(self, arr: numpy.ndarray, pos: ArrayLike, shape: t.Tuple[int, ...]) - def cutout(self, arr: numpy.ndarray, pos: ArrayLike, shape: t.Tuple[int, ...]) -> ObjectCutout[t.Any]: xp = get_array_module(arr, pos) - return ObjectCutout(self, xp.array(arr), xp.array(pos), shape) + return ObjectCutout(self, xp.asarray(arr), xp.asarray(pos), shape) def get_view_at_pos(self, arr: NDArray[NumT], pos: ArrayLike, shape: t.Tuple[int, ...]) -> NDArray[NumT]: """ @@ -379,8 +381,8 @@ def get_region_crop(self, pad: ArrayLike = 0.) -> t.Tuple[slice, slice]: def get_region_mask(self, pad: ArrayLike = 0., *, xp: t.Any = None) -> NDArray[numpy.bool_]: xp2 = numpy if xp is None else cast_array_module(xp) - mask = xp2.zeros(self.shape, dtype=numpy.bool_) - mask = at(mask, self.get_region_crop(pad=pad)).set(numpy.bool_(1)) # type: ignore + mask = xp2.zeros(tuple(self.shape), dtype=numpy.bool_) + mask = at(mask, self.get_region_crop(pad=pad)).set(t.cast(numpy.bool_, 1)) return mask def get_region_center(self) -> NDArray[numpy.floating]: @@ -429,7 +431,7 @@ def mpl_extent(self, center: bool = True) -> t.Tuple[float, float, float, float] def resample( self, arr: NDArray[NumT], new_samp: 'ObjectSampling', *, - order: int = 1, mode: '_BoundaryMode' = 'grid-constant', + order: int = 1, mode: '_InterpBoundaryMode' = 'grid-constant', cval: t.Union[NumT, float] = 1.0, rotation: t.Optional[float] = None, affine: t.Optional[ArrayLike] = None, @@ -491,8 +493,8 @@ class ObjectCutout(t.Generic[DTypeT]): _start_idxs: NDArray[numpy.int_] = field(init=False) def __post_init__(self): - self._start_idxs = numpy.round(self.sampling._pos_to_object_idx(self.pos, self.cutout_shape)).astype(numpy.int_) # type: ignore - self._start_idxs = get_array_module(self.obj).array(self._start_idxs) + xp = get_array_module(self.pos) + self._start_idxs = xp.round(self.sampling._pos_to_object_idx(self.pos, self.cutout_shape)).astype(numpy.int_) # type: ignore @property def shape(self) -> t.Tuple[int, ...]: @@ -503,6 +505,10 @@ def get(self) -> NDArray[DTypeT]: from ._jax_kernels import get_cutouts return t.cast(NDArray[DTypeT], get_cutouts(self.obj, self._start_idxs, tuple(self.cutout_shape))) + if is_torch(self.obj): + from ._torch_kernels import get_cutouts + return get_cutouts(self.obj, self._start_idxs, tuple(self.cutout_shape)) # type: ignore + if is_cupy(self.obj): try: from ._cuda_kernels import get_cutouts diff --git a/phaser/utils/optics.py b/phaser/utils/optics.py index e027b29..cdb406c 100644 --- a/phaser/utils/optics.py +++ b/phaser/utils/optics.py @@ -165,7 +165,7 @@ def fourier_shift_filter(ky: NDArray[numpy.floating], kx: NDArray[numpy.floating xp = get_array_module(ky, kx) dtype = to_complex_dtype(ky.dtype) - (y, x) = split_array(xp.array(shifts, dtype=ky.dtype), axis=-1) + (y, x) = split_array(xp.asarray(shifts, dtype=ky.dtype), axis=-1) return xp.exp(xp.array(-2.j*numpy.pi, dtype=dtype) * (ufunc_outer(xp.multiply, x, kx) + ufunc_outer(xp.multiply, y, ky))) diff --git a/phaser/utils/scan.py b/phaser/utils/scan.py index 26d26e8..5350023 100644 --- a/phaser/utils/scan.py +++ b/phaser/utils/scan.py @@ -42,17 +42,17 @@ def make_raster_scan(shape: t.Tuple[int, int], scan_step: ArrayLike, dtype = numpy.float64 # TODO actually center this around (0, 0) - yy = xp2.arange(shape[0], dtype=dtype) - xp2.array(shape[0] / 2., dtype=dtype) - xx = xp2.arange(shape[1], dtype=dtype) - xp2.array(shape[1] / 2., dtype=dtype) + yy = xp2.arange(shape[0], dtype=dtype) - xp2.asarray(shape[0] / 2., dtype=dtype) + xx = xp2.arange(shape[1], dtype=dtype) - xp2.asarray(shape[1] / 2., dtype=dtype) pts = xp2.stack(xp2.meshgrid(yy, xx, indexing='ij'), axis=-1) if rotation != 0.: theta = rotation * numpy.pi/180. - mat = xp2.array([[numpy.cos(theta), -numpy.sin(theta)], [numpy.sin(theta), numpy.cos(theta)]], dtype=dtype) + mat = xp2.asarray([[numpy.cos(theta), -numpy.sin(theta)], [numpy.sin(theta), numpy.cos(theta)]], dtype=dtype) pts = (pts @ mat.T) - return pts * xp2.broadcast_to(xp2.array(scan_step), (2,)).astype(dtype) # type: ignore + return pts * xp2.broadcast_to(xp2.asarray(scan_step), (2,)).astype(dtype) # type: ignore __all__ = [ diff --git a/phaser/utils/tree.py b/phaser/utils/tree.py new file mode 100644 index 0000000..6198c36 --- /dev/null +++ b/phaser/utils/tree.py @@ -0,0 +1,416 @@ +import dataclasses +import functools +import typing as t + +import numpy +from numpy.typing import ArrayLike, DTypeLike, NDArray +from typing_extensions import Self, dataclass_transform + +T = t.TypeVar('T') +Leaf: t.TypeAlias = t.Any +Tree: t.TypeAlias = t.Any + +class TreeSpec(t.Protocol): + @property + def num_leaves(self) -> int: + ... + + @property + def num_nodes(self) -> int: + ... + + def unflatten(self, leaves: t.Iterable[Leaf], /) -> Tree: + ... + + def flatten_up_to(self, xs: Tree, /) -> t.List[Tree]: + ... + + def __eq__(self, other: Self, /) -> bool: # pyright: ignore[reportIncompatibleMethodOverride] + ... + + def __ne__(self, other: Self, /) -> bool: # pyright: ignore[reportIncompatibleMethodOverride] + ... + +class Key(t.Protocol): + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + +class GetAttrKey(Key, t.Protocol): + @property + def name(self) -> str: + ... + + +KeyPath: t.TypeAlias = t.Tuple[Key, ...] + + +def flatten( + tree: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.Tuple[t.List[Leaf], TreeSpec]: + from phaser.utils.num import is_torch + + if is_torch(tree): + from torch.utils._pytree import tree_flatten # type: ignore + return tree_flatten(tree, is_leaf) + + import jax.tree # type: ignore + return jax.tree.flatten(tree, is_leaf) + + +def flatten_with_path( + tree: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.Tuple[t.List[t.Tuple[KeyPath, Leaf]], TreeSpec]: + from phaser.utils.num import is_torch + + if is_torch(tree): + from torch.utils._pytree import tree_flatten_with_path # type: ignore + return tree_flatten_with_path(tree, is_leaf) # type: ignore + + from jax.tree_util import tree_flatten_with_path + return tree_flatten_with_path(tree, is_leaf) + + +def unflatten( + leaves: t.Iterable[t.Any], + treespec: TreeSpec +) -> Tree: + try: + from torch.utils._pytree import TreeSpec + if isinstance(treespec, TreeSpec): + return treespec.unflatten(leaves) + except ImportError: + pass + try: + from jax.tree_util import PyTreeDef + if isinstance(treespec, PyTreeDef): + return treespec.unflatten(leaves) + except ImportError: + pass + + raise TypeError( + f"tree_unflatten expected `treespec` to be a TreeSpec, " + f"got item of type {type(treespec)} instead." + ) + + +def map( + f: t.Callable[..., t.Any], + tree: Tree, + *rest: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.Any: + from phaser.utils.num import is_torch + + if is_torch(tree): + from torch.utils._pytree import tree_map # type: ignore + return tree_map(f, tree, *rest, is_leaf=is_leaf) + + import jax.tree # type: ignore + return jax.tree.map(f, tree, *rest, is_leaf=is_leaf) + + +def reduce( + f: t.Callable[[T, t.Any], T], tree: Tree, initializer: T, *, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> T: + return functools.reduce(f, leaves(tree, is_leaf=is_leaf), initializer) + + +def sum(tree: Tree) -> numpy.ndarray: + from phaser.utils.num import get_array_module + + xp = get_array_module(tree) + sums = map(xp.sum, tree) + return reduce(lambda lhs, rhs: lhs + rhs, sums, initializer=0) + + +def map_with_path( + f: t.Callable[..., t.Any], + tree: Tree, + *rest: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.Any: + from phaser.utils.num import is_torch + + if is_torch(tree): + from torch.utils._pytree import tree_map_with_path # type: ignore + + def wrapper(path: KeyPath, *leaves: t.Any): + return f(tuple(path), *leaves) + + return tree_map_with_path(wrapper, tree, *rest, is_leaf=is_leaf) + + from jax.tree_util import tree_map_with_path # type: ignore + return tree_map_with_path(f, tree, *rest, is_leaf=is_leaf) + + +def grad( + f: t.Callable, + argnums: t.Union[int, t.Tuple[int, ...]] = 0, + has_aux: bool = False, *, xp: t.Optional[t.Any] = None, +) -> t.Callable[..., Tree]: + from phaser.utils.num import xp_is_torch, xp_is_jax + + if xp is None or xp_is_jax(xp): + import jax # type: ignore + return jax.grad(f, argnums, has_aux=has_aux) + if xp_is_torch(xp): + import torch.func # type: ignore + return torch.func.grad(f, argnums, has_aux=has_aux) + raise ValueError("`grad` is only supported for backends 'jax' and 'torch'") + + +def value_and_grad( + f: t.Callable, + argnums: t.Union[int, t.Tuple[int, ...]] = 0, + has_aux: bool = False, *, xp: t.Optional[t.Any] = None, + sign: float = 1.0, +) -> t.Callable[..., t.Tuple[Tree, Tree]]: + from phaser.utils.num import xp_is_torch, xp_is_jax + + if xp is None or xp_is_jax(xp): + import jax # type: ignore + f = jax.value_and_grad(f, argnums, has_aux=has_aux) + + @functools.wraps(f) + def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Tuple[Tree, Tree]: + (value, grad) = f(*args, **kwargs) + # conjugate to get Wirtinger derivative, multiply by sign + grad = map(lambda arr: arr.conj() * sign, grad, is_leaf=lambda x: x is None) + return (value, grad) + + return wrapper + + if not xp_is_torch(xp): + raise ValueError("`grad` is only supported for backends 'jax' and 'torch'") + + import torch.func # type: ignore + f = torch.func.grad_and_value(f, argnums, has_aux=has_aux) + + @functools.wraps(f) + def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Tuple[Tree, Tree]: + # flip order of return values + (grad, value) = f(*args, **kwargs) + # multiply by sign + grad = map(lambda arr: arr * sign, grad, is_leaf=lambda x: x is None) + return (value, grad) + + return wrapper + + +def leaves( + tree: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.List[Leaf]: + return flatten(tree, is_leaf)[0] + + +def structure( + tree: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> TreeSpec: + return flatten(tree, is_leaf)[1] + + +def leaves_with_path( + tree: Tree, + is_leaf: t.Optional[t.Callable[..., t.Any]] = None, +) -> t.List[t.Tuple[KeyPath, Leaf]]: + return flatten_with_path(tree, is_leaf)[0] + + +def zeros_like( + tree: Tree, dtype: DTypeLike = None, +) -> Tree: + from phaser.utils.num import get_array_module + xp = get_array_module(tree) + kwargs: t.Dict[str, t.Any] = {'dtype': dtype} if dtype is not None else {} + return map(lambda x: xp.zeros_like(x, **kwargs), tree) + + +def ones_like( + tree: Tree, dtype: DTypeLike = None, +) -> Tree: + from phaser.utils.num import get_array_module + xp = get_array_module(tree) + kwargs: t.Dict[str, t.Any] = {'dtype': dtype} if dtype is not None else {} + return map(lambda x: xp.ones_like(x, **kwargs), tree) + + +def full_like( + tree: Tree, fill_value: ArrayLike, + dtype: DTypeLike = None, +) -> Tree: + from phaser.utils.num import get_array_module + xp = get_array_module(tree) + kwargs: t.Dict[str, t.Any] = {'dtype': dtype} if dtype is not None else {} + return map(lambda x: xp.full_like(x, fill_value, **kwargs), tree) + + +def cast( + tree: Tree, dtype: t.Optional[DTypeLike], +) -> Tree: + if dtype is None: + return tree + return map(lambda x: x.astype(dtype), tree) + + +def clip( + tree: Tree, + min_value: t.Optional[ArrayLike] = None, + max_value: t.Optional[ArrayLike] = None, +) -> Tree: + from phaser.utils.num import get_array_module + xp = get_array_module(tree) + return map(lambda x: xp.clip(x, min_value, max_value), tree) + + +def conj( + tree: Tree +) -> Tree: + from phaser.utils.num import get_array_module + xp = get_array_module(tree) + return map(xp.conj, tree) + + +def update_moment(updates: Tree, moments: Tree, decay: float, order: int) -> Tree: + return map( + lambda g, t: ( + (1 - decay) * (g**order) + decay * t if g is not None else None + ), + updates, + moments, + is_leaf=lambda x: x is None, + ) + + +def update_moment_per_elem_norm(updates: Tree, moments: Tree, decay: float, order: int) -> Tree: + from phaser.utils.num import get_array_module, abs2 + xp = get_array_module(updates, moments) + + def orderth_norm(g): + if xp.isrealobj(g): + return g ** order + + half_order = order / 2 + # JAX generates different HLO for int and float `order` + if half_order.is_integer(): + half_order = int(half_order) + return abs2(g) ** half_order + + return map( + lambda g, t: ( + (1 - decay) * orderth_norm(g) + decay * t if g is not None else None + ), + updates, + moments, + is_leaf=lambda x: x is None, + ) + + +def bias_correction(moment: Tree, decay: float, count: t.Union[int, NDArray[numpy.integer]]) -> Tree: + bias_correction = t.cast(NDArray[numpy.floating], 1 - decay**count) + return map(lambda t: t / bias_correction.astype(t.dtype), moment) + + +def scale( + scalar: t.Union[float, numpy.floating, NDArray[numpy.floating]], + tree: Tree +) -> Tree: + return map(lambda x: scalar * x, tree) + + +def squared_norm( + tree: Tree +) -> NDArray[numpy.floating]: + return sum(map(lambda x: x**2, tree)) + + +@t.overload +@dataclass_transform(kw_only_default=False, frozen_default=False) +def tree_dataclass(cls: t.Type[T], /, *, + init: bool = True, kw_only: bool = False, frozen: bool = False, + static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), +) -> t.Type[T]: + ... + +@t.overload +@dataclass_transform(kw_only_default=False, frozen_default=False) +def tree_dataclass(*, + init: bool = True, kw_only: bool = False, frozen: bool = False, + static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), +) -> t.Callable[[t.Type[T]], t.Type[T]]: + ... + +def tree_dataclass(cls: t.Optional[t.Type[T]] = None, /, *, + init: bool = True, kw_only: bool = False, frozen: bool = False, + static_fields: t.Sequence[str] = (), drop_fields: t.Sequence[str] = (), +) -> t.Union[t.Type[T], t.Callable[[t.Type[T]], t.Type[T]]]: + if cls is None: + return lambda cls: tree_dataclass(cls, init=init, kw_only=kw_only, frozen=frozen, + static_fields=static_fields, drop_fields=drop_fields) + + cls = dataclasses.dataclass(init=init, kw_only=kw_only, frozen=frozen)(cls) + _register_dataclass(cls, static_fields=static_fields, drop_fields=drop_fields) + return cls + + +def _register_dataclass(cls: type, static_fields: t.Sequence[str], drop_fields: t.Sequence[str]): + fields = dataclasses.fields(cls) + field_names = {field.name for field in fields} + + if (extra := set(static_fields).difference(field_names)): + raise ValueError(f"Unknown field(s) passed to 'static_fields': {', '.join(map(repr, extra))}") + if (extra := set(drop_fields).difference(field_names)): + raise ValueError(f"Unknown field(s) passed to 'drop_fields': {', '.join(map(repr, extra))}") + + data_fields = tuple(field_names.difference(static_fields).difference(drop_fields)) + + def make_flatten_with_keys( + key_type: t.Callable[[str], Key] + ) -> t.Callable[[t.Any], t.Tuple[t.List[t.Tuple[Key, t.Any]], t.Hashable]]: + def flatten_with_keys(x: t.Any, /) -> tuple[list[tuple[Key, t.Any]], t.Hashable]: + meta = tuple(getattr(x, name) for name in static_fields) + trees = list((key_type(name), getattr(x, name)) for name in data_fields) + return trees, meta + + return flatten_with_keys + + def unflatten(meta: t.Hashable, trees: t.Iterable[t.Any], /) -> t.Any: + if not isinstance(meta, tuple): + raise TypeError + static_args = dict(zip(static_fields, meta, strict=True)) + data_args = dict(zip(data_fields, trees, strict=True)) + return cls(**static_args, **data_args) + + def flatten(x: t.Any, /) -> tuple[list[t.Any], t.Hashable]: + hashed = tuple(getattr(x, name) for name in static_fields) + trees = list(getattr(x, name) for name in data_fields) + return trees, hashed + + try: + from jax.tree_util import register_pytree_with_keys, GetAttrKey + except ImportError: + pass + else: + flatten_with_keys = make_flatten_with_keys(GetAttrKey) + register_pytree_with_keys(cls, flatten_with_keys, unflatten, flatten) + + try: + from torch.utils._pytree import register_pytree_node, GetAttrKey + except ImportError: + pass + else: + flatten_with_keys = make_flatten_with_keys(GetAttrKey) + register_pytree_node( + cls, flatten, lambda trees, meta: unflatten(meta, trees), + flatten_with_keys_fn=flatten_with_keys, # type: ignore + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 1835d49..f23e2a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "click~=8.1.0", "rich>=12.0.0,<15", "tifffile>=2023.8.25", + "optree>=0.13.0", "py-pane==0.11.3", "typing_extensions~=4.7", ] @@ -60,6 +61,9 @@ jax = [ "jax>=0.4.25,<0.8", "optax>=0.2.2", ] +torch = [ + "torch>=2.8.0", +] web = [ "Quart>=0.20.0", "backoff==2.2.1", @@ -91,9 +95,10 @@ include = ["phaser*"] [tool.pytest.ini_options] testpaths = ["tests"] markers = [ - "cuda: Run on CUDA backend", + "cupy: Run on Cupy backend", "jax: Run on jax backend", - "cpu: Run on CPU backend", + "torch: Run on PyTorch backend", + "numpy: Run on numpy backend", "slow: mark a test as slow", "expected_filename: Filename to load expected result from", diff --git a/tests/test_image.py b/tests/test_image.py index 1316902..a65b6be 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -5,11 +5,11 @@ from numpy.testing import assert_array_almost_equal import pytest -from .utils import with_backends, get_backend_module, check_array_equals_file +from .utils import with_backends, check_array_equals_file -from phaser.utils.num import to_numpy, Sampling +from phaser.utils.num import get_backend_module, BackendName, to_numpy, Sampling from phaser.utils.image import ( - affine_transform, _BoundaryMode + affine_transform, _InterpBoundaryMode ) @@ -21,7 +21,7 @@ def checkerboard() -> t.Tuple[NDArray[numpy.float32], Sampling]: return (checker, Sampling(checker.shape, sampling=(1.0, 1.0))) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @pytest.mark.parametrize(('mode', 'order', 'expected'), [ ('grid-constant', 0, [ 1.0, 1.0, 1.0, 1.0, -2.0, -2.0, -2.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0]), ('nearest' , 0, [-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]), @@ -34,19 +34,19 @@ def checkerboard() -> t.Tuple[NDArray[numpy.float32], Sampling]: ('reflect' , 1, [-1.0, -1.4, -1.8, -2.0, -2.0, -2.0, -1.6, -1.2, -0.8, -0.4, -0.0, 0.4, 0.8, 1.2, 1.6, 2.0, 2.0, 2.0, 1.8, 1.4, 1.0]), ('grid-wrap' , 1, [ 1.0, 1.4, 1.8, 1.2, -0.4, -2.0, -1.6, -1.2, -0.8, -0.4, -0.0, 0.4, 0.8, 1.2, 1.6, 2.0, 0.4, -1.2, -1.8, -1.4, -1.0]), ]) -def test_affine_transform_1d(mode: str, order: int, expected: ArrayLike, backend: str): +def test_affine_transform_1d(mode: _InterpBoundaryMode, order: int, expected: ArrayLike, backend: BackendName): xp = get_backend_module(backend) in_ys = numpy.array([-2., -1., 0., 1., 2.]) # interpolates at coords `numpy.linspace(-2., 6., 21, endpoint=True)` assert_array_almost_equal(numpy.array(expected), to_numpy(affine_transform( - xp.array(in_ys), [0.4], -2.0, - mode=t.cast(_BoundaryMode, mode), order=order, cval=1.0, output_shape=(21,) - )), decimal=8) + xp.asarray(in_ys), [0.4], -2.0, + mode=mode, order=order, cval=1.0, output_shape=(21,) + )), decimal=6) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @pytest.mark.parametrize(('name', 'order', 'rotation', 'sampling'), [ ('identity', 1, 0.0, Sampling((16, 16), sampling=(1.0, 1.0))), ('pad', 0, 0.0, Sampling((32, 32), sampling=(1.0, 1.0))), @@ -59,15 +59,16 @@ def test_affine_transform_1d(mode: str, order: int, expected: ArrayLike, backend ]) @check_array_equals_file('resample_{name}_order{order}_rot{rotation:03.1f}.tiff', out_name='resample_{name}_order{order}_rot{rotation:03.1f}_{backend}.tiff') def test_resample( - backend: str, + backend: BackendName, checkerboard: t.Tuple[NDArray[numpy.float32], Sampling], name: str, order: int, rotation: float, sampling: Sampling, ): - if (name, order, rotation, backend) == ('upsample', 0, 0.0, 'jax'): - pytest.xfail("JAX rounding bug?") + if (name, order, rotation) == ('upsample', 0, 0.0) and backend in ('jax', 'torch'): + # TODO: check intermediate dtypes here? + pytest.xfail("Rounding bug?") xp = get_backend_module(backend) diff --git a/tests/test_num.py b/tests/test_num.py index 410ce68..b60756c 100644 --- a/tests/test_num.py +++ b/tests/test_num.py @@ -3,41 +3,50 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal import pytest -from .utils import with_backends, get_backend_module, get_backend_scipy, mock_importerror +from .utils import with_backends # mock_importerror from phaser.utils.num import ( - get_array_module, get_scipy_module, + BackendName, + get_backend_module, to_real_dtype, to_complex_dtype, fft2, ifft2, abs2, to_numpy, as_array, - ufunc_outer + ufunc_outer, ) -@with_backends('cpu', 'jax', 'cuda') -def test_get_array_module(backend: str): +# TODO: this is broken, probably needs to run in a separate process +# the problem is we need to clear the _BackendLoader() cache to get +# proper behavior, but torch can only be imported once +""" +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_get_array_module(backend: BackendName, monkeypatch: pytest.MonkeyPatch): expected = get_backend_module(backend) mocked_imports = { - # on cpu, pretend cupy and jax don't exist - 'cpu': {'cupy', 'jax'}, + # on numpy, pretend cupy and jax don't exist + 'numpy': {'cupy', 'jax', 'torch'}, 'jax': {}, - 'cuda': {}, + 'cupy': {}, + 'torch': {}, }[backend] - assert get_array_module() is numpy + # re-load backend loader so the effect of mocking takes place + monkeypatch.setattr(phaser.utils.num, '_BACKEND_LOADER', _BackendLoader()) with mock_importerror(mocked_imports): + assert get_array_module() is numpy + assert get_array_module( numpy.array([1., 2., 3.]), - expected.array([1, 2, 3]), + expected.asarray([1, 2, 3]), None, numpy.array([1., 2., 3.]), ) is expected -@with_backends('cpu', 'jax', 'cuda') -def test_get_scipy_module(backend: str): +@with_backends('numpy', 'jax', 'cupy') +def test_get_scipy_module(backend: BackendName, monkeypatch: pytest.MonkeyPatch): import scipy xp = get_backend_module(backend) @@ -45,20 +54,24 @@ def test_get_scipy_module(backend: str): mocked_imports = { # on cpu, pretend cupyx doesn't exist - 'cpu': {'cupyx'}, + 'numpy': {'cupyx'}, 'jax': {}, - 'cuda': {}, + 'cupy': {}, }[backend] - assert get_scipy_module() is scipy + # re-load backend loader so the effect of mocking takes place + # monkeypatch.setattr(phaser.utils.num, '_BACKEND_LOADER', _BackendLoader()) with mock_importerror(mocked_imports): + assert get_scipy_module() is scipy + assert get_scipy_module( numpy.array([1., 2., 3.]), - xp.array([1, 2, 3]), + xp.asarray([1, 2, 3]), None, numpy.array([1., 2., 3.]), ) is expected +""" @pytest.mark.parametrize(('input', 'expected'), [ @@ -99,12 +112,12 @@ def test_to_complex_dtype_invalid(): to_complex_dtype(numpy.int_) -@with_backends('cpu', 'jax', 'cuda') -def test_fft2(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_fft2(backend: BackendName): xp = get_backend_module(backend) # point input, f = 5 delta(x) delta(y) - a = xp.pad(xp.array([[5.]], dtype=numpy.float32), ((2, 2), (2, 2))) + a = xp.asarray(numpy.pad([[5.]], (2, 2)).astype(numpy.float32)) # even input, so output is real # delta function input, so output is constant @@ -122,16 +135,16 @@ def test_fft2(backend: str): # zero frequency is cornered assert_array_almost_equal( to_numpy(fft2(a)), - numpy.pad([[5.+5.j]], ((0, 4), (0, 4))).astype(numpy.complex64) + numpy.pad([[5.+5.j]], (0, 4)).astype(numpy.complex64) ) -@with_backends('cpu', 'jax', 'cuda') -def test_ifft2(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_ifft2(backend: BackendName): xp = get_backend_module(backend) # point input, F = delta(k_x) delta(k_y) - a = xp.pad(xp.array([[5.]], dtype=numpy.float32), ((0, 4), (0, 4))) + a = xp.asarray(numpy.pad([[5.]], (0, 4)).astype(numpy.float32)) # even input, so output is real # delta function input, so output is constant @@ -155,30 +168,30 @@ def test_ifft2(backend: str): ) -@with_backends('cpu', 'jax', 'cuda') -def test_abs2(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_abs2(backend: BackendName): xp = get_backend_module(backend) - if backend == 'cpu': + if backend == 'numpy': assert_array_almost_equal(abs2([1.+1.j, 1.-1.j]), numpy.array([2., 2.])) assert_array_almost_equal( - to_numpy(abs2(xp.array([1.+1.j, 1.-1.j]))), + to_numpy(abs2(xp.asarray([1.+1.j, 1.-1.j]))), numpy.array([2., 2.]), ) assert_array_almost_equal( - to_numpy(abs2(xp.array([1., -2., 5.], dtype=numpy.float32))), + to_numpy(abs2(xp.asarray([1., -2., 5.], dtype=numpy.float32))), numpy.array([1, 4., 25.], dtype=numpy.float32), decimal=5 # this is pretty poor performance ) -@with_backends('cpu', 'jax', 'cuda') -def test_to_numpy(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_to_numpy(backend: BackendName): xp = get_backend_module(backend) - arr = xp.array([1., 2., 3., 4.]) + arr = xp.asarray([1., 2., 3., 4.]) assert_array_almost_equal( to_numpy(arr), @@ -186,11 +199,11 @@ def test_to_numpy(backend: str): ) -@with_backends('cpu', 'jax', 'cuda') -def test_to_array(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_to_array(backend: BackendName): xp = get_backend_module(backend) - arr = xp.array([1., 2., 3., 4.]) + arr = xp.asarray([1., 2., 3., 4.]) assert as_array(arr) is arr arr = as_array([1., 2., 3., 4.]) @@ -201,8 +214,8 @@ def test_to_array(backend: str): ) -@with_backends('cpu', 'jax', 'cuda') -def test_ufunc_outer(backend: str): +@with_backends('numpy', 'jax', 'cupy') +def test_ufunc_outer(backend: BackendName): xp = get_backend_module(backend) xs = numpy.arange(12).reshape(4, 3) diff --git a/tests/test_object.py b/tests/test_object.py index 0cca451..f07778c 100644 --- a/tests/test_object.py +++ b/tests/test_object.py @@ -3,19 +3,19 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal import pytest -from .utils import with_backends, get_backend_module, check_array_equals_file +from .utils import with_backends, check_array_equals_file -from phaser.utils.num import to_numpy, abs2 +from phaser.utils.num import get_backend_module, BackendName, to_numpy, abs2 from phaser.utils.object import random_phase_object, ObjectSampling -@with_backends('cpu', 'jax', 'cuda') -def test_random_phase_object(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_random_phase_object(backend: BackendName): xp = get_backend_module(backend) obj = random_phase_object((8, 8), 1e-4, seed=2620771887, dtype=numpy.complex64, xp=xp) - assert obj.dtype == numpy.complex64 + assert obj.dtype == xp.complex64 assert_array_almost_equal(to_numpy(obj), numpy.array([ [1.-1.5272086e-05j, 1.+1.0225522e-04j, 1.-8.0865902e-05j, 1.-1.7328106e-05j, 1.-1.2898073e-04j, 1.+2.2908196e-05j, 1.+8.1173976e-06j, 1.+2.1377344e-05j], [1.+7.4363430e-05j, 1.-9.1323782e-05j, 1.-2.0272582e-04j, 1.-4.8823396e-05j, 1.+9.3021641e-05j, 1.+1.0718761e-04j, 1.+5.0221975e-06j, 1.-5.5743083e-05j], @@ -144,10 +144,10 @@ def test_object_slicing(): ) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @pytest.mark.parametrize('dtype', ('float', 'complex', 'uint8')) @check_array_equals_file('object_get_views_{dtype}.npy', out_name='object_get_views_{dtype}_{backend}.npy') -def test_get_cutouts(backend: str, dtype: str) -> numpy.ndarray: +def test_get_cutouts(backend: BackendName, dtype: str) -> numpy.ndarray: samp = ObjectSampling((200, 200), (1.0, 1.0)) cutout_shape = (64, 64) @@ -174,25 +174,25 @@ def test_get_cutouts(backend: str, dtype: str) -> numpy.ndarray: return to_numpy(samp.cutout(obj, pos, cutout_shape).get()) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @pytest.mark.parametrize('dtype', ('float', 'complex', 'uint8')) @check_array_equals_file('object_add_views_{dtype}.tiff', out_name='object_add_views_{dtype}_{backend}.tiff', decimal=5) -def test_add_view_at_pos(backend: str, dtype: str) -> numpy.ndarray: +def test_add_view_at_pos(backend: BackendName, dtype: str) -> numpy.ndarray: samp = ObjectSampling((200, 200), (1.0, 1.0)) cutout_shape = (64, 64) xp = get_backend_module(backend) if dtype == 'uint8': - obj = xp.zeros(samp.shape, dtype=numpy.uint8) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.uint8) cutouts = xp.full((30, *cutout_shape), 15, dtype=numpy.uint8) mag = 15 elif dtype == 'float': - obj = xp.zeros(samp.shape, dtype=numpy.float32) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.float32) cutouts = xp.full((30, *cutout_shape), 10., dtype=numpy.float32) mag = 10. elif dtype == 'complex': - obj = xp.zeros(samp.shape, dtype=numpy.complex64) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.complex64) phases = xp.array([ 4.30015617, 5.15367214, 6.13496658, 4.9268498 , 3.60960355, 0.42680191, 5.12820671, 1.3260991 , 2.2065813 , 5.1417133 , @@ -247,13 +247,13 @@ def test_add_view_at_pos(backend: str, dtype: str) -> numpy.ndarray: return to_numpy(obj) -@with_backends('cpu', 'jax', 'cuda') -def test_cutout_2d(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_cutout_2d(backend: BackendName): samp = ObjectSampling((200, 200), (1.0, 1.0)) cutout_shape = (64, 64) xp = get_backend_module(backend) - obj = xp.zeros(samp.shape, dtype=numpy.float32) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.float32) cutouts = samp.cutout(obj, [[0., 0.], [2., 2.], [4., 4.], [-2., -2.]], cutout_shape) assert cutouts.get().shape == (4, *cutout_shape) @@ -263,8 +263,8 @@ def test_cutout_2d(backend: str): cutouts.set(cutouts.get()) -@with_backends('cpu', 'jax', 'cuda') -def test_cutout_multidim(backend: str): +@with_backends('numpy', 'jax', 'cupy', 'torch') +def test_cutout_multidim(backend: BackendName): samp = ObjectSampling((200, 200), (1.0, 1.0)) cutout_shape = (80, 100) @@ -296,25 +296,25 @@ def test_cutout_multidim(backend: str): cutouts.set(cutouts.get()) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @pytest.mark.parametrize('dtype', ('float', 'complex', 'uint8')) @check_array_equals_file('object_set_views_{dtype}.tiff', out_name='object_set_views_{dtype}_{backend}.tiff') -def test_set_view_at_pos(backend: str, dtype: str) -> numpy.ndarray: +def test_set_view_at_pos(backend: BackendName, dtype: str) -> numpy.ndarray: samp = ObjectSampling((200, 200), (1.0, 1.0)) cutout_shape = (64, 64) xp = get_backend_module(backend) if dtype == 'uint8': - obj = xp.zeros(samp.shape, dtype=numpy.uint8) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.uint8) cutouts = xp.full((30, *cutout_shape), 15, dtype=numpy.uint8) mag = 15 elif dtype == 'float': - obj = xp.zeros(samp.shape, dtype=numpy.float32) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.float32) cutouts = xp.full((30, *cutout_shape), 10., dtype=numpy.float32) mag = 10. elif dtype == 'complex': - obj = xp.zeros(samp.shape, dtype=numpy.complex64) + obj = xp.zeros(tuple(samp.shape), dtype=numpy.complex64) cutouts = xp.full((30, *cutout_shape), 10. + 15.j, dtype=numpy.complex64) mag = abs2(10. + 15.j) else: diff --git a/tests/test_optics.py b/tests/test_optics.py index caeb980..cc85354 100644 --- a/tests/test_optics.py +++ b/tests/test_optics.py @@ -1,26 +1,26 @@ import numpy -from .utils import with_backends, get_backend_module, check_array_equals_file +from .utils import with_backends, check_array_equals_file -from phaser.utils.num import Sampling, to_numpy, fft2, ifft2 +from phaser.utils.num import get_backend_module, BackendName, Sampling, to_numpy, fft2, ifft2 from phaser.utils.optics import make_focused_probe, fresnel_propagator -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @check_array_equals_file('probe_10mrad_focused_mag.tiff', decimal=5) -def test_focused_probe(backend: str) -> numpy.ndarray: +def test_focused_probe(backend: BackendName) -> numpy.ndarray: xp = get_backend_module(backend) sampling = Sampling((1024, 1024), extent=(25., 25.)) probe = make_focused_probe(*sampling.recip_grid(dtype=numpy.float32, xp=xp), wavelength=0.0251, aperture=10.) - return to_numpy(numpy.abs(probe)) + return to_numpy(xp.abs(probe)) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @check_array_equals_file('probe_10mrad_20over.tiff', decimal=5) -def test_defocused_probe(backend: str) -> numpy.ndarray: +def test_defocused_probe(backend: BackendName) -> numpy.ndarray: xp = get_backend_module(backend) sampling = Sampling((1024, 1024), extent=(25., 25.)) @@ -29,9 +29,9 @@ def test_defocused_probe(backend: str) -> numpy.ndarray: return to_numpy(probe) -@with_backends('cpu', 'jax', 'cuda') -@check_array_equals_file('fresnel_200kV_1nm_phase.tiff', decimal=8) -def test_fresnel_propagator(backend: str) -> numpy.ndarray: +@with_backends('numpy', 'jax', 'cupy', 'torch') +@check_array_equals_file('fresnel_200kV_1nm_phase.tiff', decimal=5) +def test_fresnel_propagator(backend: BackendName) -> numpy.ndarray: xp = get_backend_module(backend) sampling = Sampling((1024, 1024), extent=(100., 100.)) @@ -41,9 +41,9 @@ def test_fresnel_propagator(backend: str) -> numpy.ndarray: )) -@with_backends('cpu', 'jax', 'cuda') +@with_backends('numpy', 'jax', 'cupy', 'torch') @check_array_equals_file('probe_10mrad_focused_mag.tiff', decimal=5) -def test_propagator_sign(backend: str) -> numpy.ndarray: +def test_propagator_sign(backend: BackendName) -> numpy.ndarray: xp = get_backend_module(backend) sampling = Sampling((1024, 1024), extent=(25., 25.)) @@ -55,4 +55,4 @@ def test_propagator_sign(backend: str) -> numpy.ndarray: prop = fresnel_propagator(ky, kx, wavelength=0.0251, delta_z=200.) probe = ifft2(fft2(probe) * prop) - return to_numpy(numpy.abs(probe)) \ No newline at end of file + return to_numpy(xp.abs(probe)) diff --git a/tests/utils.py b/tests/utils.py index fcb8d08..2cd7bb8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,42 +137,6 @@ def wrapper(*args, file_contents_array: numpy.ndarray, **kwargs): return decorator -def get_backend_module(backend: str): - """Get the module `xp` associated with a compute backend""" - backend = backend.lower() - if backend not in ('cuda', 'jax', 'cpu'): - raise ValueError(f"Unknown backend '{backend}'") - - if not t.TYPE_CHECKING: - if backend == 'jax': - import jax.numpy - return jax.numpy - if backend == 'cuda': - import cupy - return cupy - - import numpy - return numpy - - -def get_backend_scipy(backend: str): - """Get the scipy module associated with a compute backend""" - backend = backend.lower() - if backend not in ('cuda', 'jax', 'cpu'): - raise ValueError(f"Unknown backend '{backend}'") - - if not t.TYPE_CHECKING: - if backend == 'jax': - import jax.scipy - return jax.scipy - if backend == 'cuda': - import cupyx.scipy - return cupyx.scipy - - import scipy - return scipy - - _import = builtins.__import__