diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index fa72588eb..15f5454e8 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -68,6 +68,16 @@ class ApplyCorrespondencesRequest(BaseModel): description="Mask interpretation: 'tissue' (1=tissue, 0=non-tissue) or " "'defect' (1=defect, 0=non-defect). Defect masks are inverted internally.", ) + tgt_image: list[list[list[list[float]]]] | None = Field( + None, + description="Target section image tensor (C, H, W, 1). When provided with " + "mse_weight > 0, enables image-based MSE optimization.", + ) + mse_weight: float = Field( + 0.0, + description="Weight for MSE loss between warped source and target images. " + "Only used when tgt_image is provided.", + ) class ApplyCorrespondencesResponse(BaseModel): @@ -105,6 +115,7 @@ def _parse_json_request(body: dict, device: torch.device): src_mask_tensor = None tgt_mask_tensor = None + tgt_image_tensor = None if req.src_mask is not None: src_mask_tensor = torch.tensor(req.src_mask, dtype=torch.float32, device=device) @@ -116,12 +127,20 @@ def _parse_json_request(body: dict, device: torch.device): if req.mask_type == "defect": tgt_mask_tensor = 1.0 - tgt_mask_tensor - return correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, { + if req.tgt_image is not None: + tgt_image_tensor = torch.tensor(req.tgt_image, dtype=torch.float32, device=device) + + params = { "num_iter": req.num_iter, "rig": req.rig, "lr": req.lr, "optimizer_type": req.optimizer_type, + "mse_weight": req.mse_weight, } + return ( + correspondences_dict, image_tensor, src_mask_tensor, + tgt_mask_tensor, tgt_image_tensor, params, + ) async def _read_form_field_bytes(form, field_name: str) -> bytes: @@ -219,12 +238,30 @@ async def _parse_multipart_request(request: Request, device: torch.device): if mask_type == "defect": tgt_mask_tensor = 1.0 - tgt_mask_tensor - return correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, { + tgt_image_tensor = None + if "tgt-image-data" in form: + if "tgt_image_shape" not in metadata: + raise HTTPException( + status_code=400, + detail="'tgt_image_shape' required in metadata when 'tgt-image-data' is provided", + ) + tgt_img_bytes = await _read_form_field_bytes(form, "tgt-image-data") + tgt_img_np = _read_tensor_from_bytes( + tgt_img_bytes, metadata["tgt_image_shape"], "tgt-image-data" + ) + tgt_image_tensor = torch.tensor(tgt_img_np, dtype=torch.float32, device=device) + + params = { "num_iter": metadata.get("num_iter", 200), "rig": metadata.get("rig", 1000), "lr": metadata.get("lr", 1e-3), "optimizer_type": metadata.get("optimizer_type", "adam"), + "mse_weight": metadata.get("mse_weight", 1.0), } + return ( + correspondences_dict, image_tensor, src_mask_tensor, + tgt_mask_tensor, tgt_image_tensor, params, + ) def _build_json_response(relaxed_field_np: np.ndarray, warped_image_np: np.ndarray): @@ -306,20 +343,36 @@ async def apply_correspondences(request: Request): content_type = request.headers.get("content-type", "") if "multipart/form-data" in content_type: - correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, params = ( - await _parse_multipart_request(request, device) - ) + ( + correspondences_dict, image_tensor, src_mask_tensor, + tgt_mask_tensor, tgt_image_tensor, params, + ) = await _parse_multipart_request(request, device) else: body = await request.json() - correspondences_dict, image_tensor, src_mask_tensor, tgt_mask_tensor, params = ( - _parse_json_request(body, device) - ) + ( + correspondences_dict, image_tensor, src_mask_tensor, + tgt_mask_tensor, tgt_image_tensor, params, + ) = _parse_json_request(body, device) + + print(f"[apply_correspondences] tgt_image_tensor is None: " + f"{tgt_image_tensor is None}") + if tgt_image_tensor is not None: + print(f"[apply_correspondences] tgt_image_tensor " + f"shape: {tgt_image_tensor.shape}, " + f"min: {tgt_image_tensor.min().item():.4f}, " + f"max: {tgt_image_tensor.max().item():.4f}") + print(f"[apply_correspondences] image_tensor " + f"shape: {image_tensor.shape}, " + f"min: {image_tensor.min().item():.4f}, " + f"max: {image_tensor.max().item():.4f}") + print(f"[apply_correspondences] params: {params}") relaxed_field, warped_image = apply_correspondences_to_image( correspondences_dict=correspondences_dict, image=image_tensor, src_mask=src_mask_tensor, tgt_mask=tgt_mask_tensor, + tgt_image=tgt_image_tensor, **params, )