Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Tests

on:
push:
branches: [main]
branches: [main, develop]
tags: ["*"]
pull_request:
workflow_dispatch:
Expand Down Expand Up @@ -32,6 +32,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,jax,web]'
pip install -e '.[dev,jax,torch,web]'
- name: Test
run: pytest
12 changes: 10 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import pytest
import numpy

BACKENDS: set[str] = set(('cpu', 'jax', 'cuda'))
BACKENDS: set[str] = set(('cpu', 'jax', 'cupy', 'torch'))
AVAILABLE_BACKENDS: set[str] = set(('cpu',))

try:
import cupy # pyright: ignore[reportMissingImports]
if cupy.cuda.runtime.getDeviceCount() > 0:
AVAILABLE_BACKENDS.add('cuda')
AVAILABLE_BACKENDS.add('cupy')
except ImportError:
pass

Expand All @@ -19,6 +19,14 @@
pass


try:
import torch # pyright: ignore[reportMissingImports] # noqa: F401
torch.asarray([1, 2, 3, 4]).numpy(force=True) # ensures torch is loaded, fixes a strange error with pytest
AVAILABLE_BACKENDS.add('torch')
except ImportError:
pass


def pytest_addoption(parser: pytest.Parser):
parser.addoption("--save-expected", action="store_true", dest='save-expected', default=False,
help="Overwrite expected files with the results of tests.")
Expand Down
7 changes: 4 additions & 3 deletions examples/mos2_epie.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
---
name: "mos2_epie"
backend: cupy
backend: torch

# raw data source
raw_data:
type: empad
path: "sample_data/simulated_mos2/mos2_0.00_dstep1.0.json"
path: "~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json"

post_load:
- type: poisson
Expand All @@ -32,7 +32,8 @@ engines:
beta_probe: 0.5

group_constraints: []
iter_constraints: []
iter_constraints:
- type: remove_phase_ramp

update_probe: {after: 5}

Expand Down
19 changes: 19 additions & 0 deletions notebooks/conventions.ipynb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion phaser/engines/common/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from phaser.hooks.solver import NoiseModel
from phaser.plan import AmplitudeNoisePlan, AnscombeNoisePlan, PoissonNoisePlan
from phaser.utils.num import get_array_module, Float
from phaser.utils.num import get_array_module, Float, to_numpy
from phaser.state import ReconsState


Expand Down
24 changes: 13 additions & 11 deletions phaser/engines/common/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
import logging
from math import prod
import typing as t

import numpy
Expand Down Expand Up @@ -205,7 +206,7 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(sim.object.data - 1.0))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state)


Expand All @@ -222,8 +223,9 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(abs2(sim.object.data - 1.0))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
return (cost * cost_scale * self.cost, state)

cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state) # type: ignore


class ObjPhaseL1:
Expand All @@ -239,7 +241,7 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(xp.angle(sim.object.data)))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state)


