Skip to content
Merged
3 changes: 2 additions & 1 deletion models/vista2d/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.2.2",
"version": "0.2.3",
"changelog": {
"0.2.3": "update weights link",
"0.2.2": "update to use monai components",
"0.2.1": "initial OSS version"
},
Expand Down
2 changes: 1 addition & 1 deletion models/vista2d/large_files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ large_files:
- path: "models/sam_vit_b_01ec64.pth"
url: "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
- path: "models/model.pt"
url: "https://drive.google.com/file/d/1odLhoOtlxxbEyRq-gvenP8bC0-mw63ng/view?usp=drive_link"
url: "https://github.com/Project-MONAI/model-zoo/releases/download/model_zoo_bundle_data/vista2d_model.pt"
- path: "datalists.zip"
url: "https://github.com/Project-MONAI/model-zoo/releases/download/model_zoo_bundle_data/vista2d_datalists.zip"
6 changes: 2 additions & 4 deletions models/vista3d/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"$import os",
"$import scripts",
"$import numpy as np",
"$import copy",
"$import json"
],
"bundle_root": "./",
Expand Down Expand Up @@ -47,7 +48,6 @@
128
],
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"use_cfp": false,
"use_point_window": true,
"network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)",
"network": "$@network_def.to(@device)",
Expand Down Expand Up @@ -127,7 +127,6 @@
"roi_size": "@patch_size",
"overlap": 0.5,
"sw_batch_size": "@sw_batch_size",
"use_cfp": "@use_cfp",
"use_point_window": "@use_point_window"
},
"postprocessing": {
Expand All @@ -146,7 +145,7 @@
{
"_target_": "Invertd",
"keys": "pred",
"transform": "@preprocessing",
"transform": "$copy.deepcopy(@preprocessing)",
"orig_keys": "@image_key",
"nearest_interp": true,
"to_tensor": true
Expand Down Expand Up @@ -192,7 +191,6 @@
"val_handlers": "@handlers",
"amp": true,
"hyper_kwargs": {
"use_cfp": "@use_cfp",
"user_prompt": true,
"everything_labels": "@everything_labels"
}
Expand Down
3 changes: 2 additions & 1 deletion models/vista3d/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.4.2",
"version": "0.4.3",
"changelog": {
"0.4.3": "fix CL and batch infer issues",
"0.4.2": "use MONAI components for network and utils",
"0.4.1": "initial OSS version"
},
Expand Down
13 changes: 5 additions & 8 deletions models/vista3d/configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"early_stop": false,
"fold": 0,
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"epochs": 100,
"epochs": 5,
"val_interval": 1,
"val_at_start": false,
"sw_overlap": 0.625,
Expand All @@ -28,10 +28,9 @@
"max_prompt": null,
"max_backprompt": null,
"max_foreprompt": null,
"drop_label_prob": 0.5,
"drop_point_prob": 0.5,
"drop_label_prob": 0.25,
"drop_point_prob": 0.25,
"exclude_background": true,
"use_cfp": true,
"label_set": null,
"val_label_set": "@label_set",
"amp": true,
Expand Down Expand Up @@ -277,7 +276,6 @@
"drop_label_prob": "@drop_label_prob",
"drop_point_prob": "@drop_point_prob",
"exclude_background": "@exclude_background",
"use_cfp": "@use_cfp",
"label_set": "@label_set",
"patch_size": "@patch_size",
"user_prompt": false
Expand Down Expand Up @@ -314,8 +312,7 @@
"inferer": {
"_target_": "scripts.inferer.Vista3dInferer",
"roi_size": "@patch_size_valid",
"overlap": "@sw_overlap",
"use_cfp": "@use_cfp"
"overlap": "@sw_overlap"
},
"handlers": [
{
Expand Down Expand Up @@ -377,8 +374,8 @@
"drop_label_prob": "@drop_label_prob",
"drop_point_prob": "@drop_point_prob",
"exclude_background": "@exclude_background",
"use_cfp": "@use_cfp",
"label_set": "@label_set",
"val_head": "auto",
"user_prompt": false
}
}
Expand Down
20 changes: 9 additions & 11 deletions models/vista3d/configs/train_continual.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"finetune_model_path": "$@bundle_root + '/models/model.pt'",
"n_train_samples": 10,
"n_val_samples": 10,
"val_interval": 40,
"learning_rate": 0.0001,
"val_interval": 1,
"learning_rate": 5e-05,
"lr_schedule#activate": false,
"loss#smooth_dr": 0.01,
"loss#smooth_nr": 0.0001,
Expand All @@ -18,18 +18,14 @@
"default": [
[
1,
2
],
[
2,
254
3
]
]
},
"patch_size": [
160,
160,
160
128,
128,
128
],
"label_set": "$[0] + list(x[1] for x in @label_mappings#default)",
"val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)",
Expand Down Expand Up @@ -99,11 +95,13 @@
"num_workers": "@num_cache_workers",
"progress": "@show_cache_progress"
},
"validate#evaluator#hyper_kwargs#val_label_set": "$list(range(len(@val_label_set)))",
"validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]",
"valid_remap": {
"_target_": "monai.apps.vista3d.transforms.Relabeld",
"keys": "label",
"label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}",
"dtype": "$torch.uint8"
}
},
"validate#handlers#3#key_metric_filename": "model_finetune.pt"
}
26 changes: 19 additions & 7 deletions models/vista3d/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config
#### Execute continual learning
When finetuning with new class names, please update `configs/train_continual.json`'s `label_mappings` accordingly.

