Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions web_api/app/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from zetta_utils.internal.alignment.manual_correspondence import (
apply_correspondences_to_image,
)
from zetta_utils.internal.alignment.sift import compute_sift_correspondences

from .utils import generic_exception_handler

Expand Down Expand Up @@ -331,3 +332,124 @@ async def apply_correspondences(request: Request):
return _build_binary_response(relaxed_field_np, warped_image_np, compress)

return _build_json_response(relaxed_field_np, warped_image_np)


class ComputeSiftCorrespondencesRequest(BaseModel):
src_image: str = Field(..., description="Base64-encoded uint8 image bytes")
tgt_image: str = Field(..., description="Base64-encoded uint8 image bytes")
src_image_shape: list[int] = Field(..., description="Image shape [H, W]")
tgt_image_shape: list[int] = Field(..., description="Image shape [H, W]")
num_correspondences: int = Field(200, description="Number of output correspondences")
num_octaves: int = Field(3, description="Number of octaves for SIFT")
contrast_threshold: float = Field(0.04, description="Contrast threshold for SIFT")
edge_threshold: float = Field(10, description="Edge threshold for SIFT")
sigma: float = Field(1.6, description="Sigma for SIFT")
ratio_test_fraction: float = Field(0.7, description="Lowe's ratio test fraction")
ransac_threshold: float = Field(3.0, description="RANSAC reprojection threshold")
spatial_weight: float = Field(
0.7,
description="Weight for spatial diversity vs match quality",
)
swap_xy: bool = Field(
True,
description="If true, output [y, x] (Portal convention). If false, output [x, y].",
)


class ComputeSiftCorrespondencesResponse(BaseModel):
lines: list[CorrespondenceLine] = Field(..., description="Correspondence lines")
num_inliers: int = Field(..., description="Number of RANSAC inliers")
num_matches: int = Field(..., description="Number of good matches before RANSAC")


@api.post("/compute_sift_correspondences")
async def compute_sift_correspondences_endpoint(request: Request):
"""Compute SIFT correspondences between two images.

Supports two request formats:
- application/json: JSON body with base64-encoded image bytes
- multipart/form-data: Binary image data with JSON metadata

For multipart, send:
- "src-image-data": binary file upload (uint8 image bytes)
- "tgt-image-data": binary file upload (uint8 image bytes)
- "metadata": JSON string with src_image_shape, tgt_image_shape, and SIFT params
"""
content_type = request.headers.get("content-type", "")

sift_param_keys = [
"num_correspondences", "num_octaves", "contrast_threshold",
"edge_threshold", "sigma", "ratio_test_fraction",
"ransac_threshold", "spatial_weight", "swap_xy",
]
sift_param_defaults = {
"num_correspondences": 200, "num_octaves": 3, "contrast_threshold": 0.04,
"edge_threshold": 10, "sigma": 1.6, "ratio_test_fraction": 0.7,
"ransac_threshold": 3.0, "spatial_weight": 0.7, "swap_xy": True,
}

if "multipart/form-data" in content_type:
form = await request.form()

if "metadata" not in form:
raise HTTPException(status_code=400, detail="Missing required field: 'metadata'")
if "src-image-data" not in form:
raise HTTPException(
status_code=400, detail="Missing required field: 'src-image-data'"
)
if "tgt-image-data" not in form:
raise HTTPException(
status_code=400, detail="Missing required field: 'tgt-image-data'"
)

metadata_str = await _read_form_field_str(form, "metadata")
try:
metadata = json.loads(metadata_str)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail=f"Invalid JSON in 'metadata': {e}"
) from e

missing = [
k for k in ("src_image_shape", "tgt_image_shape") if k not in metadata
]
if missing:
raise HTTPException(
status_code=400,
detail=f"Missing required metadata fields: {missing}",
)

src_bytes = await _read_form_field_bytes(form, "src-image-data")
tgt_bytes = await _read_form_field_bytes(form, "tgt-image-data")

src_shape = metadata["src_image_shape"]
tgt_shape = metadata["tgt_image_shape"]

src_image = np.frombuffer(src_bytes, dtype=np.uint8).reshape(src_shape)
tgt_image = np.frombuffer(tgt_bytes, dtype=np.uint8).reshape(tgt_shape)

sift_params = {
k: metadata.get(k, sift_param_defaults[k]) for k in sift_param_keys
}
else:
body = await request.json()
req = ComputeSiftCorrespondencesRequest(**body)

src_image = np.frombuffer(
base64.b64decode(req.src_image), dtype=np.uint8
).reshape(req.src_image_shape)
tgt_image = np.frombuffer(
base64.b64decode(req.tgt_image), dtype=np.uint8
).reshape(req.tgt_image_shape)

sift_params = {k: getattr(req, k) for k in sift_param_keys}

result = compute_sift_correspondences(
src=src_image, tgt=tgt_image, **sift_params
)

return ComputeSiftCorrespondencesResponse(
lines=[CorrespondenceLine(**line) for line in result["lines"]],
num_inliers=result["num_inliers"],
num_matches=result["num_matches"],
)
8 changes: 5 additions & 3 deletions zetta_utils/api/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@
get_aced_match_offsets_naive,
perform_aced_relaxation,
)
from zetta_utils.internal.alignment.base_coarsener import BaseCoarsener
from zetta_utils.internal.alignment.base_encoder import BaseEncoder
from zetta_utils.internal.alignment.encoding_coarsener import EncodingCoarsener
from zetta_utils.internal.alignment.deprecated.base_encoder import BaseEncoder
from zetta_utils.internal.alignment.deprecated.encoding_coarsener import (
EncodingCoarsener,
)
from zetta_utils.internal.alignment.field import (
gen_biased_perlin_noise_field,
get_rigidity_map,
Expand All @@ -130,6 +131,7 @@
percentile,
profile_field2d_percentile,
)
from zetta_utils.internal.alignment.image_encoder import ImageEncoder as BaseCoarsener
from zetta_utils.internal.alignment.misalignment_detector import (
MisalignmentDetector,
naive_misd,
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/internal
Loading