Skip to content

Commit d65df86

Browse files
Can-Zhaopre-commit-ci[bot]binliunls
authored
New Accelerated MAISI, inference only (#726)
Fixes # . ### Description New Accelerated MAISI, inference only ### Status **Ready/Work in progress/Hold** ### Please ensure all the checkboxes: <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [x] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [x] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: Can Zhao <[email protected]> Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: binliu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: binliu <[email protected]> Co-authored-by: binliunls <[email protected]>
1 parent 170be7d commit d65df86

File tree

11 files changed

+420
-244
lines changed

11 files changed

+420
-244
lines changed

ci/unit_tests/test_maisi_ct_generative.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,6 @@
8585
}
8686
]
8787

88-
TEST_CASE_INFER_ERROR = [
89-
{
90-
"bundle_root": "models/maisi_ct_generative",
91-
"num_output_samples": 1,
92-
"output_size": [256, 256, 256],
93-
"body_region": ["head"],
94-
"anatomy_list": ["colon cancer primaries"],
95-
},
96-
"Cannot find body region with given anatomy list.",
97-
]
98-
9988
TEST_CASE_INFER_ERROR_2 = [
10089
{
10190
"bundle_root": "models/maisi_ct_generative",
@@ -277,7 +266,7 @@ def test_infer_config(self, override):
277266
else:
278267
self.assertTrue(output_file.endswith(".nii.gz"))
279268

280-
@parameterized.expand([TEST_CASE_INFER_ERROR, TEST_CASE_INFER_ERROR_7])
269+
@parameterized.expand([TEST_CASE_INFER_ERROR_7])
281270
def test_infer_config_error_input(self, override, expected_error):
282271
# update override
283272
override["output_dir"] = self.output_dir

models/maisi_ct_generative/configs/inference.json

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
"output_dir": "$@bundle_root + '/output'",
1010
"create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)",
1111
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
12-
"trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'",
13-
"trained_diffusion_path": "$@model_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'",
14-
"trained_controlnet_path": "$@model_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'",
12+
"trained_autoencoder_path": "$@model_dir + '/autoencoder.pt'",
13+
"trained_diffusion_path": "$@model_dir + '/diffusion_unet.pt'",
14+
"trained_controlnet_path": "$@model_dir + '/controlnet.pt'",
1515
"trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'",
1616
"trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'",
1717
"all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'",
@@ -21,14 +21,13 @@
2121
"label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'",
2222
"real_img_median_statistics_file": "$@bundle_root + '/configs/image_median_statistics.json'",
2323
"num_output_samples": 1,
24-
"body_region": [
25-
"abdomen"
26-
],
24+
"body_region": [],
2725
"anatomy_list": [
2826
"liver"
2927
],
28+
"modality": "ct",
3029
"controllable_anatomy_size": [],
31-
"num_inference_steps": 1000,
30+
"num_inference_steps": 30,
3231
"mask_generation_num_inference_steps": 1000,
3332
"random_seed": null,
3433
"spatial_dims": 3,
@@ -63,11 +62,11 @@
6362
64
6463
],
6564
"autoencoder_sliding_window_infer_size": [
66-
96,
67-
96,
68-
96
65+
80,
66+
80,
67+
80
6968
],
70-
"autoencoder_sliding_window_infer_overlap": 0.6667,
69+
"autoencoder_sliding_window_infer_overlap": 0.4,
7170
"autoencoder_def": {
7271
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
7372
"spatial_dims": "@spatial_dims",
@@ -96,7 +95,7 @@
9695
"use_checkpointing": false,
9796
"use_convtranspose": false,
9897
"norm_float16": true,
99-
"num_splits": 8,
98+
"num_splits": 2,
10099
"dim_split": 1
101100
},
102101
"diffusion_unet_def": {
@@ -124,9 +123,12 @@
124123
],
125124
"num_res_blocks": 2,
126125
"use_flash_attention": true,
127-
"include_top_region_index_input": true,
128-
"include_bottom_region_index_input": true,
129-
"include_spacing_input": true
126+
"include_top_region_index_input": false,
127+
"include_bottom_region_index_input": false,
128+
"include_spacing_input": true,
129+
"num_class_embeds": 128,
130+
"resblock_updown": true,
131+
"include_fc": true
130132
},
131133
"controlnet_def": {
132134
"_target_": "monai.apps.generation.maisi.networks.controlnet_maisi.ControlNetMaisi",
@@ -157,7 +159,10 @@
157159
8,
158160
32,
159161
64
160-
]
162+
],
163+
"num_class_embeds": 128,
164+
"resblock_updown": true,
165+
"include_fc": true
161166
},
162167
"mask_generation_autoencoder_def": {
163168
"_target_": "monai.apps.generation.maisi.networks.autoencoderkl_maisi.AutoencoderKlMaisi",
@@ -239,12 +244,11 @@
239244
"load_mask_generation_diffusion": "$@mask_generation_diffusion_unet.load_state_dict(@checkpoint_mask_generation_diffusion_unet['unet_state_dict'], strict=True)",
240245
"mask_generation_scale_factor": "$@checkpoint_mask_generation_diffusion_unet['scale_factor']",
241246
"noise_scheduler": {
242-
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
247+
"_target_": "scripts.rectified_flow.RFlowScheduler",
243248
"num_train_timesteps": 1000,
244-
"beta_start": 0.0015,
245-
"beta_end": 0.0195,
246-
"schedule": "scaled_linear_beta",
247-
"clip_sample": false
249+
"use_discrete_timesteps": false,
250+
"use_timestep_transform": true,
251+
"sample_method": "uniform"
248252
},
249253
"mask_generation_noise_scheduler": {
250254
"_target_": "monai.networks.schedulers.ddpm.DDPMScheduler",
@@ -269,6 +273,7 @@
269273
],
270274
"body_region": "@body_region",
271275
"anatomy_list": "@anatomy_list",
276+
"modality": "@modality",
272277
"all_mask_files_json": "@all_mask_files_json",
273278
"all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json",
274279
"all_mask_files_base_dir": "@all_mask_files_base_dir",
@@ -300,6 +305,7 @@
300305
"autoencoder_sliding_window_infer_overlap": "@autoencoder_sliding_window_infer_overlap"
301306
},
302307
"run": [
308+
"$monai.utils.set_determinism(seed=@random_seed)",
303309
"$@ldm_sampler.sample_multiple_images(@num_output_samples)"
304310
],
305311
"evaluator": null

