|
29 | 29 | (12, 0): 24,
|
30 | 30 | }
|
31 | 31 |
|
| 32 | +TRITON_VERSION = version.parse(triton.__version__) |
| 33 | + |
32 | 34 |
|
33 | 35 | @functools.lru_cache
|
34 | 36 | def get_device_props(device=None):
|
@@ -57,86 +59,9 @@ def supports_tma():
|
57 | 59 | if not ret:
|
58 | 60 | return False
|
59 | 61 |
|
60 |
| - TRITON_VERSION = version.parse(triton.__version__) |
61 |
| - VALID_VERSION = version.parse('3.2.0') |
62 |
| - return TRITON_VERSION >= VALID_VERSION |
63 |
| - |
64 |
| - |
65 |
| -# Copy from: |
66 |
| -# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py |
67 |
| -class TmaDescKernelParam: |
68 |
| - TMA_DESC_SIZE = 128 |
69 |
| - |
70 |
| - def __init__(self): |
71 |
| - self.desc = torch.empty(self.TMA_DESC_SIZE, dtype=torch.uint8, device='cpu') |
72 |
| - |
73 |
| - def fill_(self, ptr, dims, block_dims, element_size): |
74 |
| - assert len(dims) == len(block_dims) |
75 |
| - assert 1 <= len(dims) <= 2 |
76 |
| - assert self.desc.data_ptr() % 64 == 0 |
77 |
| - |
78 |
| - if len(dims) == 1: |
79 |
| - triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dims[0], block_dims[0], element_size, |
80 |
| - self.desc.data_ptr()) |
81 |
| - else: |
82 |
| - triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dims[0], dims[1], block_dims[0], |
83 |
| - block_dims[1], element_size, self.desc.data_ptr()) |
84 |
| - |
85 |
| - # Return a CUtensorMap* pointer in host memory |
86 |
| - def tma_desc_cpu_ptr(self): |
87 |
| - return self.desc.data_ptr() |
88 |
| - |
89 |
| - |
90 |
| -# Copy from: |
91 |
| -# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py |
92 |
| -def create_1d_tma_descriptor_custom(ptr, dim, block_dim, element_size): |
93 |
| - desc = TmaDescKernelParam() |
94 |
| - desc.fill_(ptr, [dim], [block_dim], element_size) |
95 |
| - return desc |
96 |
| - |
97 |
| - |
98 |
| -# Copy from: |
99 |
| -# https://github.com/triton-lang/triton/blob/main/python/triton/tools/experimental_descriptor.py |
100 |
| -def create_2d_tma_descriptor_custom(ptr, dim1, dim0, block_dim1, block_dim0, element_size): |
101 |
| - desc = TmaDescKernelParam() |
102 |
| - desc.fill_(ptr, [dim1, dim0], [block_dim1, block_dim0], element_size) |
103 |
| - return desc |
104 |
| - |
105 |
| - |
106 |
| -try: |
107 |
| - from triton.tools.experimental_descriptor import create_1d_tma_descriptor, create_2d_tma_descriptor # noqa |
108 |
| -except BaseException: |
109 |
| - create_1d_tma_descriptor = create_1d_tma_descriptor_custom |
110 |
| - create_2d_tma_descriptor = create_2d_tma_descriptor_custom |
111 |
| - |
112 |
| - |
113 |
| -class TmaAutoTuneHelper: |
114 |
| - |
115 |
| - # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 |
116 |
| - class KernelParamWrapper: |
117 |
| - |
118 |
| - def __init__(self, desc): |
119 |
| - self.desc = desc |
120 |
| - |
121 |
| - def tma_desc_cpu_ptr(self): |
122 |
| - return self.desc.data_ptr() |
123 |
| - |
124 |
| - TMA_SIZE = 128 |
125 |
| - |
126 |
| - def __init__(self): |
127 |
| - self.fill_2d_tma_descriptor_inner = (triton.runtime.driver.active.utils.fill_2d_tma_descriptor) |
128 |
| - self.descriptors = {} |
129 |
| - |
130 |
| - # Call this method outside of the lambda function for grid size |
131 |
| - def init_tma_descriptor(self, name): |
132 |
| - self.descriptors[name] = torch.empty(TmaAutoTuneHelper.TMA_SIZE, device='cpu', dtype=torch.int8) |
| 62 | + VALID_VERSION = version.parse('3.4.0') |
| 63 | + return TRITON_VERSION == VALID_VERSION |
133 | 64 |
|
134 |
| - # Call this method inside the lambda function for grid size |
135 |
| - def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): |
136 |
| - desc_x = self.descriptors[name] |
137 |
| - assert desc_x.data_ptr() % 64 == 0 |
138 |
| - self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) |
139 | 65 |
|
140 |
| - def get_tma_descriptor_kernel_param(self, name): |
141 |
| - assert self.descriptors[name] is not None |
142 |
| - return self.KernelParamWrapper(self.descriptors[name]) |
| 66 | +if supports_tma(): |
| 67 | + from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F401 |
0 commit comments