Expand All @@ -261,7 +263,7 @@ def calc_loss_group(
xp.abs(fft2(xp.prod(sim.object.data, axis=0)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand Down Expand Up @@ -289,7 +291,7 @@ def calc_loss_group(
#)
# scale cost by fraction of the total reconstruction in the group
# TODO also scale by # of pixels or similar?
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -311,9 +313,9 @@ def calc_loss_group(
xp.sum(abs2(xp.diff(sim.object.data, axis=-2)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)
return (cost * cost_scale * self.cost, state) # type: ignore


class LayersTotalVariation:
Expand All @@ -333,7 +335,7 @@ def calc_loss_group(

cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -355,9 +357,9 @@ def calc_loss_group(

cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)
return (cost * cost_scale * self.cost, state) # type: ignore


class ProbePhaseTikhonov:
Expand Down
11 changes: 6 additions & 5 deletions phaser/engines/common/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import typing as t

import numpy
from numpy.typing import NDArray, DTypeLike
from numpy.typing import NDArray
from typing_extensions import Self

from phaser.utils.num import (
get_array_module, to_real_dtype, to_complex_dtype,
fft2, ifft2, is_jax, to_numpy, block_until_ready, ufunc_outer
)
from phaser.utils.misc import FloatKey, jax_dataclass, create_compact_groupings, create_sparse_groupings, shuffled
from phaser.utils.tree import tree_dataclass
from phaser.utils.misc import FloatKey, create_compact_groupings, create_sparse_groupings, shuffled
from phaser.utils.optics import fresnel_propagator, fourier_shift_filter
from phaser.state import ReconsState
from phaser.hooks.solver import NoiseModel
Expand Down Expand Up @@ -83,7 +84,7 @@ def stream_patterns(
continue


@jax_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx'))
@tree_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx'))
class SimulationState:
state: ReconsState

Expand All @@ -99,7 +100,7 @@ class SimulationState:
iter_constraint_states: t.Tuple[t.Any, ...]

xp: t.Any
dtype: DTypeLike
dtype: t.Type[numpy.floating]
start_iter: int

def __init__(
Expand All @@ -109,7 +110,7 @@ def __init__(
group_constraints: t.Tuple[GroupConstraint[t.Any], ...],
iter_constraints: t.Tuple[IterConstraint[t.Any], ...],
xp: t.Any,
dtype: DTypeLike,
dtype: t.Type[numpy.floating],
noise_model_state: t.Optional[t.Any] = None,
group_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None,
iter_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None,
Expand Down
19 changes: 10 additions & 9 deletions phaser/engines/conventional/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy

from phaser.utils.misc import mask_fraction_of_groups
from phaser.utils.num import cast_array_module, to_numpy, to_complex_dtype
from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype
from phaser.observer import Observer
from phaser.hooks import EngineArgs
from phaser.plan import ConventionalEnginePlan
Expand All @@ -17,6 +17,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:

xp = cast_array_module(args['xp'])
dtype = args['dtype']
cdtype = to_complex_dtype(dtype)
observer: Observer = args.get('observer', Observer())
seed = args['seed']

Expand All @@ -37,12 +38,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:
xp=xp, dtype=dtype
)
patterns = args['data'].patterns
pattern_mask = xp.array(args['data'].pattern_mask)
pattern_mask = xp.asarray(args['data'].pattern_mask)

assert patterns.dtype == sim.dtype
assert pattern_mask.dtype == sim.dtype
assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)
assert_dtype(patterns, dtype)
assert_dtype(pattern_mask, dtype)
assert_dtype(sim.state.object.data, cdtype)
assert_dtype(sim.state.probe.data, cdtype)

solver = props.solver(props)
sim = solver.init(sim)
Expand Down Expand Up @@ -87,8 +88,8 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:
calc_error_mask=calc_error_mask,
observer=observer,
)
assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)
assert_dtype(sim.state.object.data, cdtype)
assert_dtype(sim.state.probe.data, cdtype)

sim = sim.apply_iter_constraints()

Expand All @@ -104,7 +105,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState:
update_mag = xp.linalg.norm(pos_update, axis=-1, keepdims=True)
logger.info(f"Position update: mean {xp.mean(update_mag)}")
sim.state.scan += pos_update
assert sim.state.scan.dtype == sim.dtype
assert_dtype(sim.state.scan, dtype)

# check positions are at least overlapping object
sim.state.object.sampling.check_scan(sim.state.scan, sim.state.probe.sampling.extent / 2.)
Expand Down
8 changes: 4 additions & 4 deletions phaser/engines/conventional/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def run_iteration(
gamma=gamma,
)
check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}")
assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)
#assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
#assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)

sim = sim.apply_group_constraints(group)

Expand Down Expand Up @@ -365,8 +365,8 @@ def run_iteration(
update_probe=update_probe,
)
check_finite(sim.state.object.data, sim.state.probe.data, context=f"object or probe, group {group_i}")
assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)
#assert sim.state.object.data.dtype == to_complex_dtype(sim.dtype)
#assert sim.state.probe.data.dtype == to_complex_dtype(sim.dtype)

sim = sim.apply_group_constraints(group)

Expand Down
Loading