55import copy
66import os
77import subprocess
8+ import math
89# Paths to the input and output files
910parser = argparse .ArgumentParser (description = """Generate Tensile config file""" )
1011
@@ -163,10 +164,9 @@ def extract_dtype(match):
163164 unique_gemms_subgroups [i % args .gpus ] = [(k , v )]
164165
165166
166- matmul_instructions = {}
167167for gpu_idx , unique_gemms_subgroup in enumerate (unique_gemms_subgroups ):
168168 gemm_group = {}
169- matrix_instructions = {}
169+ matmul_instructions = {}
170170 if unique_gemms_subgroup is None :
171171 continue
172172
@@ -182,33 +182,36 @@ def extract_dtype(match):
182182 size = extract_problem_size (match )
183183 dtype = extract_dtype (match )
184184 mfma_instruction = instruction_map (dtype )
185+ dtype_str = json .dumps (dtype )
185186 if mfma_instruction is None :
186187 continue
187188 for m_tiles in range (1 , CU + 1 ):
188189 if size [0 ] // m_tiles > 256 :
189190 continue
190- wave_tile_m = size [0 ] // m_tiles // mfma_instruction [0 ]
191- if wave_tile_m <= 0 :
191+
192+ wave_tile_m = math .ceil (size [0 ] // m_tiles / mfma_instruction [0 ])
193+ if wave_tile_m <= 0 :
192194 continue
193195 for n_tiles in range (1 , CU + 1 ):
194196 if size [1 ] // n_tiles > 256 :
195197 continue
196- wave_tile_n = size [1 ] // n_tiles // mfma_instruction [1 ]
197- if wave_tile_n <= 0 :
198+ wave_tile_n = math . ceil ( size [1 ] // n_tiles / mfma_instruction [1 ])
199+ if wave_tile_n <= 0 :
198200 continue
199201 matmul_instruction = mfma_instruction + [1 , 1 , 1 , 1 , 1 ]
200202 for k in range (3 ):
201- if wave_tile_m // (2 ** k ) > 0 :
203+ if wave_tile_m // (2 ** k ) > 0 :
202204 matmul_instruction [- 4 ] = wave_tile_m // (2 ** k )
203205 matmul_instruction [- 2 ] = 2 ** k
204206 for l in range (3 ):
205- if wave_tile_n // (2 ** l ) > 0 :
207+ if wave_tile_n // (2 ** l ) > 0 :
206208 matmul_instruction [- 3 ] = wave_tile_n // (2 ** l )
207209 matmul_instruction [- 1 ] = 2 ** l
208- matmul_instructions [str (matmul_instruction )] = matmul_instruction
209-
210+ if dtype_str not in matmul_instructions :
211+ matmul_instructions [dtype_str ] = dict ()
212+ matmul_instructions [dtype_str ][str (matmul_instruction )] = matmul_instruction
213+
210214
211- dtype_str = json .dumps (dtype )
212215 if dtype_str in gemm_group :
213216 gemm_group [dtype_str ].append ({'Exact' : size })
214217 else :
@@ -239,7 +242,7 @@ def extract_dtype(match):
239242 data ["BenchmarkProblems" ][i ][1 ]["BenchmarkFinalParameters" ][0 ]["ProblemSizes" ] = gemm_group [dtype_str ]
240243 for item in data ["BenchmarkProblems" ][i ][1 ]["ForkParameters" ]:
241244 if "MatrixInstruction" in item :
242- item ["MatrixInstruction" ] = [list (item ) for item in matmul_instructions .values ()]
245+ item ["MatrixInstruction" ] = [list (item ) for item in matmul_instructions [ dtype_str ] .values ()]
243246 if "WorkGroupMappingXCCGroup" in item :
244247 item ["WorkGroupMappingXCCGroup" ] = [CU ]
245248 if "WorkGroupMappingXCC" in item :
0 commit comments