Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 39 additions & 48 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
import dataclasses
import functools
import inspect
import os
import pprint
import tempfile
Expand Down Expand Up @@ -163,17 +162,19 @@ def aval_size_bytes(aval):


def get_cuda_backend(device, compute_capability):
target = cb.GPUTarget('cuda', compute_capability, 32)
target = cb.GPUTarget("cuda", compute_capability, 32)
backend = cb.CUDABackend(target)
return backend


def get_hip_backend(device, compute_capability):
arch = triton_kernel_call_lib.get_arch_details(device)
arch = arch.split(":")[0]
target = hb.GPUTarget('hip', arch, 64)
target = hb.GPUTarget("hip", arch, 64)
backend = hb.HIPBackend(target)
return backend


@dataclasses.dataclass
class CompilationResult:
binary: str
Expand All @@ -183,32 +184,31 @@ class CompilationResult:
ttgir: str | None
llir: str | None


def compile_ttir_inplace(
ttir,
backend: [cb.CUDABackend | hb.HIPBackend],
options: [cb.CUDAOptions | hb.HIPOptions],
compute_capability,
platform
platform,
):
if platform == 'cuda':
if platform == "cuda":
return compile_ttir_to_ptx_inplace(
ttir,
backend,
options,
compute_capability,
ttir,
backend,
options,
compute_capability,
)

elif platform == 'rocm':
elif platform == "rocm":
return compile_ttir_to_hsaco_inplace(
ttir,
backend,
options,
compute_capability,
ttir,
backend,
options,
compute_capability,
)
else:
raise ValueError(
"Unsupported device."
)
raise ValueError("Unsupported device.")


def compile_ttir_to_ptx_inplace(
Expand Down Expand Up @@ -273,6 +273,7 @@ def compile_ttir_to_ptx_inplace(
llir=llir,
)


def compile_ttir_to_hsaco_inplace(
ttir,
hip_backend: hb.HIPBackend,
Expand All @@ -284,22 +285,14 @@ def compile_ttir_to_hsaco_inplace(
try:
metadata = {}
opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options)
ttgir = hip_backend.make_ttgir(
opt_ttir,
metadata,
hip_options
)
ttgir = hip_backend.make_ttgir(opt_ttir, metadata, hip_options)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if hip_options.debug:
print(ttgir)
try:
llir = hip_backend.make_llir(
ttgir,
metadata,
hip_options
)
llir = hip_backend.make_llir(ttgir, metadata, hip_options)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
Expand Down Expand Up @@ -332,6 +325,7 @@ def compile_ttir_to_hsaco_inplace(
llir=llir,
)


_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


Expand Down Expand Up @@ -370,30 +364,20 @@ def get_or_create_triton_kernel(
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
# assumption, unless array args are include in the `do_not_specialize` list.
alignments = [16] * len(arg_dtypes)
for i, _, value in scalar_args:
alignments[i] = value
specialize_extra = backend.get_arg_specialization
if specialize_impl := getattr(triton.runtime.jit, "specialize_impl", None):
# TODO(slebedev): Remove this branch once Triton 3.3 is released.
specialize_impl = functools.partial(
specialize_impl, specialize_extra=specialize_extra
)
else:
# TODO(rdyro): Remove unnecessary checks with 3.3.0 > release
create_specialize_impl = triton.runtime.jit.create_specialize_impl
if len(inspect.signature(create_specialize_impl).parameters) == 0:
# handle Triton 3.3.0 release
specialize_impl = functools.partial(
create_specialize_impl(), specialize_extra=specialize_extra
)
else:
# latest Triton head
specialize_impl = create_specialize_impl(specialize_extra)
for i, _, _ in scalar_args:
alignments[i] = 0
specialize_impl = _triton.native_specialize_impl
is_const = False
do_specialize = True
specialization = [
specialize_impl(
backend,
types.SimpleNamespace(
data_ptr=lambda: alignment, dtype=arg_dtype.removeprefix("*")
),
is_const,
do_specialize,
alignment > 0,
)
for arg_dtype, alignment in zip(arg_dtypes, alignments)
]
Expand Down Expand Up @@ -683,10 +667,11 @@ def prune_configs(configs, named_args, **kwargs):
custom_call_target_name,
api_version=2,
backend_config=zlib.compress(call_proto),
operand_output_aliases=dict(input_output_aliases)
operand_output_aliases=dict(input_output_aliases),
)
return rule(ctx, *array_args)


mlir.register_lowering(
triton_kernel_call_p,
functools.partial(triton_kernel_call_lowering, get_cuda_backend),
Expand All @@ -708,6 +693,7 @@ def triton_kernel_call_raise_on_jvp(*args, **kwargs):
"differentiation rule for your kernel."
)


ad.primitive_jvps[triton_kernel_call_p] = triton_kernel_call_raise_on_jvp


Expand All @@ -719,6 +705,7 @@ def triton_kernel_call_raise_on_vmap(*args, **kwargs):
"your kernel."
)


batching.primitive_batchers[triton_kernel_call_p] = (
triton_kernel_call_raise_on_vmap
)
Expand All @@ -737,7 +724,11 @@ def dtype(self) -> np.dtype:

def triton_call(
*args: jax.Array | bool | int | float | np.float32,
kernel: triton.JITFunction | triton.runtime.Heuristics | triton.runtime.Autotuner,
kernel: (
triton.JITFunction
| triton.runtime.Heuristics
| triton.runtime.Autotuner
),
out_shape: ShapeDtype | Sequence[ShapeDtype],
grid: GridOrLambda,
name: str = "",
Expand Down
Loading