Skip to content

Commit aa5f54e

Browse files
committed
unit test compiles
1 parent 1885ff4 commit aa5f54e

File tree

19 files changed

+760
-881
lines changed

19 files changed

+760
-881
lines changed

tensorflow/compiler/xla/debug_options_flags.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
7474

7575
// Note: CublasLt will be used for FP8 GEMMs regardless of the value of this
7676
// flag.
77-
opts.set_xla_gpu_enable_cublaslt(false);
77+
opts.set_xla_gpu_enable_cublaslt(true);
7878

7979
// TODO(b/258036887): Enable once CUDA Graphs are fully supported.
8080
opts.set_xla_gpu_cuda_graph_level(0);
@@ -122,7 +122,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
122122
opts.set_xla_partitioning_algorithm(
123123
DebugOptions::PARTITIONING_ALGORITHM_NOOP);
124124

125-
opts.set_xla_gpu_enable_triton_gemm(true);
125+
opts.set_xla_gpu_enable_triton_gemm(false);
126126
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
127127
opts.set_xla_gpu_triton_gemm_any(false);
128128

@@ -131,6 +131,15 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
131131
opts.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(false);
132132

133133
opts.set_xla_gpu_collective_inflation_factor(1);
134+
135+
// Minimum combined size of matrices in matrix multiplication to
136+
// be rewritten to cuBLAS or Triton kernel call.
137+
// This threshold is a conservative estimate and has been measured
138+
// to be always beneficial (up to generally several times faster)
139+
// on V100 and H100 GPUs. See openxla/xla #9319 for details.
140+
const int64_t kDefaultMinGemmRewriteSize = 100;
141+
opts.set_xla_gpu_gemm_rewrite_size_threshold(kDefaultMinGemmRewriteSize);
142+
134143
return opts;
135144
}
136145

tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandS
168168
Arg<LHLO_Buffer, "", [MemWrite]>:$d,
169169
Arg<Optional<LHLO_Buffer>, "", [MemRead]>:$bias,
170170
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$aux,
171+
Arg<Optional<LHLO_Buffer>, "", [MemRead, MemWrite]>:$workspace,
171172
MHLO_DotDimensionNumbers:$dot_dimension_numbers,
172173
MHLO_PrecisionConfigAttr:$precision_config,
173174
F64Attr:$alpha_real,

tensorflow/compiler/xla/service/gpu/cublas_cudnn.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,8 @@ bool IsCublasLtMatmul(const HloInstruction& hlo) {
3434
hlo.custom_call_target() == kCublasLtMatmulCallTarget;
3535
}
3636

37-
bool IsCublasLtMatmulF8(const HloInstruction& hlo) {
38-
return hlo.opcode() == HloOpcode::kCustomCall &&
39-
hlo.custom_call_target() == kCublasLtMatmulF8CallTarget;
40-
}
41-
4237
const absl::string_view kGemmCallTarget = "__cublas$gemm";
4338
const absl::string_view kCublasLtMatmulCallTarget = "__cublas$lt$matmul";
44-
const absl::string_view kCublasLtMatmulF8CallTarget = "__cublas$lt$matmul$f8";
4539
const absl::string_view kTriangularSolveCallTarget = "__cublas$triangularSolve";
4640

4741
const absl::string_view kCudnnConvBackwardInputCallTarget =

tensorflow/compiler/xla/service/gpu/cublas_cudnn.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,12 @@ bool IsLegacyCublasMatmul(const HloInstruction& hlo);
6161
// Matrix multiplication that calls into cublasLt.
6262
bool IsCublasLtMatmul(const HloInstruction& hlo);
6363

64-
// Scaled matrix multiplication in FP8. Calls into cublasLt.
65-
bool IsCublasLtMatmulF8(const HloInstruction& hlo);
66-
6764
// A call to cuBLAS general matrix multiplication API.
6865
extern const absl::string_view kGemmCallTarget;
6966

7067
// A call to cuBLAS Lt API matrix multiplication.
7168
extern const absl::string_view kCublasLtMatmulCallTarget;
7269

73-
// A call to cuBLASLt for scaled matrix multiplication in FP8.
74-
extern const absl::string_view kCublasLtMatmulF8CallTarget;
75-
7670
// A call to cuBLAS for a triangular solve.
7771
//
7872
// Like cudnn convolutions, this op returns a tuple (result, scratch_memory).

0 commit comments

Comments
 (0)