From dd7df3bb1fd37608509140ad3ace8e1284bf9b17 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Tue, 16 May 2023 21:46:17 +0000 Subject: [PATCH 1/3] Initial pallas docs --- docs/guides/pallas.md | 84 +++++++++++++++++++++++++++++++ docs/{ => guides}/triton_call.md | 2 +- docs/index.md | 2 +- jax_triton/pallas/registration.py | 12 +++++ mkdocs.yml | 4 +- tests/docs_test.py | 10 ++++ 6 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 docs/guides/pallas.md rename docs/{ => guides}/triton_call.md (64%) create mode 100644 tests/docs_test.py diff --git a/docs/guides/pallas.md b/docs/guides/pallas.md new file mode 100644 index 0000000..dd89ec0 --- /dev/null +++ b/docs/guides/pallas.md @@ -0,0 +1,84 @@ +# Writing custom kernels using Pallas, a JAX kernel sublanguage + +Pallas allows you to write your own custom kernels using JAX directly! + +## Your first Pallas kernel: `add_one` + +Let's try to write a kernel that adds one to a vector. We'll first do some basic imports. +```python +from functools import partial +import jax.numpy as jnp +import numpy as np +import jax + +from jax_triton import pallas as pl +``` + +First we'll write a kernel. A kernel is a program that will be executed (potentially multiple times) on an accelerator. Our `add_one_kernel` function should read from inputs, perform the computation, then write to the outputs. +```python +def add_one_kernel(x_ref, o_ref, *, block_size: int): + i = pl.program_id(0) + offsets = i * block_size + jnp.arange(block_size) + o_ref[offsets] = x_ref[offsets] + 1 +``` +We perform indexed reads from and in-place indexed writes to `Ref`s using NumPy-style indexing. + +We now write a JAX function that runs the kernel using `pallas_call`. The `grid` argument indicates how many times the kernel will be invoked. +```python +@jax.jit +def add_one(x): + return pl.pallas_call( + partial(add_one_kernel, block_size=8), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=pl.cdiv(x.shape[0], 8))(x) + +``` + +We can now call this JAX function like any other. +```python +x = jnp.arange(32) +np.testing.assert_allclose(add_one(x), x + 1) +``` +We can also even `jax.vmap` it! +```python +x = jnp.arange(4 * 32).reshape((4, 32)) +np.testing.assert_allclose(jax.vmap(add_one)(x), x + 1) +``` + +## Pallas programming guide + +We'll now cover some in writing Pallas kernels. + +### Launching programs in a *grid* + +Next, we'll write a *kernel*. A *kernel* is a JAX function that takes in `Ref` objects (mutable JAX types) corresponding to inputs and outputs. In this case, we'll have one `Ref` for the input (`x_ref`) and one for the output (`o_ref`). + +Conceptually, this kernel function will be executed multiple times, each on a different chunk, or block, of the inputs and outputs. We'll parameterize our kernel by a static integer `block_size`, which will determine the size of the "chunks" or "blocks" of our input that each instance of the kernel will operate on. + + + + + + + + + + + + + + +
Program indices
01234567
diff --git a/docs/triton_call.md b/docs/guides/triton_call.md similarity index 64% rename from docs/triton_call.md rename to docs/guides/triton_call.md index 30f320f..c8298ae 100644 --- a/docs/triton_call.md +++ b/docs/guides/triton_call.md @@ -1,6 +1,6 @@ # Calling Triton kernels from JAX -The primary way of using JAX Triton is using `jax_triton.triton_call` to call handwritten Triton kernels +The simplest way of using JAX Triton is using `jax_triton.triton_call` to call handwritten Triton kernels from inside JIT-ted JAX programs. ::: jax_triton.triton_call diff --git a/docs/index.md b/docs/index.md index 5d45376..613e6fd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -# JAX-Triton documentation +# JAX-Triton JAX-Triton is a repository containing containing integrations between [JAX](https://github.com/google/jax) and [Triton](https://github.com/openai/triton). diff --git a/jax_triton/pallas/registration.py b/jax_triton/pallas/registration.py index 741ee61..c09c4fd 100644 --- a/jax_triton/pallas/registration.py +++ b/jax_triton/pallas/registration.py @@ -13,6 +13,17 @@ # limitations under the License. """Lowering registrations for pallas_call""" +import functools +from jax._src.interpreters import mlir +from jax_triton.pallas import pallas_call_p + +def _pallas_call_cpu_lowering_rule(ctx: mlir.LoweringRuleContext, *args, + interpret: bool, + **kwargs): + del interpret + return mlir.lower_fun(pallas_call_p.impl)(ctx, *args, interpret=True, **kwargs) +mlir.register_lowering(pallas_call_p, _pallas_call_cpu_lowering_rule, + platform="cpu") try: from jax_triton.pallas import triton_lowering @@ -20,3 +31,4 @@ except (ImportError, ModuleNotFoundError): pass # trailer + diff --git a/mkdocs.yml b/mkdocs.yml index 918ad06..ffa2196 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,8 @@ plugins: nav: - "index.md" - - "triton_call.md" + - Guides: + - guides/triton_call.md + - guides/pallas.md strict: true diff --git a/tests/docs_test.py b/tests/docs_test.py new file mode 100644 index 0000000..3c6ba47 --- /dev/null +++ b/tests/docs_test.py @@ -0,0 +1,10 @@ +import pathlib +import pytest + +import mktestdocs + +# Note the use of `str`, makes for pretty output +@pytest.mark.parametrize('fpath', [pathlib.Path("docs") / "guides" / + "pallas.md"], ids=str) +def test_guides(fpath): + mktestdocs.check_md_file(fpath=fpath, memory=True) From 4323f8641bd9bdc42b7f1310f9de5a58cc2fe999 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Thu, 18 May 2023 00:11:03 +0000 Subject: [PATCH 2/3] Add initial docs --- docs/guides/pallas/concepts.md | 162 ++++++++++++++++++ docs/guides/pallas/index.md | 12 ++ .../{pallas.md => pallas/quickstart.md} | 41 ----- mkdocs.yml | 14 +- 4 files changed, 186 insertions(+), 43 deletions(-) create mode 100644 docs/guides/pallas/concepts.md create mode 100644 docs/guides/pallas/index.md rename docs/guides/{pallas.md => pallas/quickstart.md} (54%) diff --git a/docs/guides/pallas/concepts.md b/docs/guides/pallas/concepts.md new file mode 100644 index 0000000..29b1f94 --- /dev/null +++ b/docs/guides/pallas/concepts.md @@ -0,0 +1,162 @@ +We'll now cover some important concepts when writing your Pallas kernels. + +## Programming model + +### Single-program, multiple data (SPMD) + +Pallas has a [single-program, multiple data (SPMD)](https://en.wikipedia.org/wiki/Single_program,_multiple_data) programming paradigm. This means you write a single function that describes your computation and it'll be executed many times with different inputs. On a GPU, this bottoms out in the function being executed in parallel over many threads. + +If you're familiar with JAX, you may have seen the term SPMD before. In fact, it is the programming model for `jax.pmap` and `jax.experimental.shard_map`! However, in those cases, we parallelize computations over many different accelerators. In Pallas, we want to parallelize computations *within* an accelerator. + + +### Launching kernels in a "grid" + +In Pallas, after we write our SPMD function, which we'll call a **kernel**, we'll need to specify how we execute our kernel using a **grid**. A grid is a tuple of integers that specifies how many times we'll invoke the kernel. + +!!! info "What's a grid, more specifically?" + + If we think of this grid as a "shape" (think NumPy array or JAX array shape), it encodes a set of possible indices. For example, the grid `(8,)` encodes the indices `[(0,), (1,), (2,), ... , (7,)]` and the grid `(1, 2, 3)` encodes the indices `[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 1, 2)]`. + +The Pallas grid behaves much like the CUDA grid in that we'll be executing our kernel once per index in the grid. + +``` +# Pseudocode for how kernels are executed +for ndindex in np.ndenumerate(grid): + run_kernel(ndindex, kernel) +``` +However, on GPUs, this for loop will be *parallelized*. + +Each instance of the kernel, which we'll call a **program**, can select out which part of our input data it will operate on using its `program_id`, i.e. which index in the grid it corresponds to. + +### Stateful JAX (i.e. working with `Ref`s) + +How do we actually write our kernels? A Pallas kernel is a function that operates on arrays in "fast memory", i.e memory that is very close to our compute (on GPUs, this corresponds to L1 caches). In Pallas, we explicitly control how we interact with this memory -- specifically, we control where/when we load and store to and from memory. + +This level of control over memory isn't available in vanilla JAX. JAX doesn't even offer mutation semantics at all! However, JAX recently added support for a *state* side-effect via mutable references to arrays. Pallas repurposes `Ref`s to explicitly control how memory is accessed within kernels. + +When we write our kernel, it will take `Ref`s, not JAX arrays, as inputs. In addition, we will have `Ref`s for the outputs that we are responsible for writing the final values to. + +For example, here is a kernel that reads from an input `Ref` and copies its value to an output `Ref`: +```python +def identity_kernel(input_ref, output_ref): + value = input_ref[...] # Uses NumPy-like indexing semantics to read values + output_ref[...] = value # Writes value to `output_ref` +``` + +!!! warning + When writing to `Ref`s, remember that we might be executing our kernel in parallel, so there may be race conditions when writing to the same location in memory. + + diff --git a/docs/guides/pallas/index.md b/docs/guides/pallas/index.md new file mode 100644 index 0000000..144f097 --- /dev/null +++ b/docs/guides/pallas/index.md @@ -0,0 +1,12 @@ +# Pallas, a JAX kernel sublanguage + +Pallas allows you to write your own custom kernels using JAX directly! +Some benefits of using Pallas include: + +* Provides familiar JAX APIs (`jax.numpy`, `jax.lax`, etc.) +* Compatible with JAX transformations (e.g. `jax.vmap`) + +Guides: + +* [Pallas Quickstart](quickstart.md) +* [Pallas Concepts](concepts.md) diff --git a/docs/guides/pallas.md b/docs/guides/pallas/quickstart.md similarity index 54% rename from docs/guides/pallas.md rename to docs/guides/pallas/quickstart.md index dd89ec0..f65e68f 100644 --- a/docs/guides/pallas.md +++ b/docs/guides/pallas/quickstart.md @@ -1,5 +1,3 @@ -# Writing custom kernels using Pallas, a JAX kernel sublanguage - Pallas allows you to write your own custom kernels using JAX directly! ## Your first Pallas kernel: `add_one` @@ -31,7 +29,6 @@ def add_one(x): partial(add_one_kernel, block_size=8), out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=pl.cdiv(x.shape[0], 8))(x) - ``` We can now call this JAX function like any other. @@ -44,41 +41,3 @@ We can also even `jax.vmap` it! x = jnp.arange(4 * 32).reshape((4, 32)) np.testing.assert_allclose(jax.vmap(add_one)(x), x + 1) ``` - -## Pallas programming guide - -We'll now cover some in writing Pallas kernels. - -### Launching programs in a *grid* - -Next, we'll write a *kernel*. A *kernel* is a JAX function that takes in `Ref` objects (mutable JAX types) corresponding to inputs and outputs. In this case, we'll have one `Ref` for the input (`x_ref`) and one for the output (`o_ref`). - -Conceptually, this kernel function will be executed multiple times, each on a different chunk, or block, of the inputs and outputs. We'll parameterize our kernel by a static integer `block_size`, which will determine the size of the "chunks" or "blocks" of our input that each instance of the kernel will operate on. - - - - - - - - - - - - - - -
Program indices
01234567
diff --git a/mkdocs.yml b/mkdocs.yml index ffa2196..76307ad 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,13 +24,20 @@ theme: name: Switch to light mode features: + - navigation.instant + - navigation.tabs + - navigation.path - navigation.sections - - toc.integrate + - navigation.indexes + - toc.follow + # - toc.integrate - header.autohide markdown_extensions: +- admonition - pymdownx.highlight: anchor_linenums: true +- pymdownx.details - pymdownx.inlinehilite - pymdownx.snippets - pymdownx.superfences @@ -46,6 +53,9 @@ nav: - "index.md" - Guides: - guides/triton_call.md - - guides/pallas.md + - Pallas: + - guides/pallas/index.md + - guides/pallas/quickstart.md + - guides/pallas/concepts.md strict: true From 96bae60547dce46112b0be4ef567233752486f12 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 19 May 2023 02:55:23 +0000 Subject: [PATCH 3/3] More pallas docs --- docs/guides/pallas/concepts.md | 139 +++++------------------------- docs/guides/pallas/index.md | 4 + docs/guides/pallas/pallas_call.md | 10 +++ docs/guides/pallas/quickstart.md | 2 + jax_triton/pallas/pallas_call.py | 27 +++++- mkdocs.yml | 3 +- tests/docs_test.py | 5 +- 7 files changed, 66 insertions(+), 124 deletions(-) create mode 100644 docs/guides/pallas/pallas_call.md diff --git a/docs/guides/pallas/concepts.md b/docs/guides/pallas/concepts.md index 29b1f94..ed9b226 100644 --- a/docs/guides/pallas/concepts.md +++ b/docs/guides/pallas/concepts.md @@ -1,15 +1,15 @@ -We'll now cover some important concepts when writing your Pallas kernels. +# Concepts -## Programming model +We'll now cover some important concepts in Pallas. -### Single-program, multiple data (SPMD) +## Single-program, multiple data (SPMD) Pallas has a [single-program, multiple data (SPMD)](https://en.wikipedia.org/wiki/Single_program,_multiple_data) programming paradigm. This means you write a single function that describes your computation and it'll be executed many times with different inputs. On a GPU, this bottoms out in the function being executed in parallel over many threads. If you're familiar with JAX, you may have seen the term SPMD before. In fact, it is the programming model for `jax.pmap` and `jax.experimental.shard_map`! However, in those cases, we parallelize computations over many different accelerators. In Pallas, we want to parallelize computations *within* an accelerator. -### Launching kernels in a "grid" +## Launching kernels in a "grid" In Pallas, after we write our SPMD function, which we'll call a **kernel**, we'll need to specify how we execute our kernel using a **grid**. A grid is a tuple of integers that specifies how many times we'll invoke the kernel. @@ -28,135 +28,38 @@ However, on GPUs, this for loop will be *parallelized*. Each instance of the kernel, which we'll call a **program**, can select out which part of our input data it will operate on using its `program_id`, i.e. which index in the grid it corresponds to. -### Stateful JAX (i.e. working with `Ref`s) +## Stateful JAX (i.e. working with `Ref`s) How do we actually write our kernels? A Pallas kernel is a function that operates on arrays in "fast memory", i.e memory that is very close to our compute (on GPUs, this corresponds to L1 caches). In Pallas, we explicitly control how we interact with this memory -- specifically, we control where/when we load and store to and from memory. This level of control over memory isn't available in vanilla JAX. JAX doesn't even offer mutation semantics at all! However, JAX recently added support for a *state* side-effect via mutable references to arrays. Pallas repurposes `Ref`s to explicitly control how memory is accessed within kernels. -When we write our kernel, it will take `Ref`s, not JAX arrays, as inputs. In addition, we will have `Ref`s for the outputs that we are responsible for writing the final values to. +When we write our kernel, it will take `Ref`s, not JAX arrays, as inputs. In addition, we will have `Ref`s for the outputs that we are responsible for writing the final values to. -For example, here is a kernel that reads from an input `Ref` and copies its value to an output `Ref`: +`Ref`s, or "references", are mutable wrappers around Array values. A common pattern when working with `Ref`s is to write functions that take in both "input" and "output" `Ref`s. Usually you read from the input `Ref`s and write into the output `Ref`s. + +For example, here is a function that reads from an input `Ref` and copies its value to an output `Ref`: ```python -def identity_kernel(input_ref, output_ref): +def identity_stateful(input_ref, output_ref): value = input_ref[...] # Uses NumPy-like indexing semantics to read values output_ref[...] = value # Writes value to `output_ref` ``` -!!! warning - When writing to `Ref`s, remember that we might be executing our kernel in parallel, so there may be race conditions when writing to the same location in memory. - - +On GPU, when we read from a `Ref`, we are loading values from GPU HBM into GPU shared memory and conversely when we are writing to a `Ref`, we are writing into GPU HBM. diff --git a/docs/guides/pallas/index.md b/docs/guides/pallas/index.md index 144f097..e424a57 100644 --- a/docs/guides/pallas/index.md +++ b/docs/guides/pallas/index.md @@ -6,7 +6,11 @@ Some benefits of using Pallas include: * Provides familiar JAX APIs (`jax.numpy`, `jax.lax`, etc.) * Compatible with JAX transformations (e.g. `jax.vmap`) +!!! warning + Pallas is experimental and may not support all JAX ops and transformations! If you find any unexpected errors, please [file an issue on Github](https://github.com/jax-ml/jax-triton/issues/new). Also, Pallas APIs aren't promised to be stable. + Guides: * [Pallas Quickstart](quickstart.md) * [Pallas Concepts](concepts.md) +* [`pallas_call`](pallas_call.md) diff --git a/docs/guides/pallas/pallas_call.md b/docs/guides/pallas/pallas_call.md new file mode 100644 index 0000000..df31c35 --- /dev/null +++ b/docs/guides/pallas/pallas_call.md @@ -0,0 +1,10 @@ +# `pallas_call` + +`pallas_call` is the function by which we run our kernels inside of larger JAX functions. + +::: jax_triton.pallas.pallas_call.pallas_call + options: + show_root_heading: true + show_root_full_path: false + show_source: false + show_signature: true diff --git a/docs/guides/pallas/quickstart.md b/docs/guides/pallas/quickstart.md index f65e68f..03712c1 100644 --- a/docs/guides/pallas/quickstart.md +++ b/docs/guides/pallas/quickstart.md @@ -1,3 +1,5 @@ +# Quickstart + Pallas allows you to write your own custom kernels using JAX directly! ## Your first Pallas kernel: `add_one` diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index c25c097..7392c7a 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -319,14 +319,37 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], return arg_shape return tuple(s for s in block_spec.block_shape if s is not None) -def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, +def pallas_call(f: Callable, *, out_shape: Any, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, out_specs: Optional[Sequence[Optional[BlockSpec]]] = None, input_output_aliases: Dict[int, int] = {}, interpret: bool = False, name: Optional[str] = None, - **compiler_params: Any): + **compiler_params: Any) -> Any: + """Executes a Pallas kernel on input JAX arrays. + + Args: + f: A kernel function. It should accept input `Ref`s first, followed by + output `Ref`s afterwards. + out_shape: A Pytree of `ShapeDtypeStruct`s (or values that have `.shape` and + `.dtype` properties). An output `Ref` is passed into `f` for each value + in `out_shape`. + debug: Prints out debugging information about the kernel when `True`. + interpret: Emulates kernel execution in HLO if `True`. + name: A string name for the kernel. + grid: An optional tuple of integers that determines the schedule by which + the kernel is executed. If not provided, the grid is empty. + in_specs: a list of optional `BlockSpecs` for each of the inputs to the + kernel. + out_specs: a list of optional `BlockSpecs` for each of the outputs from the + kernel. + input_output_aliases: A mapping of input indices to output indices, + indicating which inputs should be aliased to which outputs. + **compiler_params: A dictionary mapping compiler name to compiler options. + Returns: + A Pytree of JAX Arrays + """ if grid is None: if in_specs is not None: raise ValueError("Cannot specify `in_specs` with a `None` grid.") diff --git a/mkdocs.yml b/mkdocs.yml index 76307ad..a7b5042 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,11 +26,9 @@ theme: features: - navigation.instant - navigation.tabs - - navigation.path - navigation.sections - navigation.indexes - toc.follow - # - toc.integrate - header.autohide markdown_extensions: @@ -57,5 +55,6 @@ nav: - guides/pallas/index.md - guides/pallas/quickstart.md - guides/pallas/concepts.md + - guides/pallas/pallas_call.md strict: true diff --git a/tests/docs_test.py b/tests/docs_test.py index 3c6ba47..d8817b6 100644 --- a/tests/docs_test.py +++ b/tests/docs_test.py @@ -4,7 +4,8 @@ import mktestdocs # Note the use of `str`, makes for pretty output -@pytest.mark.parametrize('fpath', [pathlib.Path("docs") / "guides" / - "pallas.md"], ids=str) +@pytest.mark.parametrize('fpath', (pathlib.Path("docs") / + "guides").glob("**/*.md"), + ids=str) def test_guides(fpath): mktestdocs.check_md_file(fpath=fpath, memory=True)