models/maisi_ct_generative/configs/metadata.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{
22
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json",
3-
"version": "0.4.6",
3+
"version": "1.0.0",
44
"changelog": {
5+
"1.0.0": "accelerated maisi, inference only, is not compartible with previous maisi diffusion model weights",
56
"0.4.6": "add TensorRT support",
67
"0.4.5": "update README",
78
"0.4.4": "update issue for IgniteInfo",

models/maisi_ct_generative/docs/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This bundle is for Nvidia MAISI (Medical AI for Synthetic Imaging), a 3D Latent
44
The inference workflow of MAISI is depicted in the figure below. It first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Then it decodes the denoised latent features into images using the trained autoencoder.
55

66
<p align="center">
7-
<img src="https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_workflow.png" alt="MAISI inference scheme">
7+
<img src="https://developer.download.nvidia.com/assets/Clara/Images/maisi_workflow_1.0.1.png" alt="MAISI inference scheme">
88
</p>
99

1010
MAISI is based on the following papers:
@@ -13,6 +13,8 @@ MAISI is based on the following papers:
1313

1414
[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf)
1515

16+
[**Rectified Flow:** Liu, Xingchao, and Chengyue Gong. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." ICLR 2023.](https://arxiv.org/pdf/2209.03003)
17+
1618
#### Example synthetic image
1719
An example result from inference is shown below:
1820
![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png)
@@ -27,11 +29,11 @@ The information for the inference input, like body region and anatomy to generat
2729

2830
- `"num_output_samples"`: int, the number of output image/mask pairs it will generate.
2931
- `"spacing"`: voxel size of generated images. E.g., if set to `[1.5, 1.5, 2.0]`, it will generate images with a resolution of 1.5&times;1.5&times;2.0 mm. The spacing for x and y axes has to be between 0.5 and 3.0 mm and the spacing for the z axis has to be between 0.5 and 5.0 mm.
30-
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512&times;512&times;256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768&times;768&times;512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, and at least 384mm for other body regions like abdomen. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768].
32+
- `"output_size"`: volume size of generated images. E.g., if set to `[512, 512, 256]`, it will generate images with size of 512&times;512&times;256. They need to be divisible by 16. If you have a small GPU memory size, you should adjust it to small numbers. Note that `"spacing"` and `"output_size"` together decide the output field of view (FOV). For eample, if set them to `[1.5, 1.5, 2.0]`mm and `[512, 512, 256]`, the FOV is 768&times;768&times;512 mm. We recommend output_size is the FOV in x and y axis are same and to be at least 256mm for head, at least 384mm for other body regions like abdomen, and no larger than 640mm. The output size for the x and y axes can be selected from [256, 384, 512], while for the z axis, it can be chosen from [128, 256, 384, 512, 640, 768].
3133
- `"controllable_anatomy_size"`: a list of controllable anatomy and its size scale (0--1). E.g., if set to `[["liver", 0.5],["hepatic tumor", 0.3]]`, the generated image will contain liver that have a median size, with size around 50% percentile, and hepatic tumor that is relatively small, with around 30% percentile. In addition, if the size scale is set to -1, it indicates that the organ does not exist or should be removed. The output will contain paired image and segmentation mask for the controllable anatomy.
3234
The following organs support generation with a controllable size: ``["liver", "gallbladder", "stomach", "pancreas", "colon", "lung tumor", "bone lesion", "hepatic tumor", "colon cancer primaries", "pancreatic tumor"]``.
3335
The raw output of the current mask generation model has a fixed size of $256^3$ voxels with a spacing of $1.5^3$ mm. If the "output_size" differs from this default, the generated masks will be resampled to the desired `"output_size"` and `"spacing"`. Note that resampling may degrade the quality of the generated masks and could trigger multiple inference attempts if the images fail to pass the [image quality check](../scripts/quality_check.py).
34-
- `"body_region"`: If "controllable_anatomy_size" is not specified, "body_region" will be used to constrain the region of generated images. It needs to be chosen from "head", "chest", "thorax", "abdomen", "pelvis", "lower".
36+
- `"body_region"`: Deprecated, please leave it as empty `"[]"`.
3537
- `"anatomy_list"`: If "controllable_anatomy_size" is not specified, the output will contain paired image and segmentation mask for the anatomy in "./configs/label_dict.json".
3638
- `"autoencoder_sliding_window_infer_size"`: in order to save GPU memory, we use sliding window inference when decoding latents to image when `"output_size"` is large. This is the patch size of the sliding window. Small value will reduce GPU memory but increase time cost. They need to be divisible by 16.
3739
- `"autoencoder_sliding_window_infer_overlap"`: float between 0 and 1. Large value will reduce the stitching artifacts when stitching patches during sliding window inference, but increase time cost. If you do not observe seam lines in the generated image result, you can use a smaller value to save inference time.

models/maisi_ct_generative/large_files.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
large_files:
2-
- path: "models/autoencoder_epoch273.pt"
2+
- path: "models/autoencoder.pt"
33
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_autoencoder_epoch273_alternative.pt"
44
hash_val: "917cfb1e49631c8a713e3bb7c758fbca"
55
hash_type: "md5"
@@ -11,6 +11,14 @@ large_files:
1111
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_maisi_controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current_alternative.pt"
1212
hash_val: "6c36572335372f405a0e85c760fa6dee"
1313
hash_type: "md5"
14+
- path: "models/diffusion_unet.pt"
15+
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/diff_unet_ckpt_rflow_epoch19350.pt"
16+
hash_val: "10501d59a3066802087c82ebd7a71719"
17+
hash_type: "md5"
18+
- path: "models/controlnet.pt"
19+
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/controlnet_rflow_epoch208.pt"
20+
hash_val: "49933da32826c0f7ca17016ccd13e23b"
21+
hash_type: "md5"
1422
- path: "models/mask_generation_autoencoder.pt"
1523
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/mask_generation_autoencoder.pt"
1624
hash_val: "b177778820f412abc9218cdb7ce3b653"

models/maisi_ct_generative/scripts/augmentation.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def dilate3d(input_tensor, erosion=3):
6060
return output.squeeze(0).squeeze(0)
6161

6262

63-
def augmentation_tumor_bone(pt_nda, output_size):
63+
def augmentation_tumor_bone(pt_nda, output_size, random_seed):
6464
volume = pt_nda.squeeze(0)
6565
real_l_volume_ = torch.zeros_like(volume)
6666
real_l_volume_[volume == 128] = 1
@@ -74,6 +74,7 @@ def augmentation_tumor_bone(pt_nda, output_size):
7474
scale_range=(0.15, 0.15, 0),
7575
padding_mode="zeros",
7676
)
77+
elastic.set_random_state(seed=random_seed)
7778

7879
tumor_szie = torch.sum((real_l_volume_ > 0).float())
7980
###########################
@@ -112,7 +113,7 @@ def augmentation_tumor_bone(pt_nda, output_size):
112113
return pt_nda
113114

114115

115-
def augmentation_tumor_liver(pt_nda, output_size):
116+
def augmentation_tumor_liver(pt_nda, output_size, random_seed):
116117
volume = pt_nda.squeeze(0)
117118
real_l_volume_ = torch.zeros_like(volume)
118119
real_l_volume_[volume == 1] = 1
@@ -129,6 +130,7 @@ def augmentation_tumor_liver(pt_nda, output_size):
129130
scale_range=(0.2, 0.2, 0.2),
130131
padding_mode="zeros",
131132
)
133+
elastic.set_random_state(seed=random_seed)
132134

