From 76e9d5308f15de4ebf7111d8a417c74290dcfa36 Mon Sep 17 00:00:00 2001 From: Sergiy Popovich Date: Mon, 9 Mar 2026 22:58:03 -0400 Subject: [PATCH] Add SIFT correspondence generation endpoint --- web_api/app/alignment.py | 122 +++++++++++++++++++++++++++++++++++++++ zetta_utils/api/v0.py | 8 ++- zetta_utils/internal | 2 +- 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/web_api/app/alignment.py b/web_api/app/alignment.py index fa72588eb..c1858c7a1 100644 --- a/web_api/app/alignment.py +++ b/web_api/app/alignment.py @@ -14,6 +14,7 @@ from zetta_utils.internal.alignment.manual_correspondence import ( apply_correspondences_to_image, ) +from zetta_utils.internal.alignment.sift import compute_sift_correspondences from .utils import generic_exception_handler @@ -331,3 +332,124 @@ async def apply_correspondences(request: Request): return _build_binary_response(relaxed_field_np, warped_image_np, compress) return _build_json_response(relaxed_field_np, warped_image_np) + + +class ComputeSiftCorrespondencesRequest(BaseModel): + src_image: str = Field(..., description="Base64-encoded uint8 image bytes") + tgt_image: str = Field(..., description="Base64-encoded uint8 image bytes") + src_image_shape: list[int] = Field(..., description="Image shape [H, W]") + tgt_image_shape: list[int] = Field(..., description="Image shape [H, W]") + num_correspondences: int = Field(200, description="Number of output correspondences") + num_octaves: int = Field(3, description="Number of octaves for SIFT") + contrast_threshold: float = Field(0.04, description="Contrast threshold for SIFT") + edge_threshold: float = Field(10, description="Edge threshold for SIFT") + sigma: float = Field(1.6, description="Sigma for SIFT") + ratio_test_fraction: float = Field(0.7, description="Lowe's ratio test fraction") + ransac_threshold: float = Field(3.0, description="RANSAC reprojection threshold") + spatial_weight: float = Field( + 0.7, + description="Weight for spatial diversity vs match quality", + ) + swap_xy: bool = Field( + True, + description="If true, output [y, x] (Portal convention). If false, output [x, y].", + ) + + +class ComputeSiftCorrespondencesResponse(BaseModel): + lines: list[CorrespondenceLine] = Field(..., description="Correspondence lines") + num_inliers: int = Field(..., description="Number of RANSAC inliers") + num_matches: int = Field(..., description="Number of good matches before RANSAC") + + +@api.post("/compute_sift_correspondences") +async def compute_sift_correspondences_endpoint(request: Request): + """Compute SIFT correspondences between two images. + + Supports two request formats: + - application/json: JSON body with base64-encoded image bytes + - multipart/form-data: Binary image data with JSON metadata + + For multipart, send: + - "src-image-data": binary file upload (uint8 image bytes) + - "tgt-image-data": binary file upload (uint8 image bytes) + - "metadata": JSON string with src_image_shape, tgt_image_shape, and SIFT params + """ + content_type = request.headers.get("content-type", "") + + sift_param_keys = [ + "num_correspondences", "num_octaves", "contrast_threshold", + "edge_threshold", "sigma", "ratio_test_fraction", + "ransac_threshold", "spatial_weight", "swap_xy", + ] + sift_param_defaults = { + "num_correspondences": 200, "num_octaves": 3, "contrast_threshold": 0.04, + "edge_threshold": 10, "sigma": 1.6, "ratio_test_fraction": 0.7, + "ransac_threshold": 3.0, "spatial_weight": 0.7, "swap_xy": True, + } + + if "multipart/form-data" in content_type: + form = await request.form() + + if "metadata" not in form: + raise HTTPException(status_code=400, detail="Missing required field: 'metadata'") + if "src-image-data" not in form: + raise HTTPException( + status_code=400, detail="Missing required field: 'src-image-data'" + ) + if "tgt-image-data" not in form: + raise HTTPException( + status_code=400, detail="Missing required field: 'tgt-image-data'" + ) + + metadata_str = await _read_form_field_str(form, "metadata") + try: + metadata = json.loads(metadata_str) + except json.JSONDecodeError as e: + raise HTTPException( + status_code=400, detail=f"Invalid JSON in 'metadata': {e}" + ) from e + + missing = [ + k for k in ("src_image_shape", "tgt_image_shape") if k not in metadata + ] + if missing: + raise HTTPException( + status_code=400, + detail=f"Missing required metadata fields: {missing}", + ) + + src_bytes = await _read_form_field_bytes(form, "src-image-data") + tgt_bytes = await _read_form_field_bytes(form, "tgt-image-data") + + src_shape = metadata["src_image_shape"] + tgt_shape = metadata["tgt_image_shape"] + + src_image = np.frombuffer(src_bytes, dtype=np.uint8).reshape(src_shape) + tgt_image = np.frombuffer(tgt_bytes, dtype=np.uint8).reshape(tgt_shape) + + sift_params = { + k: metadata.get(k, sift_param_defaults[k]) for k in sift_param_keys + } + else: + body = await request.json() + req = ComputeSiftCorrespondencesRequest(**body) + + src_image = np.frombuffer( + base64.b64decode(req.src_image), dtype=np.uint8 + ).reshape(req.src_image_shape) + tgt_image = np.frombuffer( + base64.b64decode(req.tgt_image), dtype=np.uint8 + ).reshape(req.tgt_image_shape) + + sift_params = {k: getattr(req, k) for k in sift_param_keys} + + result = compute_sift_correspondences( + src=src_image, tgt=tgt_image, **sift_params + ) + + return ComputeSiftCorrespondencesResponse( + lines=[CorrespondenceLine(**line) for line in result["lines"]], + num_inliers=result["num_inliers"], + num_matches=result["num_matches"], + ) diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 0c9a1d6ad..8f207ddbb 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -118,9 +118,10 @@ get_aced_match_offsets_naive, perform_aced_relaxation, ) -from zetta_utils.internal.alignment.base_coarsener import BaseCoarsener -from zetta_utils.internal.alignment.base_encoder import BaseEncoder -from zetta_utils.internal.alignment.encoding_coarsener import EncodingCoarsener +from zetta_utils.internal.alignment.deprecated.base_encoder import BaseEncoder +from zetta_utils.internal.alignment.deprecated.encoding_coarsener import ( + EncodingCoarsener, +) from zetta_utils.internal.alignment.field import ( gen_biased_perlin_noise_field, get_rigidity_map, @@ -130,6 +131,7 @@ percentile, profile_field2d_percentile, ) +from zetta_utils.internal.alignment.image_encoder import ImageEncoder as BaseCoarsener from zetta_utils.internal.alignment.misalignment_detector import ( MisalignmentDetector, naive_misd, diff --git a/zetta_utils/internal b/zetta_utils/internal index eebc37899..362d5323e 160000 --- a/zetta_utils/internal +++ b/zetta_utils/internal @@ -1 +1 @@ -Subproject commit eebc3789986781c5300555a31ef1df6894388287 +Subproject commit 362d5323ee9f86e63ab925eda155547f12715773