The current label mapping `[[1, 2], [2, 254]]` indicates that training labels' class indices `1` and `2`, are mapped
to the VISTA model's class `2` and `254` respectively (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`).
Since `254` is not used by VISTA, it is therefore indicating
training with a new class (the training label's class `2` will be trained as VISTA class `254`).
The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped
to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). For new classes, user
can map to any value larger than 132.

`label_set` is used to identify the VISTA model classes for providing training prompts.
`val_label_set` is used to identify the original training label classes for computing foreground/background mask during validation.
Expand All @@ -103,7 +102,10 @@ The default configs for both variables are derived from the `label_mappings` con
"label_set": "$[0] + list(x[1] for x in @label_mappings#default)"
"val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)"
```

`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob=1`, the
model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation
model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`).
In this bundle, the training is simplified by jointly training with class prompts and point prompts.

Single-GPU:
```
Expand All @@ -117,11 +119,21 @@ torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \
--config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.005
```

The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [160, 160, 160]`, and this works for the use cases
The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`, and this works for the use cases
of extending the current model to segment a few novel classes. Finetuning all supported classes may require large GPU memory and carefully designed
multi-stage training processes.

Changing `patch_size` to a smaller value such as `"patch_size": [128, 128, 128]` used in `configs/train.json` would reduce the training memory footprint.
Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` used in `configs/train.json` would reduce the training memory footprint.

In `train_continual.json`, only subset of training and validation data are used, change `n_train_samples` and `n_val_samples` to use full dataset.

In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`,
the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to
speed issue.

Note: `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is
required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped
to the VISTA3D global class index.

#### Execute evaluation
`n_train_samples` and `n_val_samples` are used to specify the number of samples to use for training and validation respectively.
Expand Down
18 changes: 10 additions & 8 deletions models/vista3d/scripts/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from monai.engines.evaluator import SupervisedEvaluator
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.transforms import Transform, reset_ops_id
from monai.utils import ForwardMode, RankFilter, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -207,6 +207,8 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
label_set = engine.hyper_kwargs.get("label_set", None)
# this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points
val_label_set = engine.hyper_kwargs.get("val_label_set", label_set)
# If user provide prompts in the inference, input image must contain original affine.
# the point coordinates are from the original_affine space, while image here is after preprocess transforms.
if engine.hyper_kwargs["user_prompt"]:
Expand Down Expand Up @@ -242,18 +244,17 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
output_classes = engine.hyper_kwargs["output_classes"]
label_set = np.arange(output_classes).tolist()
label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1)
# point prompt is generated withing vista3d,provide empty points
# point prompt is generated withing vista3d, provide empty points
points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device)
point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device)
if engine.hyper_kwargs["drop_point_prob"] > 0.99:
# validation for either auto or point.
if engine.hyper_kwargs.get("val_head", "auto") == "auto":
# automatic only validation
points = None
point_labels = None
if engine.hyper_kwargs["drop_label_prob"] > 0.99:
# remove val_label_set, vista3d will not sample points from gt labels.
val_label_set = None
else:
# point only validation
label_prompt = None
# this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points
val_label_set = engine.hyper_kwargs.get("val_label_set", label_set)

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
Expand All @@ -280,6 +281,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
labels=labels,
label_set=val_label_set,
)
inputs = reset_ops_id(inputs)
# Add dim 0 for decollate batch
engine.state.output["label_prompt"] = label_prompt.unsqueeze(0) if label_prompt is not None else None
engine.state.output["points"] = points.unsqueeze(0) if points is not None else None
Expand Down
8 changes: 3 additions & 5 deletions models/vista3d/scripts/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ class Vista3dInferer(Inferer):
Args:
roi_size: the sliding window patch size.
overlap: sliding window overlap ratio.
use_cfp: use class prompt for point head.
"""

def __init__(self, roi_size, overlap, use_cfp, use_point_window=False, sw_batch_size=1) -> None:
def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -> None:
Inferer.__init__(self)
self.roi_size = roi_size
self.overlap = overlap
self.use_cfp = use_cfp
self.sw_batch_size = sw_batch_size
self.use_point_window = use_point_window
self.sliding_window_inferer = point_based_window_inferer if use_point_window else sliding_window_inference
Expand Down Expand Up @@ -91,6 +89,7 @@ def __call__(
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
transpose=True,
with_coord=True,
predictor=network,
mode="gaussian",
sw_device=device,
Expand All @@ -103,7 +102,6 @@ def __call__(
prev_mask=prev_mask,
labels=labels,
label_set=label_set,
use_cfp=self.use_cfp,
)
except Exception:
val_outputs = None
Expand All @@ -113,6 +111,7 @@ def __call__(
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
transpose=True,
with_coord=True,
predictor=network,
mode="gaussian",
sw_device=device,
Expand All @@ -125,6 +124,5 @@ def __call__(
prev_mask=prev_mask,
labels=labels,
label_set=label_set,
use_cfp=self.use_cfp,
)
return val_outputs
6 changes: 1 addition & 5 deletions models/vista3d/scripts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):

def _compute_pred_loss():
outputs = engine.network(
input_images=inputs,
point_coords=point,
point_labels=point_label,
class_vector=label_prompt,
use_cfp=engine.hyper_kwargs["use_cfp"],
input_images=inputs, point_coords=point, point_labels=point_label, class_vector=label_prompt
)
# engine.state.output[Keys.PRED] = outputs
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
Expand Down
Loading