|  | 
|  | 1 | +#include <ATen/ATen.h> | 
|  | 2 | +#include <ATen/mps/MPSProfiler.h> | 
|  | 3 | +#include <ATen/native/mps/OperationUtils.h> | 
|  | 4 | +#include "mps_kernels.h" | 
|  | 5 | + | 
|  | 6 | +namespace vision { | 
|  | 7 | +namespace ops { | 
|  | 8 | + | 
|  | 9 | +namespace { | 
|  | 10 | + | 
|  | 11 | +at::Tensor deform_conv2d_forward_kernel( | 
|  | 12 | +    const at::Tensor& input, | 
|  | 13 | +    const at::Tensor& weight, | 
|  | 14 | +    const at::Tensor& offset, | 
|  | 15 | +    const at::Tensor& mask, | 
|  | 16 | +    const at::Tensor& bias, | 
|  | 17 | +    int64_t stride_h, | 
|  | 18 | +    int64_t stride_w, | 
|  | 19 | +    int64_t pad_h, | 
|  | 20 | +    int64_t pad_w, | 
|  | 21 | +    int64_t dilation_h, | 
|  | 22 | +    int64_t dilation_w, | 
|  | 23 | +    int64_t n_weight_grps, | 
|  | 24 | +    int64_t n_offset_grps, | 
|  | 25 | +    bool use_mask) { | 
|  | 26 | +  using namespace at::native::mps; | 
|  | 27 | +  at::Tensor input_c = input.contiguous(); | 
|  | 28 | +  at::Tensor weight_c = weight.contiguous(); | 
|  | 29 | +  at::Tensor offset_c = offset.contiguous(); | 
|  | 30 | +  at::Tensor mask_c = mask.contiguous(); | 
|  | 31 | +  at::Tensor bias_c = bias.contiguous(); | 
|  | 32 | + | 
|  | 33 | +  TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D"); | 
|  | 34 | +  TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D"); | 
|  | 35 | +  TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D"); | 
|  | 36 | +  TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true"); | 
|  | 37 | +  TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); | 
|  | 38 | +  TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor"); | 
|  | 39 | +  TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor"); | 
|  | 40 | +  TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor"); | 
|  | 41 | +  TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor"); | 
|  | 42 | + | 
|  | 43 | +  at::DeviceGuard guard(input_c.device()); | 
|  | 44 | + | 
|  | 45 | +  uint32_t batch = input_c.size(0); | 
|  | 46 | +  uint32_t in_channels = input_c.size(1); | 
|  | 47 | +  uint32_t in_h = input_c.size(2); | 
|  | 48 | +  uint32_t in_w = input_c.size(3); | 
|  | 49 | +  uint32_t weight_h = weight_c.size(2); | 
|  | 50 | +  uint32_t weight_w = weight_c.size(3); | 
|  | 51 | +  uint32_t out_channels = weight_c.size(0); | 
|  | 52 | +  uint32_t ker_h = dilation_h * (weight_h - 1) + 1; | 
|  | 53 | +  uint32_t ker_w = dilation_w * (weight_w - 1) + 1; | 
|  | 54 | +  uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; | 
|  | 55 | +  uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; | 
|  | 56 | +  uint32_t pad_h_u = static_cast<uint32_t>(pad_h); | 
|  | 57 | +  uint32_t pad_w_u = static_cast<uint32_t>(pad_w); | 
|  | 58 | +  uint32_t stride_h_u = static_cast<uint32_t>(stride_h); | 
|  | 59 | +  uint32_t stride_w_u = static_cast<uint32_t>(stride_w); | 
|  | 60 | +  uint32_t dilation_h_u = static_cast<uint32_t>(dilation_h); | 
|  | 61 | +  uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w); | 
|  | 62 | + | 
|  | 63 | +  TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, | 
|  | 64 | +    "Input channels (", in_channels,  | 
|  | 65 | +    ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")"); | 
|  | 66 | +  TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0, | 
|  | 67 | +    "Weight tensor's out channels (", weight_c.size(0),  | 
|  | 68 | +    ") must be divisible by n_weight_grps (", n_weight_grps, ")"); | 
|  | 69 | +  TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w, | 
|  | 70 | +    "Offset tensor shape[1] is invalid: got ", offset_c.size(1),  | 
|  | 71 | +    ", expected ", n_offset_grps * 2 * weight_h * weight_w); | 
|  | 72 | +  TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w, | 
|  | 73 | +    "Mask tensor shape[1] is invalid: got ", mask_c.size(1),  | 
|  | 74 | +    ", expected ", n_offset_grps * weight_h * weight_w); | 
|  | 75 | +  TORCH_CHECK(in_channels % n_offset_grps == 0, | 
|  | 76 | +    "Input tensor channels (", in_channels,  | 
|  | 77 | +    ") must be divisible by n_offset_grps (", n_offset_grps, ")"); | 
|  | 78 | +  TORCH_CHECK(offset_c.size(0) == batch, | 
|  | 79 | +    "Offset tensor batch size (", offset_c.size(0), | 
|  | 80 | +    ") must match input tensor batch size (", batch, ")"); | 
|  | 81 | +  TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w, | 
|  | 82 | +    "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3),  | 
|  | 83 | +    ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); | 
|  | 84 | +  TORCH_CHECK(!use_mask || mask_c.size(0) == batch, | 
|  | 85 | +    "Mask tensor batch size (", mask_c.size(0), | 
|  | 86 | +    ") must match input tensor batch size (", batch, ")"); | 
|  | 87 | +  TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w), | 
|  | 88 | +    "Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3), | 
|  | 89 | +    ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); | 
|  | 90 | +  TORCH_CHECK(out_h > 0 && out_w > 0, | 
|  | 91 | +    "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); | 
|  | 92 | + | 
|  | 93 | +  auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options()); | 
|  | 94 | + | 
|  | 95 | +  id<MTLBuffer> inputBuffer  = getMTLBufferStorage(input_c); | 
|  | 96 | +  id<MTLBuffer> offsetBuffer = getMTLBufferStorage(offset_c); | 
|  | 97 | +  id<MTLBuffer> maskBuffer   = use_mask ? getMTLBufferStorage(mask_c) : nil; | 
|  | 98 | +  id<MTLBuffer> outputBuffer = getMTLBufferStorage(columns); | 
|  | 99 | + | 
|  | 100 | +  id<MTLDevice> device = MPSDevice::getInstance()->device(); | 
|  | 101 | +  std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); | 
|  | 102 | +  id<MTLComputePipelineState> pipelineState = mps::visionPipelineState(device, kernelName); | 
|  | 103 | + | 
|  | 104 | +  int num_kernels = in_channels * out_h * out_w * batch; | 
|  | 105 | +  NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup; | 
|  | 106 | +  NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup; | 
|  | 107 | +  MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); | 
|  | 108 | +  MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1); | 
|  | 109 | + | 
|  | 110 | +  MPSStream* mpsStream = getCurrentMPSStream(); | 
|  | 111 | +  dispatch_sync(mpsStream->queue(), ^{ | 
|  | 112 | +    @autoreleasepool { | 
|  | 113 | +      id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder(); | 
|  | 114 | +      [computeEncoder setComputePipelineState:pipelineState]; | 
|  | 115 | +      at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer, | 
|  | 116 | +                                   std::array<uint32_t, 2>{in_h, in_w}, | 
|  | 117 | +                                   std::array<uint32_t, 2>{weight_h, weight_w}, | 
|  | 118 | +                                   std::array<uint32_t, 2>{pad_h_u, pad_w_u}, | 
|  | 119 | +                                   std::array<uint32_t, 2>{stride_h_u, stride_w_u}, | 
|  | 120 | +                                   std::array<uint32_t, 2>{dilation_h_u, dilation_w_u}, | 
|  | 121 | +                                   batch, in_channels, n_offset_grps, | 
|  | 122 | +                                   std::array<uint32_t, 2>{out_h, out_w}, | 
|  | 123 | +                                   use_mask, outputBuffer); | 
|  | 124 | +      [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; | 
|  | 125 | +    } | 
|  | 126 | +  }); | 
|  | 127 | +  int in_channels_per_grp = in_channels / n_weight_grps; | 
|  | 128 | +  int out_channels_per_grp = out_channels / n_weight_grps; | 
|  | 129 | +  auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w}); | 
|  | 130 | +  auto columns_grouped = columns.view({n_weight_grps, | 
|  | 131 | +                                      (in_channels * weight_h * weight_w) / n_weight_grps, | 
|  | 132 | +                                      batch * out_h * out_w}); | 
|  | 133 | +  auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1}); | 
|  | 134 | +  auto out_grouped = at::bmm(weight_reshaped, columns_grouped); | 
|  | 135 | +  auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w}) | 
|  | 136 | +              .transpose(0, 1); | 
|  | 137 | +  return out + bias_c.view({1, out_channels, 1, 1}); | 
|  | 138 | +} | 
|  | 139 | + | 
|  | 140 | +} // namespace | 
|  | 141 | + | 
|  | 142 | +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { | 
|  | 143 | +  m.impl( | 
|  | 144 | +      TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), | 
|  | 145 | +      TORCH_FN(deform_conv2d_forward_kernel)); | 
|  | 146 | +} | 
|  | 147 | + | 
|  | 148 | +} // namespace ops | 
|  | 149 | +} // namespace vision | 
0 commit comments