133135
tumor_szie = torch.sum(real_l_volume_ == 2)
134136
###########################
@@ -161,7 +163,7 @@ def augmentation_tumor_liver(pt_nda, output_size):
161163
return pt_nda
162164

163165

164-
def augmentation_tumor_lung(pt_nda, output_size):
166+
def augmentation_tumor_lung(pt_nda, output_size, random_seed):
165167
volume = pt_nda.squeeze(0)
166168
real_l_volume_ = torch.zeros_like(volume)
167169
real_l_volume_[volume == 23] = 1
@@ -177,6 +179,7 @@ def augmentation_tumor_lung(pt_nda, output_size):
177179
scale_range=(0.15, 0.15, 0.15),
178180
padding_mode="zeros",
179181
)
182+
elastic.set_random_state(seed=random_seed)
180183

181184
tumor_szie = torch.sum(real_l_volume_)
182185
# before move lung tumor maks, full the original location by lung labels
@@ -224,7 +227,7 @@ def augmentation_tumor_lung(pt_nda, output_size):
224227
return pt_nda
225228

226229

227-
def augmentation_tumor_pancreas(pt_nda, output_size):
230+
def augmentation_tumor_pancreas(pt_nda, output_size, random_seed):
228231
volume = pt_nda.squeeze(0)
229232
real_l_volume_ = torch.zeros_like(volume)
230233
real_l_volume_[volume == 4] = 1
@@ -241,6 +244,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size):
241244
scale_range=(0.1, 0.1, 0.1),
242245
padding_mode="zeros",
243246
)
247+
elastic.set_random_state(seed=random_seed)
244248

