Skip to content

Commit 89a763d

Browse files
committed
fix bugs in matrix instructions
1 parent 548ed1c commit 89a763d

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

tensorflow/tools/hipblaslt/tensile_config_generator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import copy
66
import os
77
import subprocess
8+
import math
89
# Paths to the input and output files
910
parser = argparse.ArgumentParser(description="""Generate Tensile config file""")
1011

@@ -163,10 +164,9 @@ def extract_dtype(match):
163164
unique_gemms_subgroups[i%args.gpus] = [(k, v)]
164165

165166

166-
matmul_instructions = {}
167167
for gpu_idx, unique_gemms_subgroup in enumerate(unique_gemms_subgroups):
168168
gemm_group = {}
169-
matrix_instructions = {}
169+
matmul_instructions = {}
170170
if unique_gemms_subgroup is None:
171171
continue
172172

@@ -182,33 +182,36 @@ def extract_dtype(match):
182182
size = extract_problem_size(match)
183183
dtype = extract_dtype(match)
184184
mfma_instruction = instruction_map(dtype)
185+
dtype_str = json.dumps(dtype)
185186
if mfma_instruction is None:
186187
continue
187188
for m_tiles in range(1, CU+1):
188189
if size[0] // m_tiles > 256:
189190
continue
190-
wave_tile_m = size[0] // m_tiles // mfma_instruction[0]
191-
if wave_tile_m<=0:
191+
192+
wave_tile_m = math.ceil(size[0] // m_tiles / mfma_instruction[0])
193+
if wave_tile_m <= 0:
192194
continue
193195
for n_tiles in range(1, CU+1):
194196
if size[1] // n_tiles > 256:
195197
continue
196-
wave_tile_n = size[1] // n_tiles // mfma_instruction[1]
197-
if wave_tile_n<=0:
198+
wave_tile_n = math.ceil(size[1] // n_tiles / mfma_instruction[1])
199+
if wave_tile_n <= 0:
198200
continue
199201
matmul_instruction = mfma_instruction+[1, 1, 1, 1, 1]
200202
for k in range(3):
201-
if wave_tile_m//(2**k) > 0:
203+
if wave_tile_m // (2**k) > 0:
202204
matmul_instruction[-4] = wave_tile_m//(2**k)
203205
matmul_instruction[-2] = 2**k
204206
for l in range(3):
205-
if wave_tile_n//(2**l) > 0:
207+
if wave_tile_n // (2**l) > 0:
206208
matmul_instruction[-3] = wave_tile_n//(2**l)
207209
matmul_instruction[-1] = 2**l
208-
matmul_instructions[str(matmul_instruction)] = matmul_instruction
209-
210+
if dtype_str not in matmul_instructions:
211+
matmul_instructions[dtype_str] = dict()
212+
matmul_instructions[dtype_str][str(matmul_instruction)] = matmul_instruction
213+
210214

211-
dtype_str = json.dumps(dtype)
212215
if dtype_str in gemm_group:
213216
gemm_group[dtype_str].append({'Exact': size})
214217
else:
@@ -239,7 +242,7 @@ def extract_dtype(match):
239242
data["BenchmarkProblems"][i][1]["BenchmarkFinalParameters"][0]["ProblemSizes"] = gemm_group[dtype_str]
240243
for item in data["BenchmarkProblems"][i][1]["ForkParameters"]:
241244
if "MatrixInstruction" in item:
242-
item["MatrixInstruction"] = [list(item) for item in matmul_instructions.values()]
245+
item["MatrixInstruction"] = [list(item) for item in matmul_instructions[dtype_str].values()]
243246
if "WorkGroupMappingXCCGroup" in item:
244247
item["WorkGroupMappingXCCGroup"] = [CU]
245248
if "WorkGroupMappingXCC" in item:

0 commit comments

Comments
 (0)