From ce2bfe432198d5a36b776af33a9d07c959242334 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Mon, 14 Apr 2025 23:15:57 +0400 Subject: [PATCH 1/4] deformable conv2d kernel for mps --- test/test_ops.py | 57 +++++-- .../csrc/ops/mps/deform_conv2d_kernel.mm | 151 +++++++++++++++++ torchvision/csrc/ops/mps/mps_kernels.h | 159 ++++++++++++++++++ 3 files changed, 357 insertions(+), 10 deletions(-) create mode 100644 torchvision/csrc/ops/mps/deform_conv2d_kernel.mm diff --git a/test/test_ops.py b/test/test_ops.py index 88124f7ba17..98629998848 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -929,6 +929,7 @@ def test_batched_nms_implementations(self, seed): class TestDeformConv: dtype = torch.float64 + mps_dtype = torch.float32 def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1): stride_h, stride_w = _pair(stride) @@ -1050,12 +1051,11 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs - @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) - @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, batch_sz, dtype=None): - dtype = dtype or self.dtype + dtype = self.mps_dtype if device == "mps" else dtype or self.dtype x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) in_channels = 6 out_channels = 2 @@ -1201,13 +1201,50 @@ def test_forward_scriptability(self): torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) -optests.generate_opcheck_tests( - testcase=TestDeformConv, - namespaces=["torchvision"], - failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), - additional_decorators=[], - test_utils=OPTESTS, -) +@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64)) +@pytest.mark.parametrize("device", cpu_and_cuda()) +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_deform_conv2d_opcheck(dtype, device, requires_grad): + batch_size, channels_in, height, width = 1, 6, 10, 10 + kernel_size = (3, 3) + stride = (1, 1) + padding = (1, 1) + dilation = (1, 1) + groups = 2 + out_channels = 4 + out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 + out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 + x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad) + offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w, + dtype=dtype, device=device, requires_grad=requires_grad) + weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1], + dtype=dtype, device=device, requires_grad=requires_grad) + bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad) + use_mask = True + mask = torch.sigmoid(torch.randn( + batch_size, + kernel_size[0] * kernel_size[1], + out_h, + out_w, + dtype=dtype, device=device, requires_grad=requires_grad + )) + kwargs = { + "offset": offset, + "weight": weight, + "bias": bias, + "stride_h": stride[0], + "stride_w": stride[1], + "pad_h": padding[0], + "pad_w": padding[1], + "dilation_h": dilation[0], + "dilation_w": dilation[1], + "groups": groups, + "offset_groups": 1, + "use_mask": use_mask, + "mask": mask, # no modulation in this test + } + optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs) + class TestFrozenBNT: diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm new file mode 100644 index 00000000000..a9febf0481c --- /dev/null +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -0,0 +1,151 @@ +#include +#include +#include +#include "mps_kernels.h" + +namespace vision { +namespace ops { + +namespace { + +at::Tensor deform_conv2d_forward_kernel( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + int64_t stride_h, + int64_t stride_w, + int64_t pad_h, + int64_t pad_w, + int64_t dilation_h, + int64_t dilation_w, + int64_t n_weight_grps, + int64_t n_offset_grps, + bool use_mask) { + using namespace at::native::mps; + at::Tensor input_c = input.contiguous(); + at::Tensor weight_c = weight.contiguous(); + at::Tensor offset_c = offset.contiguous(); + at::Tensor mask_c = mask.contiguous(); + at::Tensor bias_c = bias.contiguous(); + + TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D"); + TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D"); + TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D"); + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true"); + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + + at::DeviceGuard guard(input_c.device()); + + int batch = input_c.size(0); + int in_channels = input_c.size(1); + int in_h = input_c.size(2); + int in_w = input_c.size(3); + int weight_h = weight_c.size(2); + int weight_w = weight_c.size(3); + int out_channels = weight_c.size(0); + int ker_h = dilation_h * (weight_h - 1) + 1; + int ker_w = dilation_w * (weight_w - 1) + 1; + int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + + TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, + "Input channels (", in_channels, + ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")"); + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0, + "Weight tensor's out channels (", weight_c.size(0), + ") must be divisible by n_weight_grps (", n_weight_grps, ")"); + TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w, + "Offset tensor shape[1] is invalid: got ", offset_c.size(1), + ", expected ", n_offset_grps * 2 * weight_h * weight_w); + TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w, + "Mask tensor shape[1] is invalid: got ", mask_c.size(1), + ", expected ", n_offset_grps * weight_h * weight_w); + TORCH_CHECK(in_channels % n_offset_grps == 0, + "Input tensor channels (", in_channels, + ") must be divisible by n_offset_grps (", n_offset_grps, ")"); + TORCH_CHECK(offset_c.size(0) == batch, + "Offset tensor batch size (", offset_c.size(0), + ") must match input tensor batch size (", batch, ")"); + TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w, + "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); + TORCH_CHECK(!use_mask || mask_c.size(0) == batch, + "Mask tensor batch size (", mask_c.size(0), + ") must match input tensor batch size (", batch, ")"); + TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w), + "Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3), + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); + TORCH_CHECK(out_h > 0 && out_w > 0, + "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); + + auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options()); + + id inputBuffer = getMTLBufferStorage(input_c); + id offsetBuffer = getMTLBufferStorage(offset_c); + id maskBuffer = use_mask ? getMTLBufferStorage(mask_c) : nil; + id outputBuffer = getMTLBufferStorage(columns); + + id device = MPSDevice::getInstance()->device(); + std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); + id pipelineState = mps::visionPipelineState(device, kernelName); + + int num_kernels = in_channels * out_h * out_w * batch; + NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup; + NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup; + MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^{ + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + [computeEncoder setComputePipelineState:pipelineState]; + [computeEncoder setBuffer:inputBuffer offset:0 atIndex:0]; + [computeEncoder setBuffer:offsetBuffer offset:0 atIndex:1]; + [computeEncoder setBuffer:maskBuffer offset:0 atIndex:2]; + [computeEncoder setBytes:&in_h length:sizeof(int) atIndex:3]; + [computeEncoder setBytes:&in_w length:sizeof(int) atIndex:4]; + [computeEncoder setBytes:&weight_h length:sizeof(int) atIndex:5]; + [computeEncoder setBytes:&weight_w length:sizeof(int) atIndex:6]; + [computeEncoder setBytes:&pad_h length:sizeof(int) atIndex:7]; + [computeEncoder setBytes:&pad_w length:sizeof(int) atIndex:8]; + [computeEncoder setBytes:&stride_h length:sizeof(int) atIndex:9]; + [computeEncoder setBytes:&stride_w length:sizeof(int) atIndex:10]; + [computeEncoder setBytes:&dilation_h length:sizeof(int) atIndex:11]; + [computeEncoder setBytes:&dilation_w length:sizeof(int) atIndex:12]; + [computeEncoder setBytes:&batch length:sizeof(int) atIndex:13]; + [computeEncoder setBytes:&in_channels length:sizeof(int) atIndex:14]; + [computeEncoder setBytes:&n_offset_grps length:sizeof(int) atIndex:15]; + [computeEncoder setBytes:&out_h length:sizeof(int) atIndex:16]; + [computeEncoder setBytes:&out_w length:sizeof(int) atIndex:17]; + [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:18]; + [computeEncoder setBuffer:outputBuffer offset:0 atIndex:19]; + + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; + } + }); + int in_channels_per_grp = in_channels / n_weight_grps; + int out_channels_per_grp = out_channels / n_weight_grps; + auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w}); + auto columns_grouped = columns.view({n_weight_grps, + (in_channels * weight_h * weight_w) / n_weight_grps, + batch * out_h * out_w}); + auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1}); + auto out_grouped = at::bmm(weight_reshaped, columns_grouped); + auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w}) + .transpose(0, 1); + return out + bias_c.view({1, out_channels, 1, 1}); +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), + TORCH_FN(deform_conv2d_forward_kernel)); +} + +} // namespace ops +} // namespace vision \ No newline at end of file diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index f85546a6c41..2f24c86c6bf 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -91,6 +91,52 @@ inline T bilinear_interpolate( return val; } +template +inline T bilinear_interpolate_deformable_conv2d( + constant T* input, + integer_t height, + integer_t width, + T y, + T x, + uint index /* index for debug only*/) { + if (y <= -1.0 || y >= height || x <= -1.0 || x >= width) { + return 0; + } + integer_t y_low = static_cast(floor(y)); + integer_t x_low = static_cast(floor(x)); + integer_t y_high = y_low + 1; + integer_t x_high = x_low + 1; + + T ly = y - static_cast(y_low); + T lx = x - static_cast(x_low); + T hh = 1.0 - ly; + T hw = 1.0 - lx; + + T v1 = 0; + if (y_low >= 0 && x_low >= 0) + v1 = input[y_low * width + x_low]; + + T v2 = 0; + if (y_low >= 0 && x_high <= width - 1) + v2 = input[y_low * width + x_high]; + + T v3 = 0; + if (y_high <= height - 1 && x_low >= 0) + v3 = input[y_high * width + x_low]; + + T v4 = 0; + if (y_high <= height - 1 && x_high <= width - 1) + v4 = input[y_high * width + x_high]; + + T w1 = hh * hw; + T w2 = hh * lx; + T w3 = ly * hw; + T w4 = ly * lx; + + T val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + return val; +} + template inline void bilinear_interpolate_gradient( integer_t height, @@ -225,6 +271,117 @@ kernel void nms( \ uint2 tgid [[threadgroup_position_in_grid]], \ uint2 tid2 [[thread_position_in_threadgroup]]); + +template +kernel void deformable_im2col_kernel( + constant T* input_ptr [[ buffer(0) ]], + constant T* offset_ptr [[ buffer(1) ]], + constant T* mask_ptr [[ buffer(2) ]], + constant int& height [[ buffer(3) ]], + constant int& width [[ buffer(4) ]], + constant int& weight_h [[ buffer(5) ]], + constant int& weight_w [[ buffer(6) ]], + constant int& pad_h [[ buffer(7) ]], + constant int& pad_w [[ buffer(8) ]], + constant int& stride_h [[ buffer(9) ]], + constant int& stride_w [[ buffer(10)]], + constant int& dilation_h [[ buffer(11)]], + constant int& dilation_w [[ buffer(12)]], + constant int& batch_size [[ buffer(13)]], + constant int& n_in_channels [[ buffer(14)]], + constant int& n_offset_grps [[ buffer(15)]], + constant int& out_h [[ buffer(16)]], + constant int& out_w [[ buffer(17)]], + constant bool& use_mask [[ buffer(18)]], + device T* columns_ptr [[ buffer(19)]], + uint tid [[ thread_position_in_grid ]], + uint tpg [[ threads_per_grid ]]) +{ + int total = out_w * out_h * batch_size * n_in_channels; + int gridSize = tpg; + if (tid >= total) { + return; + } + + int out_x = tid % out_w; + int out_y = (tid / out_w) % out_h; + int out_b = (tid / (out_w * out_h)) % batch_size; + int in_c = tid / (out_w * out_h * batch_size); + int out_c = in_c * weight_h * weight_w; + + int c_per_offset_grp = n_in_channels / n_offset_grps; + int grp_idx = in_c / c_per_offset_grp; + + int col_offset = out_c * (batch_size * out_h * out_w) + + out_b * (out_h * out_w) + + out_y * out_w + out_x; + device T* local_columns_ptr = columns_ptr + col_offset; + + int input_offset = out_b * (n_in_channels * height * width) + + in_c * (height * width); + constant T* local_input_ptr = input_ptr + input_offset; + + int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w; + constant T* local_offset_ptr = offset_ptr + offset_offset; + + constant T* local_mask_ptr = nullptr; + if (use_mask) { + int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w; + local_mask_ptr = mask_ptr + mask_offset; + } + + for (int i = 0; i < weight_h; ++i) { + for (int j = 0; j < weight_w; ++j) { + int mask_index = i * weight_w + j; + int offset_index = 2 * mask_index; + + T mask_value = 1; + if (use_mask) { + mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x]; + } + + T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x]; + T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x]; + + T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val; + T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val; + + T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid); + + *local_columns_ptr = mask_value * interp; + + local_columns_ptr += batch_size * out_h * out_w; + } + } +} + +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +kernel void deformable_im2col_kernel( \ + constant DTYPE* input_ptr [[ buffer(0) ]], \ + constant DTYPE* offset_ptr [[ buffer(1) ]], \ + constant DTYPE* mask_ptr [[ buffer(2) ]], \ + constant int& height [[ buffer(3) ]], \ + constant int& width [[ buffer(4) ]], \ + constant int& weight_h [[ buffer(5) ]], \ + constant int& weight_w [[ buffer(6) ]], \ + constant int& pad_h [[ buffer(7) ]], \ + constant int& pad_w [[ buffer(8) ]], \ + constant int& stride_h [[ buffer(9) ]], \ + constant int& stride_w [[ buffer(10)]], \ + constant int& dilation_h [[ buffer(11)]], \ + constant int& dilation_w [[ buffer(12)]], \ + constant int& batch_sz [[ buffer(13)]], \ + constant int& n_in_channels[[ buffer(14)]], \ + constant int& n_offset_grps[[ buffer(15)]], \ + constant int& out_h [[ buffer(16)]], \ + constant int& out_w [[ buffer(17)]], \ + constant bool& use_mask [[ buffer(18)]], \ + device DTYPE* columns_ptr [[ buffer(19)]], \ + uint tid [[ thread_position_in_grid ]], \ + uint tpg [[ threads_per_grid ]]); + template kernel void roi_align( constant T * input [[buffer(0)]], @@ -1013,6 +1170,8 @@ kernel void ps_roi_pool_backward( \ REGISTER_NMS_OP(float); REGISTER_NMS_OP(half); +REGISTER_DEFORMABLE_IM2COL_OP(float); +REGISTER_DEFORMABLE_IM2COL_OP(half); REGISTER_ROI_ALIGN_OP(float, int64_t); REGISTER_ROI_ALIGN_OP(half, int64_t); REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t); From 66a65222919c028bf01ab5a0c95fdc79c66cbbc0 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Wed, 23 Apr 2025 23:47:45 +0400 Subject: [PATCH 2/4] use mtl set args --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index a9febf0481c..1d390a37f43 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -102,27 +102,10 @@ @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); [computeEncoder setComputePipelineState:pipelineState]; - [computeEncoder setBuffer:inputBuffer offset:0 atIndex:0]; - [computeEncoder setBuffer:offsetBuffer offset:0 atIndex:1]; - [computeEncoder setBuffer:maskBuffer offset:0 atIndex:2]; - [computeEncoder setBytes:&in_h length:sizeof(int) atIndex:3]; - [computeEncoder setBytes:&in_w length:sizeof(int) atIndex:4]; - [computeEncoder setBytes:&weight_h length:sizeof(int) atIndex:5]; - [computeEncoder setBytes:&weight_w length:sizeof(int) atIndex:6]; - [computeEncoder setBytes:&pad_h length:sizeof(int) atIndex:7]; - [computeEncoder setBytes:&pad_w length:sizeof(int) atIndex:8]; - [computeEncoder setBytes:&stride_h length:sizeof(int) atIndex:9]; - [computeEncoder setBytes:&stride_w length:sizeof(int) atIndex:10]; - [computeEncoder setBytes:&dilation_h length:sizeof(int) atIndex:11]; - [computeEncoder setBytes:&dilation_w length:sizeof(int) atIndex:12]; - [computeEncoder setBytes:&batch length:sizeof(int) atIndex:13]; - [computeEncoder setBytes:&in_channels length:sizeof(int) atIndex:14]; - [computeEncoder setBytes:&n_offset_grps length:sizeof(int) atIndex:15]; - [computeEncoder setBytes:&out_h length:sizeof(int) atIndex:16]; - [computeEncoder setBytes:&out_w length:sizeof(int) atIndex:17]; - [computeEncoder setBytes:&use_mask length:sizeof(bool) atIndex:18]; - [computeEncoder setBuffer:outputBuffer offset:0 atIndex:19]; - + at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer, + in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, batch, in_channels, n_offset_grps, out_h, out_w, + use_mask, outputBuffer); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; } }); From 731351672c7899a4d920fb17b9b949ce0717ae8c Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 13 Jun 2025 23:53:44 +0400 Subject: [PATCH 3/4] resolve pr comments and collate ints to int2 --- .../csrc/ops/mps/deform_conv2d_kernel.mm | 41 ++++++--- torchvision/csrc/ops/mps/mps_kernels.h | 85 +++++++++---------- 2 files changed, 68 insertions(+), 58 deletions(-) diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 1d390a37f43..4630efe429b 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -35,20 +35,30 @@ TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D"); TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true"); TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor"); + TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor"); + TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor"); + TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor"); at::DeviceGuard guard(input_c.device()); - int batch = input_c.size(0); - int in_channels = input_c.size(1); - int in_h = input_c.size(2); - int in_w = input_c.size(3); - int weight_h = weight_c.size(2); - int weight_w = weight_c.size(3); - int out_channels = weight_c.size(0); - int ker_h = dilation_h * (weight_h - 1) + 1; - int ker_w = dilation_w * (weight_w - 1) + 1; - int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; - int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + uint32_t batch = input_c.size(0); + uint32_t in_channels = input_c.size(1); + uint32_t in_h = input_c.size(2); + uint32_t in_w = input_c.size(3); + uint32_t weight_h = weight_c.size(2); + uint32_t weight_w = weight_c.size(3); + uint32_t out_channels = weight_c.size(0); + uint32_t ker_h = dilation_h * (weight_h - 1) + 1; + uint32_t ker_w = dilation_w * (weight_w - 1) + 1; + uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; + uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; + uint32_t pad_h_u = static_cast(pad_h); + uint32_t pad_w_u = static_cast(pad_w); + uint32_t stride_h_u = static_cast(stride_h); + uint32_t stride_w_u = static_cast(stride_w); + uint32_t dilation_h_u = static_cast(dilation_h); + uint32_t dilation_w_u = static_cast(dilation_w); TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, "Input channels (", in_channels, @@ -103,8 +113,13 @@ id computeEncoder = mpsStream->commandEncoder(); [computeEncoder setComputePipelineState:pipelineState]; at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer, - in_h, in_w, weight_h, weight_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, batch, in_channels, n_offset_grps, out_h, out_w, + std::array{in_h, in_w}, + std::array{weight_h, weight_w}, + std::array{pad_h_u, pad_w_u}, + std::array{stride_h_u, stride_w_u}, + std::array{dilation_h_u, dilation_w_u}, + batch, in_channels, n_offset_grps, + std::array{out_h, out_w}, use_mask, outputBuffer); [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; } diff --git a/torchvision/csrc/ops/mps/mps_kernels.h b/torchvision/csrc/ops/mps/mps_kernels.h index 2f24c86c6bf..35c60fa0064 100644 --- a/torchvision/csrc/ops/mps/mps_kernels.h +++ b/torchvision/csrc/ops/mps/mps_kernels.h @@ -277,28 +277,29 @@ kernel void deformable_im2col_kernel( constant T* input_ptr [[ buffer(0) ]], constant T* offset_ptr [[ buffer(1) ]], constant T* mask_ptr [[ buffer(2) ]], - constant int& height [[ buffer(3) ]], - constant int& width [[ buffer(4) ]], - constant int& weight_h [[ buffer(5) ]], - constant int& weight_w [[ buffer(6) ]], - constant int& pad_h [[ buffer(7) ]], - constant int& pad_w [[ buffer(8) ]], - constant int& stride_h [[ buffer(9) ]], - constant int& stride_w [[ buffer(10)]], - constant int& dilation_h [[ buffer(11)]], - constant int& dilation_w [[ buffer(12)]], - constant int& batch_size [[ buffer(13)]], - constant int& n_in_channels [[ buffer(14)]], - constant int& n_offset_grps [[ buffer(15)]], - constant int& out_h [[ buffer(16)]], - constant int& out_w [[ buffer(17)]], - constant bool& use_mask [[ buffer(18)]], - device T* columns_ptr [[ buffer(19)]], + constant int2& input_size [[ buffer(3) ]], // (height, width) + constant int2& weight_size [[ buffer(4) ]], // (weight_h, weight_w) + constant int2& pad [[ buffer(5) ]], // (pad_h, pad_w) + constant int2& stride [[ buffer(6) ]], // (stride_h, stride_w) + constant int2& dilation [[ buffer(7) ]], // (dilation_h, dilation_w) + constant int& batch_size [[ buffer(8) ]], + constant int& n_in_channels [[ buffer(9) ]], + constant int& n_offset_grps [[ buffer(10)]], + constant int2& out_size [[ buffer(11)]], // (out_h, out_w) + constant bool& use_mask [[ buffer(12)]], + device T* columns_ptr [[ buffer(13)]], uint tid [[ thread_position_in_grid ]], - uint tpg [[ threads_per_grid ]]) + uint tpg [[ threads_per_grid ]] +) { + int height = input_size.x, width = input_size.y; + int weight_h = weight_size.x, weight_w = weight_size.y; + int pad_h = pad.x, pad_w = pad.y; + int stride_h = stride.x, stride_w = stride.y; + int dilation_h = dilation.x, dilation_w = dilation.y; + int out_h = out_size.x, out_w = out_size.y; + int total = out_w * out_h * batch_size * n_in_channels; - int gridSize = tpg; if (tid >= total) { return; } @@ -355,32 +356,26 @@ kernel void deformable_im2col_kernel( } } -#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ -template \ -[[host_name("deformable_im2col_" #DTYPE)]] \ -kernel void deformable_im2col_kernel( \ - constant DTYPE* input_ptr [[ buffer(0) ]], \ - constant DTYPE* offset_ptr [[ buffer(1) ]], \ - constant DTYPE* mask_ptr [[ buffer(2) ]], \ - constant int& height [[ buffer(3) ]], \ - constant int& width [[ buffer(4) ]], \ - constant int& weight_h [[ buffer(5) ]], \ - constant int& weight_w [[ buffer(6) ]], \ - constant int& pad_h [[ buffer(7) ]], \ - constant int& pad_w [[ buffer(8) ]], \ - constant int& stride_h [[ buffer(9) ]], \ - constant int& stride_w [[ buffer(10)]], \ - constant int& dilation_h [[ buffer(11)]], \ - constant int& dilation_w [[ buffer(12)]], \ - constant int& batch_sz [[ buffer(13)]], \ - constant int& n_in_channels[[ buffer(14)]], \ - constant int& n_offset_grps[[ buffer(15)]], \ - constant int& out_h [[ buffer(16)]], \ - constant int& out_w [[ buffer(17)]], \ - constant bool& use_mask [[ buffer(18)]], \ - device DTYPE* columns_ptr [[ buffer(19)]], \ - uint tid [[ thread_position_in_grid ]], \ - uint tpg [[ threads_per_grid ]]); +#define REGISTER_DEFORMABLE_IM2COL_OP(DTYPE) \ +template \ +[[host_name("deformable_im2col_" #DTYPE)]] \ +kernel void deformable_im2col_kernel( \ + constant DTYPE* input_ptr [[ buffer(0) ]], \ + constant DTYPE* offset_ptr [[ buffer(1) ]], \ + constant DTYPE* mask_ptr [[ buffer(2) ]], \ + constant int2& input_size [[ buffer(3) ]], /* (h, w) */ \ + constant int2& weight_size [[ buffer(4) ]], /* (h, w) */ \ + constant int2& pad [[ buffer(5) ]], /* (h, w) */ \ + constant int2& stride [[ buffer(6) ]], /* (h, w) */ \ + constant int2& dilation [[ buffer(7) ]], /* (h, w) */ \ + constant int& batch_size [[ buffer(8) ]], \ + constant int& n_in_channels [[ buffer(9) ]], \ + constant int& n_offset_grps [[ buffer(10)]], \ + constant int2& out_size [[ buffer(11)]], /* (h, w) */ \ + constant bool& use_mask [[ buffer(12)]], \ + device DTYPE* columns_ptr [[ buffer(13)]], \ + uint tid [[ thread_position_in_grid ]], \ + uint tpg [[ threads_per_grid ]]); template kernel void roi_align( From f48400998c33407efc9dea9db0772016acc22daa Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Thu, 19 Jun 2025 21:30:17 +0400 Subject: [PATCH 4/4] linter, thought I ran it :/ --- test/test_ops.py | 39 +++++++++++++------ .../csrc/ops/mps/deform_conv2d_kernel.mm | 2 +- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c7c415e0ab3..eeed3345834 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1215,19 +1215,37 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad): out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad) - offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w, - dtype=dtype, device=device, requires_grad=requires_grad) - weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1], - dtype=dtype, device=device, requires_grad=requires_grad) - bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad) - use_mask = True - mask = torch.sigmoid(torch.randn( + offset = torch.randn( batch_size, - kernel_size[0] * kernel_size[1], + 2 * kernel_size[0] * kernel_size[1], out_h, out_w, - dtype=dtype, device=device, requires_grad=requires_grad - )) + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + weight = torch.randn( + out_channels, + channels_in // groups, + kernel_size[0], + kernel_size[1], + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad) + use_mask = True + mask = torch.sigmoid( + torch.randn( + batch_size, + kernel_size[0] * kernel_size[1], + out_h, + out_w, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + ) kwargs = { "offset": offset, "weight": weight, @@ -1246,7 +1264,6 @@ def test_deform_conv2d_opcheck(dtype, device, requires_grad): optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs) - class TestFrozenBNT: def test_frozenbatchnorm2d_repr(self): num_features = 32 diff --git a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm index 4630efe429b..63371365655 100644 --- a/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm +++ b/torchvision/csrc/ops/mps/deform_conv2d_kernel.mm @@ -146,4 +146,4 @@ } } // namespace ops -} // namespace vision \ No newline at end of file +} // namespace vision