245249
tumor_szie = torch.sum(real_l_volume_ == 2)
246250
###########################
@@ -273,7 +277,7 @@ def augmentation_tumor_pancreas(pt_nda, output_size):
273277
return pt_nda
274278

275279

276-
def augmentation_tumor_colon(pt_nda, output_size):
280+
def augmentation_tumor_colon(pt_nda, output_size, random_seed):
277281
volume = pt_nda.squeeze(0)
278282
real_l_volume_ = torch.zeros_like(volume)
279283
real_l_volume_[volume == 27] = 1
@@ -289,6 +293,7 @@ def augmentation_tumor_colon(pt_nda, output_size):
289293
scale_range=(0.1, 0.1, 0.1),
290294
padding_mode="zeros",
291295
)
296+
elastic.set_random_state(seed=random_seed)
292297

293298
tumor_szie = torch.sum(real_l_volume_)
294299
###########################
@@ -330,37 +335,39 @@ def augmentation_tumor_colon(pt_nda, output_size):
330335
return pt_nda
331336

332337

333-
def augmentation_body(pt_nda):
338+
def augmentation_body(pt_nda, random_seed):
334339
volume = pt_nda.squeeze(0)
335340

336341
zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0)
342+
zoom.set_random_state(seed=random_seed)
343+
337344
volume = zoom(volume)
338345

339346
pt_nda = volume.unsqueeze(0)
340347
return pt_nda
341348

342349

343-
def augmentation(pt_nda, output_size):
350+
def augmentation(pt_nda, output_size, random_seed):
344351
label_list = torch.unique(pt_nda)
345352
label_list = list(label_list.cpu().numpy())
346353

347354
if 128 in label_list:
348355
print("augmenting bone lesion/tumor")
349-
pt_nda = augmentation_tumor_bone(pt_nda, output_size)
356+
pt_nda = augmentation_tumor_bone(pt_nda, output_size, random_seed)
350357
elif 26 in label_list:
351358
print("augmenting liver tumor")
352-
pt_nda = augmentation_tumor_liver(pt_nda, output_size)
359+
pt_nda = augmentation_tumor_liver(pt_nda, output_size, random_seed)
353360
elif 23 in label_list:
354361
print("augmenting lung tumor")
355-
pt_nda = augmentation_tumor_lung(pt_nda, output_size)
362+
pt_nda = augmentation_tumor_lung(pt_nda, output_size, random_seed)
356363
elif 24 in label_list:
357364
print("augmenting pancreas tumor")
358-
pt_nda = augmentation_tumor_pancreas(pt_nda, output_size)
365+
pt_nda = augmentation_tumor_pancreas(pt_nda, output_size, random_seed)
359366
elif 27 in label_list:
360367
print("augmenting colon tumor")
361-
pt_nda = augmentation_tumor_colon(pt_nda, output_size)
368+
pt_nda = augmentation_tumor_colon(pt_nda, output_size, random_seed)
362369
else:
363370
print("augmenting body")
364-
pt_nda = augmentation_body(pt_nda)
371+
pt_nda = augmentation_body(pt_nda, random_seed)
365372

366373
return pt_nda

0 commit comments

Comments
 (0)