Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ model_weights
outputs_*
*.egg-info/*
build
matching/third_party/MatchAnything/weights/*.ckpt
matching/third_party/MatchAnything/imcui/third_party/MatchAnything/weights/*.ckpt
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,6 @@
[submodule "matching/third_party/RIPE"]
path = matching/third_party/RIPE
url = https://github.com/fraunhoferhhi/RIPE.git
[submodule "matching/third_party/MatchAnything"]
path = matching/third_party/MatchAnything
url = https://huggingface.co/spaces/LittleFrog/MatchAnything
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,35 @@ As with matching, you can also run extraction from the command line
python main_extractor.py --matcher sift-lg --device cpu --out_dir output_sift-lg --n_kpts 2048
```

### MatchAnything variants (ELoFTR / RoMa)
MatchAnything (HF Space: https://huggingface.co/spaces/LittleFrog/MatchAnything) is tracked as a git submodule at `matching/third_party/MatchAnything` (code lives under `imcui/third_party/MatchAnything`). Init/update it with:
```bash
git submodule update --init --recursive matching/third_party/MatchAnything
```
Download checkpoints (kept out of git) into the nested MatchAnything folder. PowerShell example:
```powershell
cd matching/third_party/MatchAnything/imcui/third_party/MatchAnything
python -m pip install gdown
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gdown is dependency already, so we can move this code to the download_weights() fn so that they are auto downloaded at first install. See example here.

python -m gdown 12L3g9-w8rR9K2L4rYaGaDJ7NqX1D713d --fuzzy -O weights.zip
tar -xf weights.zip # or: unzip weights.zip
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unzipping and removing the zip should also be done in download_weights

Remove-Item weights.zip
# weights/matchanything_eloftr.ckpt and weights/matchanything_roma.ckpt should now exist
```
Run either variant via:
```bash
# ELoFTR backbone (defaults to 832px NPE size)
python main_matcher.py --matcher matchanything-eloftr --device cuda --im_size 832 --out_dir outputs_matchanything-eloftr

# RoMa backbone (AMP disabled on CPU automatically)
python main_matcher.py --matcher matchanything-roma --device cuda --im_size 832 --out_dir outputs_matchanything-roma
```
Weights should be at `matching/third_party/MatchAnything/imcui/third_party/MatchAnything/weights/matchanything_eloftr.ckpt` and `matching/third_party/MatchAnything/imcui/third_party/MatchAnything/weights/matchanything_roma.ckpt`.
The RoMa variant uses the vendored ROMA package (inside the submodule); if your env cannot import `roma`, install it in editable mode:
```bash
python -m pip install -e matching/third_party/MatchAnything/imcui/third_party/MatchAnything/third_party/ROMA
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work if you directly install romatch from pypi? see here

```
Lightning 1.4.9 (MatchAnything dependency) also expects `torchmetrics==0.6.0`, `wandb==0.15.12`, and `pydantic==1.10.x`; these are pinned in `requirements.txt`.


## Available Models
You can choose any of the following methods (input to `get_matcher()`):
Expand Down
12 changes: 12 additions & 0 deletions matching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"se2loftr",
"xoftr",
"aspanformer",
"matchanything-eloftr",
"matchanything-roma",
"matchformer",
"sift-lg",
"superpoint-lg",
Expand Down Expand Up @@ -122,6 +124,16 @@ def get_matcher(

return efficient_loftr.EfficientLoFTRMatcher(device, *args, **kwargs)

if matcher_name in ["matchanything-eloftr", "matchanything_eloftr"]:
from matching.im_models import matchanything

return matchanything.MatchAnythingMatcher(device, variant="eloftr", *args, **kwargs)

if matcher_name in ["matchanything-roma", "matchanything_roma"]:
from matching.im_models import matchanything

return matchanything.MatchAnythingMatcher(device, variant="roma", *args, **kwargs)

if matcher_name == "se2loftr":
from matching.im_models import se2loftr

Expand Down
224 changes: 224 additions & 0 deletions matching/im_models/matchanything.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import numpy as np
from pathlib import Path
from PIL import Image
import cv2
import torch
import torch.nn.functional as F

from matching import BaseMatcher, THIRD_PARTY_DIR
from matching.utils import add_to_path

# Expose the MatchAnything HF Space code (nested under imcui/third_party/MatchAnything) and its deps.
MATCHANYTHING_DIR = THIRD_PARTY_DIR.joinpath("MatchAnything", "imcui", "third_party", "MatchAnything")
add_to_path(MATCHANYTHING_DIR)

from src.lightning.lightning_loftr import PL_LoFTR # noqa: E402
from src.config.default import get_cfg_defaults # noqa: E402


class MatchAnythingMatcher(BaseMatcher):
"""Wrapper around the MatchAnything checkpoints."""

def __init__(
self,
device="cpu",
variant="eloftr",
match_threshold=0.2,
img_resize=None,
*args,
**kwargs,
):
super().__init__(device, **kwargs)

self.variant = variant.lower()
if self.variant not in ("eloftr", "roma"):
raise ValueError(f"Unsupported MatchAnything variant: {variant}")

self.match_threshold = match_threshold
self.img_resize = img_resize

self.model_name = f"matchanything_{self.variant}"
self.model_path = MATCHANYTHING_DIR.joinpath("weights", f"{self.model_name}.ckpt")
self._load_model()

def _load_model(self):
self.model_path.parent.mkdir(parents=True, exist_ok=True)
self.download_weights()

cfg = get_cfg_defaults()
if self.variant == "eloftr":
cfg.merge_from_file(str(MATCHANYTHING_DIR.joinpath("configs", "models", "eloftr_model.py")))
# The LoFTR config expects an NPE tuple; default to 832 if not provided.
if cfg.DATASET.NPE_NAME is not None:
if cfg.DATASET.NPE_NAME == "megadepth":
target_size = self.img_resize or 832
cfg.LOFTR.COARSE.NPE = [832, 832, target_size, target_size]
else:
cfg.merge_from_file(str(MATCHANYTHING_DIR.joinpath("configs", "models", "roma_model.py")))
if self.device == "cpu":
cfg.LOFTR.FP16 = False
cfg.ROMA.MODEL.AMP = False

cfg.METHOD = self.model_name
cfg.LOFTR.MATCH_COARSE.THR = self.match_threshold

self.net = PL_LoFTR(cfg, pretrained_ckpt=self.model_path, test_mode=True).matcher
self.net.eval().to(self.device)

def download_weights(self):
"""Ensure weights exist locally."""
if self.model_path.is_file():
return

raise FileNotFoundError(
f"Missing weights for {self.model_name}. "
f"Place the checkpoint at {self.model_path}"
)

def _preprocess_single(self, img):
img_np = img.cpu().numpy().squeeze() * 255
img_np = img_np.transpose(1, 2, 0).astype("uint8")

img_size = np.array(img_np.shape[:2])
img_gray = np.array(Image.fromarray(img_np).convert("L"))
img_resized, scale_hw, mask = resize(img_gray, df=32)

img_tensor = torch.from_numpy(img_resized)[None][None] / 255.0
return img_tensor, img_size, scale_hw, mask, img

def _forward(self, img0, img1):
img0_proc, img0_size, img0_scale, mask0, img0_orig = self._preprocess_single(img0)
img1_proc, img1_size, img1_scale, mask1, img1_orig = self._preprocess_single(img1)

batch = {
"image0": img0_proc,
"image1": img1_proc,
# ROMA expects a leading batch dim on RGB images; keep it for both variants
"image0_rgb_origin": img0_orig[None],
"image1_rgb_origin": img1_orig[None],
"origin_img_size0": torch.from_numpy(img0_size)[None],
"origin_img_size1": torch.from_numpy(img1_size)[None],
}

if mask0 is not None and mask1 is not None:
mask0_t = torch.from_numpy(mask0).to(self.device)
mask1_t = torch.from_numpy(mask1).to(self.device)
ts_mask_0, ts_mask_1 = F.interpolate(
torch.stack([mask0_t, mask1_t], dim=0)[None].float(),
scale_factor=0.125,
mode="nearest",
recompute_scale_factor=False,
)[0].bool()
batch["mask0"] = ts_mask_0[None]
batch["mask1"] = ts_mask_1[None]

batch = dict_to_device(batch, device=self.device)

self.net(batch)

mkpts0 = batch["mkpts0_f"].detach().cpu()
mkpts1 = batch["mkpts1_f"].detach().cpu()

if self.variant == "eloftr":
mkpts0 *= torch.tensor(img0_scale)[[1, 0]]
mkpts1 *= torch.tensor(img1_scale)[[1, 0]]

return mkpts0, mkpts1, None, None, None, None


def resize(img, resize=None, df=8, padding=True):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to keep this and the other resize components if we need them, but are they significantly different from the resize already supported in load_image and resize_by and resize_to_divisible ?

w, h = img.shape[1], img.shape[0]
w_new, h_new = process_resize(w, h, resize=resize, df=df, resize_no_larger_than=False)
img_new = resize_image(img, (w_new, h_new), interp="pil_LANCZOS").astype("float32")
h_scale, w_scale = img.shape[0] / img_new.shape[0], img.shape[1] / img_new.shape[1]
mask = None
if padding:
img_new, mask = pad_bottom_right(img_new, max(h_new, w_new), ret_mask=True)
return img_new, [h_scale, w_scale], mask


def process_resize(w, h, resize=None, df=None, resize_no_larger_than=False):
if resize is not None:
assert len(resize) > 0 and len(resize) <= 2
if resize_no_larger_than and (max(h, w) <= max(resize)):
w_new, h_new = w, h
else:
if len(resize) == 1 and resize[0] > -1: # resize the larger side
scale = resize[0] / max(h, w)
w_new, h_new = int(round(w * scale)), int(round(h * scale))
elif len(resize) == 1 and resize[0] == -1:
w_new, h_new = w, h
else:
w_new, h_new = resize[0], resize[1]
else:
w_new, h_new = w, h

if df is not None:
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
return w_new, h_new


def resize_image(image, size, interp):
if interp.startswith("cv2_"):
interp = getattr(cv2, "INTER_" + interp[len("cv2_") :].upper())
h, w = image.shape[:2]
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
interp = cv2.INTER_LINEAR
resized = cv2.resize(image, size, interpolation=interp)
elif interp.startswith("pil_"):
interp = getattr(Image, interp[len("pil_") :].upper())
resized = Image.fromarray(image.astype(np.uint8))
resized = resized.resize(size, resample=interp)
resized = np.asarray(resized, dtype=image.dtype)
else:
raise ValueError(f"Unknown interpolation {interp}.")
return resized


def pad_bottom_right(inp, pad_size, ret_mask=False):
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
mask = None
if inp.ndim == 2:
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
padded[: inp.shape[0], : inp.shape[1]] = inp
if ret_mask:
mask = np.zeros((pad_size, pad_size), dtype=bool)
mask[: inp.shape[0], : inp.shape[1]] = True
elif inp.ndim == 3:
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
padded[:, : inp.shape[1], : inp.shape[2]] = inp
if ret_mask:
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
mask[:, : inp.shape[1], : inp.shape[2]] = True
mask = mask[0] if mask is not None else None
else:
raise NotImplementedError()
return padded, mask


def dict_to_device(data_dict, device="cuda"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be moved to utils

data_dict_device = {}
for k, v in data_dict.items():
if isinstance(v, torch.Tensor):
data_dict_device[k] = v.to(device)
elif isinstance(v, dict):
data_dict_device[k] = dict_to_device(v, device=device)
elif isinstance(v, list):
data_dict_device[k] = list_to_device(v, device=device)
else:
data_dict_device[k] = v
return data_dict_device


def list_to_device(data_list, device="cuda"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be moved to utils

data_list_device = []
for obj in data_list:
if isinstance(obj, torch.Tensor):
data_list_device.append(obj.to(device))
elif isinstance(obj, dict):
data_list_device.append(dict_to_device(obj, device=device))
elif isinstance(obj, list):
data_list_device.append(list_to_device(obj, device=device))
else:
data_list_device.append(obj)
return data_list_device
1 change: 1 addition & 0 deletions matching/third_party/MatchAnything
Submodule MatchAnything added at 48f422
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,9 @@ loguru
timm
omegaconf
poselib
lightning==2.0.0 # from EDM
lightning==2.0.0 # from EDM
pytorch-lightning==1.4.9 # MatchAnything dependency
pynvml # used by MatchAnything lightning wrapper
torchmetrics==0.6.0 # lightning 1.4.9 compatibility
pydantic==1.10.14 # lightning 1.4.9 / wandb compatibility
wandb==0.15.12 # lightning 1.4.9 compatibility