Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions web_api/app/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ class ApplyCorrespondencesRequest(BaseModel):
description="Mask interpretation: 'tissue' (1=tissue, 0=non-tissue) or "
"'defect' (1=defect, 0=non-defect). Defect masks are inverted internally.",
)
tgt_image: list[list[list[list[float]]]] | None = Field(
None,
description="Target section image tensor (C, H, W, 1). When provided with "
"mse_weight > 0, enables image-based MSE optimization.",
)
mse_weight: float = Field(
0.0,
description="Weight for MSE loss between warped source and target images. "
"Only used when tgt_image is provided.",
)


class ApplyCorrespondencesResponse(BaseModel):
Expand Down Expand Up @@ -105,6 +115,7 @@ def _parse_json_request(body: dict, device: torch.device):

src_mask_tensor = None
tgt_mask_tensor = None
tgt_image_tensor = None

if req.src_mask is not None:
src_mask_tensor = torch.tensor(req.src_mask, dtype=torch.float32, device=device)
Expand All @@ -116,12 +127,20 @@ def _parse_json_request(body: dict, device: torch.device):
if req.mask_type == "defect":
tgt_mask_tensor = 1.0 - tgt_mask_tensor

return correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, {
if req.tgt_image is not None:
tgt_image_tensor = torch.tensor(req.tgt_image, dtype=torch.float32, device=device)

params = {
"num_iter": req.num_iter,
"rig": req.rig,
"lr": req.lr,
"optimizer_type": req.optimizer_type,
"mse_weight": req.mse_weight,
}
return (
correspondences_dict, image_tensor, src_mask_tensor,
tgt_mask_tensor, tgt_image_tensor, params,
)


async def _read_form_field_bytes(form, field_name: str) -> bytes:
Expand Down Expand Up @@ -219,12 +238,30 @@ async def _parse_multipart_request(request: Request, device: torch.device):
if mask_type == "defect":
tgt_mask_tensor = 1.0 - tgt_mask_tensor

return correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, {
tgt_image_tensor = None
if "tgt-image-data" in form:
if "tgt_image_shape" not in metadata:
raise HTTPException(
status_code=400,
detail="'tgt_image_shape' required in metadata when 'tgt-image-data' is provided",
)
tgt_img_bytes = await _read_form_field_bytes(form, "tgt-image-data")
tgt_img_np = _read_tensor_from_bytes(
tgt_img_bytes, metadata["tgt_image_shape"], "tgt-image-data"
)
tgt_image_tensor = torch.tensor(tgt_img_np, dtype=torch.float32, device=device)

params = {
"num_iter": metadata.get("num_iter", 200),
"rig": metadata.get("rig", 1000),
"lr": metadata.get("lr", 1e-3),
"optimizer_type": metadata.get("optimizer_type", "adam"),
"mse_weight": metadata.get("mse_weight", 1.0),
}
return (
correspondences_dict, image_tensor, src_mask_tensor,
tgt_mask_tensor, tgt_image_tensor, params,
)


def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray):
Expand Down Expand Up @@ -306,20 +343,36 @@ async def apply_correspondences(request: Request):

content_type = request.headers.get("content-type", "")
if "multipart/form-data" in content_type:
correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, params = (
await _parse_multipart_request(request, device)
)
(
correspondences_dict, image_tensor, src_mask_tensor,
tgt_mask_tensor, tgt_image_tensor, params,
) = await _parse_multipart_request(request, device)
else:
body = await request.json()
correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, params = (
_parse_json_request(body, device)
)
(
correspondences_dict, image_tensor, src_mask_tensor,
tgt_mask_tensor, tgt_image_tensor, params,
) = _parse_json_request(body, device)

print(f"[apply_correspondences] tgt_image_tensor is None: "
f"{tgt_image_tensor is None}")
if tgt_image_tensor is not None:
print(f"[apply_correspondences] tgt_image_tensor "
f"shape: {tgt_image_tensor.shape}, "
f"min: {tgt_image_tensor.min().item():.4f}, "
f"max: {tgt_image_tensor.max().item():.4f}")
print(f"[apply_correspondences] image_tensor "
f"shape: {image_tensor.shape}, "
f"min: {image_tensor.min().item():.4f}, "
f"max: {image_tensor.max().item():.4f}")
print(f"[apply_correspondences] params: {params}")

relaxed_field, warped_image = apply_correspondences_to_image(
correspondences_dict=correspondences_dict,
image=image_tensor,
src_mask=src_mask_tensor,
tgt_mask=tgt_mask_tensor,
tgt_image=tgt_image_tensor,
**params,
)

Expand Down
Loading