Skip to content

Commit 20608cc

Browse files
NuClick + Classification (user experience) (#1092)
* Draft for nuclei classification model Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * sync local changes Signed-off-by: Sachidanand Alle <[email protected]> * Sync up local changes Signed-off-by: Sachidanand Alle <[email protected]> * fix inference Signed-off-by: Sachidanand Alle <[email protected]> * Add QuPath Fixes Signed-off-by: SACHIDANAND ALLE <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sachidanand Alle <[email protected]> Signed-off-by: SACHIDANAND ALLE <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5882ccf commit 20608cc

File tree

24 files changed

+1054
-196
lines changed

24 files changed

+1054
-196
lines changed

monailabel/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class Settings(BaseSettings):
5454
MONAI_LABEL_DICOMWEB_READ_TIMEOUT: float = 5.0
5555

5656
MONAI_LABEL_DATASTORE_AUTO_RELOAD: bool = True
57+
MONAI_LABEL_DATASTORE_READ_ONLY: bool = False
5758
MONAI_LABEL_DATASTORE_FILE_EXT: List[str] = [
5859
"*.nii.gz",
5960
"*.nii",

monailabel/datastore/local.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
datastore_config: str = "datastore_v2.json",
104104
extensions=("*.nii.gz", "*.nii"),
105105
auto_reload=False,
106+
read_only=False,
106107
):
107108
"""
108109
Creates a `LocalDataset` object
@@ -142,7 +143,8 @@ def __init__(
142143
os.makedirs(self._datastore.label_path(DefaultLabelTag.ORIGINAL), exist_ok=True)
143144

144145
# reconcile the loaded datastore file with any existing files in the path
145-
self._reconcile_datastore()
146+
if not read_only:
147+
self._reconcile_datastore()
146148

147149
if auto_reload:
148150
logger.info("Start observing external modifications on datastore (AUTO RELOAD)")

monailabel/interfaces/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def init_datastore(self) -> Datastore:
135135
self.studies,
136136
extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT,
137137
auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD,
138+
read_only=settings.MONAI_LABEL_DATASTORE_READ_ONLY,
138139
)
139140

140141
def init_remote_datastore(self) -> Datastore:

monailabel/tasks/infer/basic_infer.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import time
1616
from abc import abstractmethod
17+
from enum import Enum
1718
from typing import Any, Callable, Dict, Sequence, Tuple, Union
1819

1920
import torch
@@ -29,6 +30,14 @@
2930
logger = logging.getLogger(__name__)
3031

3132

33+
class CallBackTypes(str, Enum):
34+
PRE_TRANSFORMS = "PRE_TRANSFORMS"
35+
INFERER = "INFERER"
36+
INVERT_TRANSFORMS = "INVERT_TRANSFORMS"
37+
POST_TRANSFORMS = "POST_TRANSFORMS"
38+
WRITER = "WRITER"
39+
40+
3241
class BasicInferTask(InferTask):
3342
"""
3443
Basic Inference Task Helper
@@ -71,12 +80,11 @@ def __init__(
7180
:param train_mode: Run in Train mode instead of eval (when network has dropouts)
7281
:param skip_writer: Skip Writer and return data dictionary
7382
"""
83+
84+
super().__init__(type, labels, dimension, description, config)
85+
7486
self.path = [] if not path else [path] if isinstance(path, str) else path
7587
self.network = network
76-
self.type = type
77-
self.labels = [] if labels is None else [labels] if isinstance(labels, str) else labels
78-
self.dimension = dimension
79-
self.description = description
8088
self.model_state_dict = model_state_dict
8189
self.input_key = input_key
8290
self.output_label_key = output_label_key
@@ -88,15 +96,18 @@ def __init__(
8896

8997
self._networks: Dict = {}
9098

91-
self._config: Dict[str, Any] = {
92-
"device": device_list(),
93-
# "result_extension": None,
94-
# "result_dtype": None,
95-
# "result_compress": False
96-
# "roi_size": self.roi_size,
97-
# "sw_batch_size": 1,
98-
# "sw_overlap": 0.25,
99-
}
99+
self._config.update(
100+
{
101+
"device": device_list(),
102+
# "result_extension": None,
103+
# "result_dtype": None,
104+
# "result_compress": False
105+
# "roi_size": self.roi_size,
106+
# "sw_batch_size": 1,
107+
# "sw_overlap": 0.25,
108+
}
109+
)
110+
100111
if config:
101112
self._config.update(config)
102113

@@ -232,14 +243,20 @@ def inferer(self, data=None) -> Inferer:
232243
)
233244
return SimpleInferer()
234245

235-
def __call__(self, request) -> Union[Dict, Tuple[str, Dict[str, Any]]]:
246+
def __call__(
247+
self, request, callbacks: Union[Dict[CallBackTypes, Any], None] = None
248+
) -> Union[Dict, Tuple[str, Dict[str, Any]]]:
236249
"""
237250
It provides basic implementation to run the following in order
238251
- Run Pre Transforms
239252
- Run Inferer
253+
- Run Invert Transforms
240254
- Run Post Transforms
241255
- Run Writer to save the label mask and result params
242256
257+
You can provide callbacks which can be useful while writing pipelines to consume intermediate outputs
258+
Callback function should consume data and return data (modified/updated) e.g. `def my_cb(data): return data`
259+
243260
Returns: Label (File Path) and Result Params (JSON)
244261
"""
245262
begin = time.time()
@@ -262,28 +279,47 @@ def __call__(self, request) -> Union[Dict, Tuple[str, Dict[str, Any]]]:
262279
dump_data(req, logger.level)
263280
data = req
264281

282+
# callbacks useful in case of pipeliens to consume intermediate output from each of the following stages
283+
# callback function should consume data and returns data (modified/updated)
284+
callbacks = callbacks if callbacks else {}
285+
callback_run_pre_transforms = callbacks.get(CallBackTypes.PRE_TRANSFORMS)
286+
callback_run_inferer = callbacks.get(CallBackTypes.INFERER)
287+
callback_run_invert_transforms = callbacks.get(CallBackTypes.INVERT_TRANSFORMS)
288+
callback_run_post_transforms = callbacks.get(CallBackTypes.POST_TRANSFORMS)
289+
callback_writer = callbacks.get(CallBackTypes.WRITER)
290+
265291
start = time.time()
266292
pre_transforms = self.pre_transforms(data)
267293
data = self.run_pre_transforms(data, pre_transforms)
294+
if callback_run_pre_transforms:
295+
data = callback_run_pre_transforms(data)
268296
latency_pre = time.time() - start
269297

270298
start = time.time()
271299
data = self.run_inferer(data, device=device)
300+
if callback_run_inferer:
301+
data = callback_run_inferer(data)
272302
latency_inferer = time.time() - start
273303

274304
start = time.time()
275305
data = self.run_invert_transforms(data, pre_transforms, self.inverse_transforms(data))
306+
if callback_run_invert_transforms:
307+
data = callback_run_invert_transforms(data)
276308
latency_invert = time.time() - start
277309

278310
start = time.time()
279311
data = self.run_post_transforms(data, self.post_transforms(data))
312+
if callback_run_post_transforms:
313+
data = callback_run_post_transforms(data)
280314
latency_post = time.time() - start
281315

282316
if self.skip_writer:
283317
return dict(data)
284318

285319
start = time.time()
286320
result_file_name, result_json = self.writer(data)
321+
if callback_writer:
322+
data = callback_writer(data)
287323
latency_write = time.time() - start
288324

289325
latency_total = time.time() - begin
@@ -467,7 +503,8 @@ def writer(self, data, extension=None, dtype=None) -> Tuple[Any, Any]:
467503
if isinstance(self.labels, dict):
468504
label_names = {v: k for k, v in self.labels.items()}
469505
else:
470-
label_names = {v: k for v, k in enumerate(self.labels)}
506+
label_names = {v: k for v, k in enumerate(self.labels)} if isinstance(self.labels, Sequence) else None
507+
471508
cw = ClassificationWriter(label=self.output_label_key, label_names=label_names)
472509
return cw(data)
473510

monailabel/tasks/train/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def tensor_to_list(d):
3232
return r
3333

3434
stats: Dict[str, Any] = dict()
35-
stats.update(trainer.get_train_stats())
35+
stats.update(trainer.get_stats())
3636
stats["epoch"] = trainer.state.epoch
3737
stats["start_ts"] = int(start_ts)
3838

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/MonaiLabelClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,19 @@ public static class Strategy {
9898
public String description;
9999
}
100100

101+
public static class Trainer {
102+
public String description;
103+
public Map<String, Object> config;
104+
}
105+
101106
public static class ResponseInfo {
102107
public String name;
103108
public String description;
104109
public String version;
105110
public Labels labels;
106111
public Map<String, Model> models;
107112
public Map<String, Strategy> strategies;
113+
public Map<String, Trainer> trainers;
108114
}
109115

110116
public static class ImageInfo {

plugins/qupath/src/main/java/qupath/lib/extension/monailabel/commands/RunInference.java

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import java.nio.file.Path;
2020
import java.util.ArrayList;
2121
import java.util.Arrays;
22-
import java.util.HashMap;
2322
import java.util.HashSet;
2423
import java.util.List;
25-
import java.util.Map;
2624
import java.util.Set;
2725

2826
import javax.xml.parsers.ParserConfigurationException;
@@ -50,6 +48,7 @@
5048
import qupath.lib.plugins.parameters.ParameterList;
5149
import qupath.lib.regions.ImagePlane;
5250
import qupath.lib.regions.RegionRequest;
51+
import qupath.lib.roi.PointsROI;
5352
import qupath.lib.roi.ROIs;
5453
import qupath.lib.roi.interfaces.ROI;
5554
import qupath.lib.scripting.QP;
@@ -80,38 +79,27 @@ public void run() {
8079
}
8180

8281
ResponseInfo info = MonaiLabelClient.info();
83-
List<String> names = new ArrayList<String>();
84-
Map<String, String[]> labels = new HashMap<String, String[]>();
85-
for (String n : info.models.keySet()) {
86-
names.add(n);
87-
labels.put(n, info.models.get(n).labels.labels());
88-
}
82+
List<String> names = Arrays.asList(info.models.keySet().toArray(new String[0]));
8983

90-
ParameterList list = new ParameterList();
9184
if (selectedModel == null || selectedModel.isEmpty()) {
9285
selectedModel = names.isEmpty() ? "" : names.get(0);
9386
}
9487

88+
ParameterList list = new ParameterList();
9589
list.addChoiceParameter("Model", "Model Name", selectedModel, names);
9690
list.addStringParameter("Location", "Location (x,y,w,h)", Arrays.toString(bbox));
9791
list.addIntParameter("TileSize", "TileSize", tileSize);
9892

99-
boolean override = !info.models.get(selectedModel).nuclick;
100-
list.addBooleanParameter("Override", "Override", override);
101-
10293
if (Dialogs.showParameterDialog("MONAILabel", list)) {
10394
String model = (String) list.getChoiceParameterValue("Model");
10495
bbox = Utils.parseStringArray(list.getStringParameterValue("Location"));
105-
override = list.getBooleanParameterValue("Override").booleanValue();
10696
tileSize = list.getIntParameterValue("TileSize").intValue();
10797

10898
selectedModel = model;
10999
selectedBBox = bbox;
110100
selectedTileSize = tileSize;
111101

112-
boolean validateClicks = info.models.get(selectedModel).nuclick;
113-
runInference(model, new HashSet<String>(Arrays.asList(labels.get(model))), bbox, tileSize, imageData,
114-
override, validateClicks);
102+
runInference(model, info, bbox, tileSize, imageData);
115103
}
116104
} catch (Exception ex) {
117105
ex.printStackTrace();
@@ -124,13 +112,16 @@ ArrayList<Point2> getClicks(String name, ImageData<BufferedImage> imageData, ROI
124112
List<PathObject> objs = imageData.getHierarchy().getFlattenedObjectList(null);
125113
ArrayList<Point2> clicks = new ArrayList<Point2>();
126114
for (int i = 0; i < objs.size(); i++) {
127-
String pname = objs.get(i).getPathClass() == null ? "" : objs.get(i).getPathClass().getName();
128-
if (pname.equalsIgnoreCase(name)) {
129-
ROI r = objs.get(i).getROI();
130-
List<Point2> points = r.getAllPoints();
131-
for (Point2 p : points) {
132-
if (monaiLabelROI.contains(p.getX(), p.getY())) {
133-
clicks.add(new Point2(p.getX() - offsetX, p.getY() - offsetY));
115+
var obj = objs.get(i);
116+
String pname = obj.getPathClass() == null ? "" : obj.getPathClass().getName();
117+
if (name.isEmpty() || pname.equalsIgnoreCase(name)) {
118+
ROI r = obj.getROI();
119+
if (r instanceof PointsROI) {
120+
List<Point2> points = r.getAllPoints();
121+
for (Point2 p : points) {
122+
if (monaiLabelROI.contains(p.getX(), p.getY())) {
123+
clicks.add(new Point2(p.getX() - offsetX, p.getY() - offsetY));
124+
}
134125
}
135126
}
136127
}
@@ -140,12 +131,17 @@ ArrayList<Point2> getClicks(String name, ImageData<BufferedImage> imageData, ROI
140131
return clicks;
141132
}
142133

143-
private void runInference(String model, Set<String> labels, int[] bbox, int tileSize,
144-
ImageData<BufferedImage> imageData, boolean override, boolean validateClicks)
134+
private void runInference(String model, ResponseInfo info, int[] bbox, int tileSize,
135+
ImageData<BufferedImage> imageData)
145136
throws SAXException, IOException, ParserConfigurationException, InterruptedException {
146137
logger.info("MONAILabel:: Running Inference...");
147-
logger.info("MONAILabel:: Model: " + model + "; override: " + override + "; clicks:" + validateClicks
148-
+ "; Labels: " + labels);
138+
139+
boolean isNuClick = info.models.get(model).nuclick;
140+
boolean override = !isNuClick;
141+
boolean validateClicks = isNuClick;
142+
var labels = new HashSet<String>(Arrays.asList(info.models.get(model).labels.labels()));
143+
144+
logger.info("MONAILabel:: Model: " + model + "; Labels: " + labels);
149145

150146
Path imagePatch = null;
151147
try {
@@ -178,30 +174,38 @@ private void runInference(String model, Set<String> labels, int[] bbox, int tile
178174
req.location[0] = req.location[1] = 0;
179175
req.size[0] = req.size[1] = 0;
180176

181-
var fg = getClicks("Positive", imageData, roi, offsetX, offsetY);
182-
var bg = getClicks("Negative", imageData, roi, offsetX, offsetY);
183-
if (validateClicks) {
184-
if (fg.size() == 0 && bg.size() == 0) {
185-
Dialogs.showErrorMessage("MONAILabel",
186-
"Need atleast one Postive/Negative annotation/click point within the ROI");
187-
return;
188-
}
189-
if (roi.getBoundsHeight() < 128 || roi.getBoundsWidth() < 128) {
190-
Dialogs.showErrorMessage("MONAILabel",
191-
"Min Height/Width of ROI should be more than 128");
192-
return;
193-
}
194-
}
195-
196-
req.params.addClicks(fg, true);
197-
req.params.addClicks(bg, false);
198177

199178
imagePatch = java.nio.file.Files.createTempFile("patch", ".png");
200179
imageFile = imagePatch.toString();
201180
var requestROI = RegionRequest.createInstance(imageData.getServer().getPath(), 1, roi);
202181
ImageWriterTools.writeImageRegion(imageData.getServer(), requestROI, imageFile);
203182
}
204183

184+
ArrayList<Point2> fg = new ArrayList<>();
185+
ArrayList<Point2> bg = new ArrayList<>();
186+
if (isNuClick) {
187+
fg = getClicks("", imageData, roi, offsetX, offsetY);
188+
} else {
189+
fg = getClicks("Positive", imageData, roi, offsetX, offsetY);
190+
bg = getClicks("Negative", imageData, roi, offsetX, offsetY);
191+
}
192+
193+
if (validateClicks) {
194+
if (fg.size() == 0 && bg.size() == 0) {
195+
Dialogs.showErrorMessage("MONAILabel",
196+
"Need atleast one Postive/Negative annotation/click point within the ROI");
197+
return;
198+
}
199+
if (roi.getBoundsHeight() < 128 || roi.getBoundsWidth() < 128) {
200+
Dialogs.showErrorMessage("MONAILabel",
201+
"Min Height/Width of ROI should be more than 128");
202+
return;
203+
}
204+
}
205+
req.params.addClicks(fg, true);
206+
req.params.addClicks(bg, false);
207+
208+
205209
Document dom = MonaiLabelClient.infer(model, image, imageFile, sessionId, req);
206210
NodeList annotation_list = dom.getElementsByTagName("Annotation");
207211
int count = updateAnnotations(labels, annotation_list, roi, imageData, override, offsetX, offsetY);
@@ -227,6 +231,20 @@ private int updateAnnotations(Set<String> labels, NodeList annotation_list, ROI
227231
}
228232
}
229233
}
234+
} else {
235+
List<PathObject> objs = imageData.getHierarchy().getFlattenedObjectList(null);
236+
for (int i = 0; i < objs.size(); i++) {
237+
var obj = objs.get(i);
238+
ROI r = obj.getROI();
239+
if (r instanceof PointsROI) {
240+
String pname = obj.getPathClass() == null ? "" : obj.getPathClass().getName();
241+
if (pname.equalsIgnoreCase("Positive") || pname.equalsIgnoreCase("Negative")) {
242+
continue;
243+
}
244+
imageData.getHierarchy().removeObjectWithoutUpdate(obj, false);
245+
}
246+
}
247+
QP.fireHierarchyUpdate(imageData.getHierarchy());
230248
}
231249

232250
int count = 0;

0 commit comments

Comments
 (0)