Skip to content

Commit 5fd2c37

Browse files
committed
fix dtype
Signed-off-by: sewon.jeon <[email protected]>
1 parent d8fb2d2 commit 5fd2c37

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

monai/transforms/post/array.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -850,19 +850,17 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
850850
center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)
851851
# Place impulse (use maximum in case of overlapping landmarks)
852852
current_val = heatmap[idx][tuple(center_int)]
853-
heatmap[idx][tuple(center_int)] = torch.maximum(current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device))
853+
heatmap[idx][tuple(center_int)] = torch.maximum(
854+
current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device)
855+
)
854856

855857
# Apply Gaussian blur using GaussianFilter
856858
# Reshape to (num_points, 1, *spatial) for per-channel filtering
857859
heatmap_input = heatmap.unsqueeze(1) # Add channel dimension
858860

859861
gaussian_filter = GaussianFilter(
860-
spatial_dims=spatial_dims,
861-
sigma=sigma,
862-
truncated=self.truncated,
863-
approx="erf",
864-
requires_grad=False
865-
).to(device)
862+
spatial_dims=spatial_dims, sigma=sigma, truncated=self.truncated, approx="erf", requires_grad=False
863+
).to(device=device, dtype=self.torch_dtype)
866864

867865
heatmap_blurred = gaussian_filter(heatmap_input)
868866
heatmap = heatmap_blurred.squeeze(1) # Remove channel dimension

0 commit comments

Comments
 (0)