Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions docs/guides/pallas/concepts.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Concepts

We'll now cover some important concepts in Pallas.

## Single-program, multiple data (SPMD)

Pallas has a [single-program, multiple data (SPMD)](https://en.wikipedia.org/wiki/Single_program,_multiple_data) programming paradigm. This means you write a single function that describes your computation and it'll be executed many times with different inputs. On a GPU, this bottoms out in the function being executed in parallel over many threads.

If you're familiar with JAX, you may have seen the term SPMD before. In fact, it is the programming model for `jax.pmap` and `jax.experimental.shard_map`! However, in those cases, we parallelize computations over many different accelerators. In Pallas, we want to parallelize computations *within* an accelerator.


## Launching kernels in a "grid"

In Pallas, after we write our SPMD function, which we'll call a **kernel**, we'll need to specify how we execute our kernel using a **grid**. A grid is a tuple of integers that specifies how many times we'll invoke the kernel.

!!! info "What's a grid, more specifically?"

If we think of this grid as a "shape" (think NumPy array or JAX array shape), it encodes a set of possible indices. For example, the grid `(8,)` encodes the indices `[(0,), (1,), (2,), ... , (7,)]` and the grid `(1, 2, 3)` encodes the indices `[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 1, 2)]`.

The Pallas grid behaves much like the CUDA grid in that we'll be executing our kernel once per index in the grid.

```
# Pseudocode for how kernels are executed
for ndindex in np.ndenumerate(grid):
run_kernel(ndindex, kernel)
```
However, on GPUs, this for loop will be *parallelized*.

Each instance of the kernel, which we'll call a **program**, can select out which part of our input data it will operate on using its `program_id`, i.e. which index in the grid it corresponds to.

## Stateful JAX (i.e. working with `Ref`s)

How do we actually write our kernels? A Pallas kernel is a function that operates on arrays in "fast memory", i.e memory that is very close to our compute (on GPUs, this corresponds to L1 caches). In Pallas, we explicitly control how we interact with this memory -- specifically, we control where/when we load and store to and from memory.

This level of control over memory isn't available in vanilla JAX. JAX doesn't even offer mutation semantics at all! However, JAX recently added support for a *state* side-effect via mutable references to arrays. Pallas repurposes `Ref`s to explicitly control how memory is accessed within kernels.

When we write our kernel, it will take `Ref`s, not JAX arrays, as inputs. In addition, we will have `Ref`s for the outputs that we are responsible for writing the final values to.

`Ref`s, or "references", are mutable wrappers around Array values. A common pattern when working with `Ref`s is to write functions that take in both "input" and "output" `Ref`s. Usually you read from the input `Ref`s and write into the output `Ref`s.

For example, here is a function that reads from an input `Ref` and copies its value to an output `Ref`:
```python
def identity_stateful(input_ref, output_ref):
value = input_ref[...] # Uses NumPy-like indexing semantics to read values
output_ref[...] = value # Writes value to `output_ref`
```

!!! note
The `[...]` notation corresponds to reading the *entire* value of the `Ref`.

Here's a function that computes `exp(x)`.
```python
def exp_stateful(input_ref, output_ref):
output_ref[...] = jnp.exp(input_ref[...])
```

You can also read from and write to the same `Ref`.
```python
def exp_plus_one_stateful(input_ref, output_ref):
output_ref[...] = jnp.exp(input_ref[...]) # Read from Ref
output_ref[...] = output_ref[...] + 1 # Read from and write to Ref
```
Conceptually, `exp_plus_one_stateful` updates `output_ref` in-place to compute `exp(x) + 1`.

On GPU, when we read from a `Ref`, we are loading values from GPU HBM into GPU shared memory and conversely when we are writing to a `Ref`, we are writing into GPU HBM.
16 changes: 16 additions & 0 deletions docs/guides/pallas/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Pallas, a JAX kernel sublanguage

Pallas allows you to write your own custom kernels using JAX directly!
Some benefits of using Pallas include:

* Provides familiar JAX APIs (`jax.numpy`, `jax.lax`, etc.)
* Compatible with JAX transformations (e.g. `jax.vmap`)

!!! warning
Pallas is experimental and may not support all JAX ops and transformations! If you find any unexpected errors, please [file an issue on Github](https://github.com/jax-ml/jax-triton/issues/new). Also, Pallas APIs aren't promised to be stable.

Guides:

* [Pallas Quickstart](quickstart.md)
* [Pallas Concepts](concepts.md)
* [`pallas_call`](pallas_call.md)
10 changes: 10 additions & 0 deletions docs/guides/pallas/pallas_call.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# `pallas_call`

`pallas_call` is the function by which we run our kernels inside of larger JAX functions.

::: jax_triton.pallas.pallas_call.pallas_call
options:
show_root_heading: true
show_root_full_path: false
show_source: false
show_signature: true
45 changes: 45 additions & 0 deletions docs/guides/pallas/quickstart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Quickstart

Pallas allows you to write your own custom kernels using JAX directly!

## Your first Pallas kernel: `add_one`

Let's try to write a kernel that adds one to a vector. We'll first do some basic imports.
```python
from functools import partial
import jax.numpy as jnp
import numpy as np
import jax

from jax_triton import pallas as pl
```

First we'll write a kernel. A kernel is a program that will be executed (potentially multiple times) on an accelerator. Our `add_one_kernel` function should read from inputs, perform the computation, then write to the outputs.
```python
def add_one_kernel(x_ref, o_ref, *, block_size: int):
i = pl.program_id(0)
offsets = i * block_size + jnp.arange(block_size)
o_ref[offsets] = x_ref[offsets] + 1
```
We perform indexed reads from and in-place indexed writes to `Ref`s using NumPy-style indexing.

We now write a JAX function that runs the kernel using `pallas_call`. The `grid` argument indicates how many times the kernel will be invoked.
```python
@jax.jit
def add_one(x):
return pl.pallas_call(
partial(add_one_kernel, block_size=8),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=pl.cdiv(x.shape[0], 8))(x)
```

We can now call this JAX function like any other.
```python
x = jnp.arange(32)
np.testing.assert_allclose(add_one(x), x + 1)
```
We can also even `jax.vmap` it!
```python
x = jnp.arange(4 * 32).reshape((4, 32))
np.testing.assert_allclose(jax.vmap(add_one)(x), x + 1)
```
2 changes: 1 addition & 1 deletion docs/triton_call.md → docs/guides/triton_call.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Calling Triton kernels from JAX

The primary way of using JAX Triton is using `jax_triton.triton_call` to call handwritten Triton kernels
The simplest way of using JAX Triton is using `jax_triton.triton_call` to call handwritten Triton kernels
from inside JIT-ted JAX programs.

::: jax_triton.triton_call
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# JAX-Triton documentation
# JAX-Triton

JAX-Triton is a repository containing containing integrations between [JAX](https://github.com/google/jax)
and [Triton](https://github.com/openai/triton).
Expand Down
27 changes: 25 additions & 2 deletions jax_triton/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,37 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec],
return arg_shape
return tuple(s for s in block_spec.block_shape if s is not None)

def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False,
def pallas_call(f: Callable, *, out_shape: Any, debug: bool = False,
grid: Optional[Grid] = None,
in_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
out_specs: Optional[Sequence[Optional[BlockSpec]]] = None,
input_output_aliases: Dict[int, int] = {},
interpret: bool = False,
name: Optional[str] = None,
**compiler_params: Any):
**compiler_params: Any) -> Any:
"""Executes a Pallas kernel on input JAX arrays.

Args:
f: A kernel function. It should accept input `Ref`s first, followed by
output `Ref`s afterwards.
out_shape: A Pytree of `ShapeDtypeStruct`s (or values that have `.shape` and
`.dtype` properties). An output `Ref` is passed into `f` for each value
in `out_shape`.
debug: Prints out debugging information about the kernel when `True`.
interpret: Emulates kernel execution in HLO if `True`.
name: A string name for the kernel.
grid: An optional tuple of integers that determines the schedule by which
the kernel is executed. If not provided, the grid is empty.
in_specs: a list of optional `BlockSpecs` for each of the inputs to the
kernel.
out_specs: a list of optional `BlockSpecs` for each of the outputs from the
kernel.
input_output_aliases: A mapping of input indices to output indices,
indicating which inputs should be aliased to which outputs.
**compiler_params: A dictionary mapping compiler name to compiler options.
Returns:
A Pytree of JAX Arrays
"""
if grid is None:
if in_specs is not None:
raise ValueError("Cannot specify `in_specs` with a `None` grid.")
Expand Down
12 changes: 12 additions & 0 deletions jax_triton/pallas/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@
# limitations under the License.

"""Lowering registrations for pallas_call"""
import functools
from jax._src.interpreters import mlir
from jax_triton.pallas import pallas_call_p

def _pallas_call_cpu_lowering_rule(ctx: mlir.LoweringRuleContext, *args,
interpret: bool,
**kwargs):
del interpret
return mlir.lower_fun(pallas_call_p.impl)(ctx, *args, interpret=True, **kwargs)
mlir.register_lowering(pallas_call_p, _pallas_call_cpu_lowering_rule,
platform="cpu")

try:
from jax_triton.pallas import triton_lowering
del triton_lowering
except (ImportError, ModuleNotFoundError):
pass
# trailer

15 changes: 13 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ theme:
name: Switch to light mode

features:
- navigation.instant
- navigation.tabs
- navigation.sections
- toc.integrate
- navigation.indexes
- toc.follow
- header.autohide

markdown_extensions:
- admonition
- pymdownx.highlight:
anchor_linenums: true
- pymdownx.details
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
Expand All @@ -44,6 +49,12 @@ plugins:

nav:
- "index.md"
- "triton_call.md"
- Guides:
- guides/triton_call.md
- Pallas:
- guides/pallas/index.md
- guides/pallas/quickstart.md
- guides/pallas/concepts.md
- guides/pallas/pallas_call.md

strict: true
11 changes: 11 additions & 0 deletions tests/docs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pathlib
import pytest

import mktestdocs

# Note the use of `str`, makes for pretty output
@pytest.mark.parametrize('fpath', (pathlib.Path("docs") /
"guides").glob("**/*.md"),
ids=str)
def test_guides(fpath):
mktestdocs.check_md_file(fpath=fpath, memory=True)