Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions phaser/engines/common/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def output_state(state: ReconsState, out_dir: Path, options: SaveOptions):


def _save_probe(state: ReconsState, out_path: Path, options: SaveOptions):
probe = to_numpy(state.probe.data)
probe = to_numpy(state.probe.data[0])
write_opts = tiff_write_opts(state.probe.sampling, n_slices=probe.shape[0])

if options.img_dtype == 'float':
Expand All @@ -65,7 +65,7 @@ def _save_probe(state: ReconsState, out_path: Path, options: SaveOptions):


def _save_probe_mag(state: ReconsState, out_path: Path, options: SaveOptions):
probe_mag = abs2(state.probe.data)
probe_mag = abs2(state.probe.data[0])
write_opts = tiff_write_opts(state.probe.sampling, n_slices=probe_mag.shape[0])

if options.img_dtype != 'float':
Expand All @@ -78,7 +78,7 @@ def _save_probe_mag(state: ReconsState, out_path: Path, options: SaveOptions):

def _save_probe_recip(state: ReconsState, out_path: Path, options: SaveOptions):
xp = get_array_module(state.probe.data)
probe = to_numpy(xp.fft.fftshift(fft2(state.probe.data), axes=(-1, -2)))
probe = to_numpy(xp.fft.fftshift(fft2(state.probe.data[0]), axes=(-1, -2)))
write_opts = tiff_write_opts_recip(state.probe.sampling, n_slices=probe.shape[0])

if options.img_dtype == 'float':
Expand All @@ -98,7 +98,7 @@ def _save_probe_recip(state: ReconsState, out_path: Path, options: SaveOptions):

def _save_probe_recip_mag(state: ReconsState, out_path: Path, options: SaveOptions):
xp = get_array_module(state.probe.data)
probe_mag = to_numpy(abs2(xp.fft.fftshift(fft2(state.probe.data), axes=(-1, -2))))
probe_mag = to_numpy(abs2(xp.fft.fftshift(fft2(state.probe.data[0]), axes=(-1, -2))))
write_opts = tiff_write_opts_recip(state.probe.sampling, n_slices=probe_mag.shape[0])

if options.img_dtype != 'float':
Expand Down
15 changes: 8 additions & 7 deletions phaser/engines/common/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,32 @@ def __len__(self) -> int:


def stream_patterns(
groups: t.Iterable[NDArray[numpy.int64]], patterns: NDArray[numpy.floating],
groups: t.Iterable[NDArray[numpy.int64]], patterns: NDArray[numpy.floating], patterns_id: NDArray[numpy.integer],
xp: t.Any, buf_n: int = 1
) -> t.Iterator[t.Tuple[NDArray[numpy.int64], NDArray[numpy.floating]]]:
) -> t.Iterator[t.Tuple[NDArray[numpy.int64], NDArray[numpy.floating], NDArray[numpy.integer]]]:
if buf_n == 0:
for group in groups:
group_patterns = xp.asarray(patterns[tuple(group)])
yield group, block_until_ready(group_patterns)
group_patterns_id = xp.asarray(patterns_id[tuple(group)])
yield group, block_until_ready(group_patterns), block_until_ready(group_patterns_id)
return

buf = collections.deque()
it = iter(groups)

for group in it:
buf.append((group, xp.asarray(patterns[tuple(group)])))
buf.append((group, xp.asarray(patterns[tuple(group)]), xp.asarray(patterns_id[tuple(group)])))
if len(buf) >= buf_n:
break

while len(buf) > 0:
(group, group_patterns) = buf.popleft()
yield group, block_until_ready(group_patterns)
(group, group_patterns, group_patterns_id) = buf.popleft()
yield group, block_until_ready(group_patterns), block_until_ready(group_patterns_id)

# attempt to feed queue
try:
group = next(it)
buf.append((group, xp.asarray(patterns[tuple(group)])))
buf.append((group, xp.asarray(patterns[tuple(group)]), xp.asarray(patterns_id[tuple(group)])))
except StopIteration:
continue

Expand Down
21 changes: 13 additions & 8 deletions phaser/engines/gradient/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
seed = args['seed']
patterns = args['data'].patterns
pattern_mask = args['data'].pattern_mask
patterns_id = args['data'].patterns_id

noise_model = props.noise_model(None)

Expand Down Expand Up @@ -195,9 +196,9 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:

# runs rescaling
rescale_factors = []
for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan),
patterns, xp=xp, buf_n=props.buffer_n_groups)):
group_rescale_factors = dry_run(state, group, propagators, group_patterns, xp=xp, dtype=dtype)
for (group_i, (group, group_patterns, group_patterns_id)) in enumerate(stream_patterns(groups.iter(state.scan),
patterns, patterns_id, xp=xp, buf_n=props.buffer_n_groups)):
group_rescale_factors = dry_run(state, group, propagators, group_patterns, group_patterns_id, xp=xp, dtype=dtype)
rescale_factors.append(group_rescale_factors)

rescale_factors = xp.concatenate(rescale_factors, axis=0)
Expand Down Expand Up @@ -240,8 +241,8 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
]
losses = []

for (group_i, (group, group_patterns)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups),
patterns, xp=xp, buf_n=props.buffer_n_groups)):
for (group_i, (group, group_patterns, group_patterns_id)) in enumerate(stream_patterns(groups.iter(state.scan, i, iter_shuffle_groups),
patterns, patterns_id, xp=xp, buf_n=props.buffer_n_groups)):
(state, loss, iter_grads, solver_states) = run_group(
state, group=group, vars=iter_vars,
noise_model=noise_model,
Expand All @@ -252,6 +253,7 @@ def run_engine(args: EngineArgs, props: GradientEnginePlan) -> ReconsState:
solver_states=solver_states,
props=propagators,
group_patterns=group_patterns, #load_group(group),
group_patterns_id=group_patterns_id,
pattern_mask=pattern_mask,
probe_int=probe_int,
xp=xp, dtype=dtype
Expand Down Expand Up @@ -307,6 +309,7 @@ def run_group(
solver_states: SolverStates,
props: t.Optional[NDArray[numpy.complexfloating]],
group_patterns: NDArray[numpy.floating],
group_patterns_id: NDArray[numpy.integer],
pattern_mask: NDArray[numpy.floating],
probe_int: t.Union[float, numpy.floating],
xp: t.Any,
Expand All @@ -317,7 +320,7 @@ def run_group(

((loss, solver_states), grad) = jax.value_and_grad(run_model, has_aux=True)(
*extract_vars(state, vars, group),
group=group, props=props, group_patterns=group_patterns, pattern_mask=pattern_mask,
group=group, props=props, group_patterns=group_patterns, group_patterns_id=group_patterns_id, pattern_mask=pattern_mask,
noise_model=noise_model, regularizers=regularizers, solver_states=solver_states,
xp=xp, dtype=dtype
)
Expand Down Expand Up @@ -360,6 +363,7 @@ def run_model(
group: NDArray[numpy.integer],
props: t.Optional[NDArray[numpy.complexfloating]], # base propagator, shape (n_slices-1, ny, nx)
group_patterns: NDArray[numpy.floating],
group_patterns_id: NDArray[numpy.integer],
pattern_mask: NDArray[numpy.floating],
noise_model: NoiseModel[t.Any],
regularizers: t.Sequence[CostRegularizer[t.Any]],
Expand All @@ -380,7 +384,7 @@ def run_model(
probes = sim.probe.data
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, group_scan, probes.shape[-2:])
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(group_scan, probes.shape[-2:]))[:, None, ...]
probes = ifft2(fft2(probes) * group_subpx_filters)
probes = ifft2(fft2(probes[group_patterns_id]) * group_subpx_filters)

def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
# psi: (batch, n_probe, Ny, Nx)
Expand Down Expand Up @@ -415,6 +419,7 @@ def dry_run(
group: NDArray[numpy.integer],
props: t.Optional[NDArray[numpy.complexfloating]],
group_patterns: NDArray[numpy.floating],
group_patterns_id: NDArray[numpy.integer],
xp: t.Any,
dtype: t.Type[numpy.floating],
) -> NDArray[numpy.floating]:
Expand All @@ -423,7 +428,7 @@ def dry_run(
probes = sim.probe.data
group_obj = sim.object.sampling.get_view_at_pos(sim.object.data, sim.scan[tuple(group)], probes.shape[-2:])
group_subpx_filters = fourier_shift_filter(ky, kx, sim.object.sampling.get_subpx_shifts(sim.scan[tuple(group)], probes.shape[-2:]))[:, None, ...]
probes = ifft2(fft2(probes) * group_subpx_filters)
probes = ifft2(fft2(probes[group_patterns_id]) * group_subpx_filters)

def sim_slice(slice_i: int, prop: t.Optional[NDArray[numpy.complexfloating]], psi):
if prop is not None:
Expand Down
44 changes: 30 additions & 14 deletions phaser/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def initialize_reconstruction(

raw_data = load_raw_data(plan, xp, seed, init_state=init_state)

data = Patterns(raw_data['patterns'], raw_data['mask'])
data = Patterns(raw_data['patterns'], raw_data['mask'], raw_data['patterns_id'])
sampling = raw_data['sampling']
wavelength = unwrap(raw_data.get('wavelength', None))
probe_hook = raw_data.get('probe_hook', None)
Expand All @@ -279,9 +279,19 @@ def initialize_reconstruction(
probe = pane.from_data(probe_hook, ProbeHook)( # type: ignore
{'sampling': sampling, 'wavelength': wavelength, 'dtype': dtype, 'seed': seed, 'xp': xp}
)

## (num_scan, probe_modes, *probe_shape)
##TODO hardcoded for now to num_scan=3
probe.data = probe.data.reshape(1, *probe.data.shape)
probe.data = xp.tile(probe.data, (3, 1, 1, 1))

if probe.data.ndim == 2:
probe.data = probe.data.reshape((1, 1, *probe.data.shape))
elif probe.data.ndim == 3:
probe.data = probe.data.reshape((1, *probe.data.shape))

print(probe.data.shape)

if init_state.scan is not None and plan.init.scan is None:
logging.info("Re-using scan from initial state...")
scan = init_state.scan
Expand Down Expand Up @@ -399,19 +409,25 @@ def prepare_for_engine(patterns: Patterns, state: ReconsState, xp: t.Any, engine
state.object.data = state.object.sampling.resample(state.object.data, obj_sampling)
state.object.sampling = obj_sampling

current_probe_modes = state.probe.data.shape[0]
if engine.probe_modes != current_probe_modes:
# fix probe modes
if engine.probe_modes < current_probe_modes:
# TODO: redistribute intensity here
state.probe.data = state.probe.data[:engine.probe_modes]
else:
from phaser.utils.optics import make_hermetian_modes
if current_probe_modes != 1:
logging.info("Summing probe modes (in real-space) before recreating with different # of modes")

base_mode = xp.sum(state.probe.data, axis=0)
state.probe.data = make_hermetian_modes(base_mode, engine.probe_modes, base_mode_power=engine.base_mode_power)
new_probe_data = []

for num_scan in range(state.probe.data.shape[0]):
current_probe_modes = state.probe.data.shape[1]
if engine.probe_modes != current_probe_modes:
# fix probe modes
if engine.probe_modes < current_probe_modes:
# TODO: redistribute intensity here
new_probe_data.append(state.probe.data[num_scan, :engine.probe_modes])
else:
from phaser.utils.optics import make_hermetian_modes
if current_probe_modes != 1:
logging.info("Summing probe modes (in real-space) before recreating with different # of modes")

base_mode = xp.sum(state.probe.data[num_scan], axis=0)
new_probe_data.append(make_hermetian_modes(base_mode, engine.probe_modes, base_mode_power=engine.base_mode_power))
new_probe_data = xp.stack(new_probe_data, axis=0)
state.probe.data = new_probe_data
print(state.probe.data.shape)

if engine.slices is not None and (len(engine.slices.thicknesses) != len(state.object.thicknesses)
or not numpy.allclose(engine.slices.thicknesses, state.object.thicknesses)):
Expand Down
7 changes: 7 additions & 0 deletions phaser/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class RawData(t.TypedDict):
patterns: NDArray[numpy.floating]
patterns_id: NDArray[numpy.integer]
mask: NDArray[numpy.floating]
sampling: 'Sampling'
wavelength: NotRequired[t.Optional[float]]
Expand Down Expand Up @@ -120,9 +121,15 @@ class RasterScanProps(Dataclass):
affine: t.Optional[t.Annotated[NDArray[numpy.floating], annotations.shape((2, 2))]] = None


class CustomScanProps(Dataclass):
path: str
"""Path to .npy file containing scan array matching the size of the scan"""


class ScanHook(Hook[ScanHookArgs, NDArray[numpy.floating]]):
known = {
'raster': ('phaser.hooks.scan:raster_scan', RasterScanProps),
'custom': ('phaser.hooks.scan:load_custom_scan', CustomScanProps),
}


Expand Down
1 change: 1 addition & 0 deletions phaser/hooks/io/empad.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def load_empad(args: None, props: LoadEmpadProps) -> RawData:

return {
'patterns': patterns,
'patterns_id': numpy.load('id.npy').astype(numpy.int8),
'mask': numpy.fft.ifftshift(mask, axes=(-1, -2)),
'sampling': sampling,
'wavelength': wavelength,
Expand Down
40 changes: 38 additions & 2 deletions phaser/hooks/scan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

import numpy
from numpy.typing import NDArray
from pathlib import Path

from phaser.utils.num import cast_array_module
from phaser.utils.scan import make_raster_scan
from . import ScanHookArgs, RasterScanProps
from . import ScanHookArgs, RasterScanProps, CustomScanProps


def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.floating]:
Expand All @@ -25,4 +26,39 @@ def raster_scan(args: ScanHookArgs, props: RasterScanProps) -> NDArray[numpy.flo
# equivalent to (affine @ scan.T).T (active transformation)
scan = scan @ affine.T

return scan
numpy.save('scan', scan)
return scan


def load_custom_scan(args: ScanHookArgs, props: CustomScanProps) -> NDArray[numpy.floating]:
# def load_custom_tilt(args: TiltHookArgs, props: CustomTiltProps) -> NDArray[numpy.floating]:
"""
Load scan array from a .npy file.

The loaded array can have shape (ny, nx, 2) matching props.shape,
or shape (N, 2) where N == ny*nx, which will be reshaped accordingly.
"""
xp = cast_array_module(args['xp'])

path = Path(props.path).expanduser()
if not path.exists():
raise FileNotFoundError(f"Custom scan file not found: {path}")

scan = numpy.load(path)

# shape = args['shape']
# expected_shape_3d = (*shape, 2)
# expected_shape_2d = (numpy.prod(shape), 2)

# if tilt_data.ndim == 3:
# if tilt_data.shape != expected_shape_3d:
# raise ValueError(f"Loaded tilt data shape {tilt_data.shape} does not match expected shape {expected_shape_3d}")
# result = tilt_data
# elif tilt_data.ndim == 2:
# if tilt_data.shape != expected_shape_2d:
# raise ValueError(f"Loaded tilt data shape {tilt_data.shape} is incompatible with expected 2D shape {expected_shape_2d}")
# result = tilt_data.reshape(expected_shape_3d)
# else:
# raise ValueError(f"Loaded tilt data must be 2D or 3D array, got shape {tilt_data.shape}")
print("loaded scan")
return xp.array(scan, dtype=xp.float32)
3 changes: 2 additions & 1 deletion phaser/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ class Patterns():
"""Raw diffraction patterns, with 0-frequency sample in corner"""
pattern_mask: NDArray[numpy.floating]
"""Mask indicating which portions of the diffraction patterns contain data."""
patterns_id: NDArray[numpy.integer]

def to_numpy(self) -> Self:
return self.__class__(
to_numpy(self.patterns), to_numpy(self.pattern_mask)
to_numpy(self.patterns), to_numpy(self.pattern_mask), to_numpy(self.patterns_id)
)


Expand Down
11 changes: 6 additions & 5 deletions phaser/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def hdf5_read_state(file: HdfLike) -> PartialReconsState:

def hdf5_read_probe_state(group: h5py.Group) -> ProbeState:
probes = _hdf5_read_dataset(group, 'data', numpy.complexfloating)
assert probes.ndim == 3
assert probes.ndim == 4

extent = _hdf5_read_dataset_shape(group, 'extent', numpy.float64, (2,))
(n_y, n_x) = probes.shape[-2:]
Expand Down Expand Up @@ -194,11 +194,12 @@ def hdf5_write_state(state: t.Union[ReconsState, PartialReconsState], file: HdfL


def hdf5_write_probe_state(state: ProbeState, group: h5py.Group):
assert state.data.ndim == 3
assert state.data.ndim == 4
dataset = group.create_dataset('data', data=to_numpy(state.data))
dataset.dims[0].label = 'mode'
dataset.dims[1].label = 'y'
dataset.dims[2].label = 'x'
dataset.dims[0].label = 'scan_num'
dataset.dims[1].label = 'mode'
dataset.dims[2].label = 'y'
dataset.dims[3].label = 'x'

group.create_dataset('sampling', data=state.sampling.sampling.astype(numpy.float64))
group.create_dataset('extent', data=state.sampling.extent.astype(numpy.float64))
Expand Down
Loading