From 1b34b078627503eed2c1ee66602055af5c97103c Mon Sep 17 00:00:00 2001 From: mlz Date: Wed, 3 Sep 2025 16:32:25 -0400 Subject: [PATCH 1/2] working with a shared probe --- phaser/hooks/__init__.py | 6 ++++++ phaser/hooks/scan.py | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index ff86e08..f68ae51 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -120,9 +120,15 @@ class RasterScanProps(Dataclass): affine: t.Optional[t.Annotated[NDArray[numpy.floating], annotations.shape((2, 2))]] = None +class CustomScanProps(Dataclass): + path: str + """Path to .npy file containing scan array matching the size of the scan""" + + class ScanHook(Hook[ScanHookArgs, NDArray[numpy.floating]]): known = { 'raster': ('phaser.hooks.scan:raster_scan', RasterScanProps), + 'custom': ('phaser.hooks.scan:load_custom_scan', CustomScanProps), } diff --git a/phaser/hooks/scan.py b/phaser/hooks/scan.py index 6e402fb..2af16a8 100644 --- a/phaser/hooks/scan.py +++ b/phaser/hooks/scan.py @@ -1,10 +1,11 @@ import numpy from numpy.typing import NDArray +from pathlib import Path from phaser.utils.num import cast_array_module from phaser.utils.scan import make_raster_scan -from . import ScanHookArgs, RasterScanProps +from . import ScanHookArgs, RasterScanProps, CustomScanProps def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.floating]: @@ -25,4 +26,39 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.flo # equivalent to (affine @ scan.T).T (active transformation) scan = scan @ affine.T - return scan \ No newline at end of file + numpy.save('scan', scan) + return scan + + +def load_custom_scan(args: ScanHookArgs, props: CustomScanProps) -> NDArray[numpy.floating]: +# def load_custom_tilt(args: TiltHookArgs, props: CustomTiltProps) -> NDArray[numpy.floating]: + """ + Load scan array from a .npy file. + + The loaded array can have shape (ny, nx, 2) matching props.shape, + or shape (N, 2) where N == ny*nx, which will be reshaped accordingly. + """ + xp = cast_array_module(args['xp']) + + path = Path(props.path).expanduser() + if not path.exists(): + raise FileNotFoundError(f"Custom scan file not found: {path}") + + scan = numpy.load(path) + + # shape = args['shape'] + # expected_shape_3d = (*shape, 2) + # expected_shape_2d = (numpy.prod(shape), 2) + + # if tilt_data.ndim == 3: + # if tilt_data.shape != expected_shape_3d: + # raise ValueError(f"Loaded tilt data shape {tilt_data.shape} does not match expected shape {expected_shape_3d}") + # result = tilt_data + # elif tilt_data.ndim == 2: + # if tilt_data.shape != expected_shape_2d: + # raise ValueError(f"Loaded tilt data shape {tilt_data.shape} is incompatible with expected 2D shape {expected_shape_2d}") + # result = tilt_data.reshape(expected_shape_3d) + # else: + # raise ValueError(f"Loaded tilt data must be 2D or 3D array, got shape {tilt_data.shape}") + print("loaded scan") + return xp.array(scan, dtype=xp.float32) From a3606211aa4e0efdf0833ae99362832dd9390964 Mon Sep 17 00:00:00 2001 From: mlz Date: Thu, 4 Sep 2025 10:01:53 -0400 Subject: [PATCH 2/2] hard coded I/O --- phaser/engines/common/output.py | 8 +++--- phaser/engines/common/simulation.py | 15 +++++----- phaser/engines/gradient/run.py | 21 ++++++++------ phaser/execute.py | 44 ++++++++++++++++++++--------- phaser/hooks/__init__.py | 1 + phaser/hooks/io/empad.py | 1 + phaser/state.py | 3 +- phaser/utils/io.py | 11 ++++---- 8 files changed, 65 insertions(+), 39 deletions(-) diff --git a/phaser/engines/common/output.py b/phaser/engines/common/output.py index 7100d18..dc18de7 100644 --- a/phaser/engines/common/output.py +++ b/phaser/engines/common/output.py @@ -46,7 +46,7 @@ def output_state(state: ReconsState, out_dir: Path, options: SaveOptions): def _save_probe(state: ReconsState, out_path: Path, options: SaveOptions): - probe = to_numpy(state.probe.data) + probe = to_numpy(state.probe.data[0]) write_opts = tiff_write_opts(state.probe.sampling, n_slices=probe.shape[0]) if options.img_dtype == 'float': @@ -65,7 +65,7 @@ def _save_probe(state: ReconsState, out_path: Path, options: SaveOptions): def _save_probe_mag(state: ReconsState, out_path: Path, options: SaveOptions): - probe_mag = abs2(state.probe.data) + probe_mag = abs2(state.probe.data[0]) write_opts = tiff_write_opts(state.probe.sampling, n_slices=probe_mag.shape[0]) if options.img_dtype != 'float': @@ -78,7 +78,7 @@ def _save_probe_mag(state: ReconsState, out_path: Path, options: SaveOptions): def _save_probe_recip(state: ReconsState, out_path: Path, options: SaveOptions): xp = get_array_module(state.probe.data) - probe = to_numpy(xp.fft.fftshift(fft2(state.probe.data), axes=(-1, -2))) + probe = to_numpy(xp.fft.fftshift(fft2(state.probe.data[0]), axes=(-1, -2))) write_opts = tiff_write_opts_recip(state.probe.sampling, n_slices=probe.shape[0]) if options.img_dtype == 'float': @@ -98,7 +98,7 @@ def _save_probe_recip(state: ReconsState, out_path: Path, options: SaveOptions): def _save_probe_recip_mag(state: ReconsState, out_path: Path, options: SaveOptions): xp = get_array_module(state.probe.data) - probe_mag = to_numpy(abs2(xp.fft.fftshift(fft2(state.probe.data), axes=(-1, -2)))) + probe_mag = to_numpy(abs2(xp.fft.fftshift(fft2(state.probe.data[0]), axes=(-1, -2)))) write_opts = tiff_write_opts_recip(state.probe.sampling, n_slices=probe_mag.shape[0]) if options.img_dtype != 'float': diff --git a/phaser/engines/common/simulation.py b/phaser/engines/common/simulation.py index fb88163..dc1f7be 100644 --- a/phaser/engines/common/simulation.py +++ b/phaser/engines/common/simulation.py @@ -54,31 +54,32 @@ def __len__(self) -> int: def stream_patterns( - groups: t.Iterable[NDArray[numpy.int64]], patterns: NDArray[numpy.floating], + groups: t.Iterable[NDArray[numpy.int64]], patterns: NDArray[numpy.floating], patterns_id: NDArray[numpy.integer], xp: t.Any, buf_n: int = 1 -) -> t.Iterator[t.Tuple[NDArray[numpy.int64], NDArray[numpy.floating]]]: +) -> t.Iterator[t.Tuple[NDArray[numpy.int64], NDArray[numpy.floating], NDArray[numpy.integer]]]: if buf_n == 0: for group in groups: group_patterns = xp.asarray(patterns[tuple(group)]) - yield group, block_until_ready(group_patterns) + group_patterns_id = xp.asarray(patterns_id[tuple(group)]) + yield group, block_until_ready(group_patterns), block_until_ready(group_patterns_id) return buf = collections.deque() it = iter(groups) for group in it: - buf.append((group, xp.asarray(patterns[tuple(group)]))) + buf.append((group, xp.asarray(patterns[tuple(group)]), xp.asarray(patterns_id[tuple(group)]))) if len(buf) >= buf_n: break while len(buf) > 0: - (group, group_patterns) = buf.popleft() - yield group, block_until_ready(group_patterns) + (group, group_patterns, group_patterns_id) = buf.popleft() + yield group, block_until_ready(group_patterns), block_until_ready(group_patterns_id) # attempt to feed queue try: group = next(it) - buf.append((group, xp.asarray(patterns[tuple(group)]))) + buf.append((group, xp.asarray(patterns[tuple(group)]), xp.asarray(patterns_id[tuple(group)]))) except StopIteration: continue diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 84a4cce..2b41bb7 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -166,6 +166,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: seed = args['seed'] patterns = args['data'].patterns pattern_mask = args['data'].pattern_mask + patterns_id = args['data'].patterns_id noise_model = props.noise_model(None) @@ -195,9 +196,9 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: # runs rescaling rescale_factors = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan), - patterns, xp=xp, buf_n=props.buffer_n_groups)): - group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype) + for (group_i, (group, group_patterns, group_patterns_id)) in enumerate(stream_patterns(groups.iter(state.scan), + patterns, patterns_id, xp=xp, buf_n=props.buffer_n_groups)): + group_rescale_factors = dry_run(state, group, propagators, group_patterns, group_patterns_id, xp=xp, dtype=dtype) rescale_factors.append(group_rescale_factors) rescale_factors = xp.concatenate(rescale_factors, axis=0) @@ -240,8 +241,8 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: ] losses = [] - for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups), - patterns, xp=xp, buf_n=props.buffer_n_groups)): + for (group_i, (group, group_patterns, group_patterns_id)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups), + patterns, patterns_id, xp=xp, buf_n=props.buffer_n_groups)): (state, loss, iter_grads, solver_states) = run_group( state, group=group, vars=iter_vars, noise_model=noise_model, @@ -252,6 +253,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState: solver_states=solver_states, props=propagators, group_patterns=group_patterns, #load_group(group), + group_patterns_id=group_patterns_id, pattern_mask=pattern_mask, probe_int=probe_int, xp=xp, dtype=dtype @@ -307,6 +309,7 @@ def run_group( solver_states: SolverStates, props: t.Optional[NDArray[numpy.complexfloating]], group_patterns: NDArray[numpy.floating], + group_patterns_id: NDArray[numpy.integer], pattern_mask: NDArray[numpy.floating], probe_int: t.Union[float, numpy.floating], xp: t.Any, @@ -317,7 +320,7 @@ def run_group( ((loss, solver_states), grad) = jax.value_and_grad(run_model, has_aux=True)( *extract_vars(state, vars, group), - group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask, + group=group, props=props, group_patterns=group_patterns, group_patterns_id=group_patterns_id, pattern_mask=pattern_mask, noise_model=noise_model, regularizers=regularizers, solver_states=solver_states, xp=xp, dtype=dtype ) @@ -360,6 +363,7 @@ def run_model( group: NDArray[numpy.integer], props: t.Optional[NDArray[numpy.complexfloating]], # base propagator, shape (n_slices-1, ny, nx) group_patterns: NDArray[numpy.floating], + group_patterns_id: NDArray[numpy.integer], pattern_mask: NDArray[numpy.floating], noise_model: NoiseModel[t.Any], regularizers: t.Sequence[CostRegularizer[t.Any]], @@ -380,7 +384,7 @@ def run_model( probes = sim.probe.data group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:]) group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))[:, None, ...] - probes = ifft2(fft2(probes) * group_subpx_filters) + probes = ifft2(fft2(probes[group_patterns_id]) * group_subpx_filters) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): # psi: (batch, n_probe, Ny, Nx) @@ -415,6 +419,7 @@ def dry_run( group: NDArray[numpy.integer], props: t.Optional[NDArray[numpy.complexfloating]], group_patterns: NDArray[numpy.floating], + group_patterns_id: NDArray[numpy.integer], xp: t.Any, dtype: t.Type[numpy.floating], ) -> NDArray[numpy.floating]: @@ -423,7 +428,7 @@ def dry_run( probes = sim.probe.data group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:]) group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...] - probes = ifft2(fft2(probes) * group_subpx_filters) + probes = ifft2(fft2(probes[group_patterns_id]) * group_subpx_filters) def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi): if prop is not None: diff --git a/phaser/execute.py b/phaser/execute.py index 597d244..2dc6bb3 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -254,7 +254,7 @@ def initialize_reconstruction( raw_data = load_raw_data(plan, xp, seed, init_state=init_state) - data = Patterns(raw_data['patterns'], raw_data['mask']) + data = Patterns(raw_data['patterns'], raw_data['mask'], raw_data['patterns_id']) sampling = raw_data['sampling'] wavelength = unwrap(raw_data.get('wavelength', None)) probe_hook = raw_data.get('probe_hook', None) @@ -279,9 +279,19 @@ def initialize_reconstruction( probe = pane.from_data(probe_hook, ProbeHook)( # type: ignore {'sampling': sampling, 'wavelength': wavelength, 'dtype': dtype, 'seed': seed, 'xp': xp} ) + + ## (num_scan, probe_modes, *probe_shape) + ##TODO hardcoded for now to num_scan=3 + probe.data = probe.data.reshape(1, *probe.data.shape) + probe.data = xp.tile(probe.data, (3, 1, 1, 1)) + if probe.data.ndim == 2: + probe.data = probe.data.reshape((1, 1, *probe.data.shape)) + elif probe.data.ndim == 3: probe.data = probe.data.reshape((1, *probe.data.shape)) + print(probe.data.shape) + if init_state.scan is not None and plan.init.scan is None: logging.info("Re-using scan from initial state...") scan = init_state.scan @@ -399,19 +409,25 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine state.object.data = state.object.sampling.resample(state.object.data, obj_sampling) state.object.sampling = obj_sampling - current_probe_modes = state.probe.data.shape[0] - if engine.probe_modes != current_probe_modes: - # fix probe modes - if engine.probe_modes < current_probe_modes: - # TODO: redistribute intensity here - state.probe.data = state.probe.data[:engine.probe_modes] - else: - from phaser.utils.optics import make_hermetian_modes - if current_probe_modes != 1: - logging.info("Summing probe modes (in real-space) before recreating with different # of modes") - - base_mode = xp.sum(state.probe.data, axis=0) - state.probe.data = make_hermetian_modes(base_mode, engine.probe_modes, base_mode_power=engine.base_mode_power) + new_probe_data = [] + + for num_scan in range(state.probe.data.shape[0]): + current_probe_modes = state.probe.data.shape[1] + if engine.probe_modes != current_probe_modes: + # fix probe modes + if engine.probe_modes < current_probe_modes: + # TODO: redistribute intensity here + new_probe_data.append(state.probe.data[num_scan, :engine.probe_modes]) + else: + from phaser.utils.optics import make_hermetian_modes + if current_probe_modes != 1: + logging.info("Summing probe modes (in real-space) before recreating with different # of modes") + + base_mode = xp.sum(state.probe.data[num_scan], axis=0) + new_probe_data.append(make_hermetian_modes(base_mode, engine.probe_modes, base_mode_power=engine.base_mode_power)) + new_probe_data = xp.stack(new_probe_data, axis=0) + state.probe.data = new_probe_data + print(state.probe.data.shape) if engine.slices is not None and (len(engine.slices.thicknesses) != len(state.object.thicknesses) or not numpy.allclose(engine.slices.thicknesses, state.object.thicknesses)): diff --git a/phaser/hooks/__init__.py b/phaser/hooks/__init__.py index f68ae51..5c16eab 100644 --- a/phaser/hooks/__init__.py +++ b/phaser/hooks/__init__.py @@ -18,6 +18,7 @@ class RawData(t.TypedDict): patterns: NDArray[numpy.floating] + patterns_id: NDArray[numpy.integer] mask: NDArray[numpy.floating] sampling: 'Sampling' wavelength: NotRequired[t.Optional[float]] diff --git a/phaser/hooks/io/empad.py b/phaser/hooks/io/empad.py index 488eafc..40b66b3 100644 --- a/phaser/hooks/io/empad.py +++ b/phaser/hooks/io/empad.py @@ -88,6 +88,7 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData: return { 'patterns': patterns, + 'patterns_id': numpy.load('id.npy').astype(numpy.int8), 'mask': numpy.fft.ifftshift(mask, axes=(-1, -2)), 'sampling': sampling, 'wavelength': wavelength, diff --git a/phaser/state.py b/phaser/state.py index f09b98a..8b76214 100644 --- a/phaser/state.py +++ b/phaser/state.py @@ -20,10 +20,11 @@ class Patterns(): """Raw diffraction patterns, with 0-frequency sample in corner""" pattern_mask: NDArray[numpy.floating] """Mask indicating which portions of the diffraction patterns contain data.""" + patterns_id: NDArray[numpy.integer] def to_numpy(self) -> Self: return self.__class__( - to_numpy(self.patterns), to_numpy(self.pattern_mask) + to_numpy(self.patterns), to_numpy(self.pattern_mask), to_numpy(self.patterns_id) ) diff --git a/phaser/utils/io.py b/phaser/utils/io.py index 700432a..b40e113 100644 --- a/phaser/utils/io.py +++ b/phaser/utils/io.py @@ -121,7 +121,7 @@ def hdf5_read_state(file: HdfLike) -> PartialReconsState: def hdf5_read_probe_state(group: h5py.Group) -> ProbeState: probes = _hdf5_read_dataset(group, 'data', numpy.complexfloating) - assert probes.ndim == 3 + assert probes.ndim == 4 extent = _hdf5_read_dataset_shape(group, 'extent', numpy.float64, (2,)) (n_y, n_x) = probes.shape[-2:] @@ -194,11 +194,12 @@ def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfL def hdf5_write_probe_state(state: ProbeState, group: h5py.Group): - assert state.data.ndim == 3 + assert state.data.ndim == 4 dataset = group.create_dataset('data', data=to_numpy(state.data)) - dataset.dims[0].label = 'mode' - dataset.dims[1].label = 'y' - dataset.dims[2].label = 'x' + dataset.dims[0].label = 'scan_num' + dataset.dims[1].label = 'mode' + dataset.dims[2].label = 'y' + dataset.dims[3].label = 'x' group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64)) group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64))