@@ -850,19 +850,17 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
850850 center_int = torch .minimum (torch .maximum (center_int , torch .zeros_like (center_int )), bounds_max )
851851 # Place impulse (use maximum in case of overlapping landmarks)
852852 current_val = heatmap [idx ][tuple (center_int )]
853- heatmap [idx ][tuple (center_int )] = torch .maximum (current_val , torch .tensor (1.0 , dtype = self .torch_dtype , device = device ))
853+ heatmap [idx ][tuple (center_int )] = torch .maximum (
854+ current_val , torch .tensor (1.0 , dtype = self .torch_dtype , device = device )
855+ )
854856
855857 # Apply Gaussian blur using GaussianFilter
856858 # Reshape to (num_points, 1, *spatial) for per-channel filtering
857859 heatmap_input = heatmap .unsqueeze (1 ) # Add channel dimension
858860
859861 gaussian_filter = GaussianFilter (
860- spatial_dims = spatial_dims ,
861- sigma = sigma ,
862- truncated = self .truncated ,
863- approx = "erf" ,
864- requires_grad = False
865- ).to (device )
862+ spatial_dims = spatial_dims , sigma = sigma , truncated = self .truncated , approx = "erf" , requires_grad = False
863+ ).to (device = device , dtype = self .torch_dtype )
866864
867865 heatmap_blurred = gaussian_filter (heatmap_input )
868866 heatmap = heatmap_blurred .squeeze (1 ) # Remove channel dimension
0 commit comments