Skip to content

Commit 727bfa6

Browse files
committed
[Wave] Add support for direct global load to lds
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 89d4060 commit 727bfa6

File tree

9 files changed

+140
-23
lines changed

9 files changed

+140
-23
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,16 @@ def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Regi
291291
...
292292

293293

294+
def gather_to_lds(
295+
src: "Memory",
296+
src_idx: dict[IndexSymbol, IndexSequence],
297+
dst: "Memory",
298+
dst_idx: dict[IndexSymbol, IndexSequence],
299+
dtype: DataType,
300+
):
301+
...
302+
303+
294304
def define_op(op_name: str) -> Callable[[T], T]:
295305
def decorator(cls: T) -> T:
296306
cls.tkw_op_name = op_name
@@ -2332,3 +2342,19 @@ def indexing_dims(self) -> list[IndexExpr]:
23322342

23332343
def infer_type(self):
23342344
self.type = get_custom(_to_sequence(self.args)[0]).type
2345+
2346+
2347+
@define_op("gather_to_lds")
2348+
@dataclass
2349+
class GatherToLDS(CustomOp):
2350+
"""
2351+
Represents an instruction that performs direct load from global
2352+
to lds. Source memory points to the global memory to load from
2353+
and the destination points to shared memory.
2354+
"""
2355+
2356+
src: Memory
2357+
src_idx: dict[IndexSymbol, IndexSequence]
2358+
dst: Memory
2359+
dst_idx: dict[IndexSymbol, IndexSequence]
2360+
dtype: DataType

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
exp2,
6262
extract,
6363
extract_slice,
64+
gather_to_lds,
6465
ge,
6566
get_custom,
6667
get_result,
@@ -1612,3 +1613,15 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
16121613
[1],
16131614
)
16141615
emitter.bind_node_proxy(node, IRProxyValue(slice))
1616+
1617+
1618+
@handle_op(gather_to_lds)
1619+
def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
1620+
try:
1621+
src, src_idx, dst, dst_idx, dtype = node.args
1622+
except ValueError as e:
1623+
raise ValidationError("Malformed arguments") from e
1624+
1625+
return amdgpu_d.gather_to_lds(
1626+
transfer_type=dtype, src=src, src_indices=src_idx, dst=dst, dst_indices=dst_idx
1627+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from .._support.tracing import CapturedTrace
8+
from ..lang.global_symbols import *
9+
from ..ops.wave_ops import GatherToLDS, Write, get_custom
10+
from ..wave.constraints import (
11+
Constraint,
12+
)
13+
from ..wave.utils.run_utils import get_default_arch
14+
from .utils.general_utils import is_valid_global_read
15+
from .utils.graph_utils import DCE
16+
from .utils.symbol_utils import (
17+
subs_idxc,
18+
)
19+
20+
21+
gather_to_shared_supported_arch = ["gfx950"]
22+
23+
24+
def get_write_node_info(read_custom):
25+
write_node, write_memory, write_idx = [], [], []
26+
27+
for user in read_custom.users:
28+
if (
29+
isinstance(user, Write)
30+
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
31+
):
32+
write_node.append(user)
33+
write_memory.append(user.memory)
34+
write_idx.append(user.get_derived_indices[0])
35+
36+
return write_node, write_memory, write_idx
37+
38+
39+
def gather_to_shared(trace: CapturedTrace, constraints: list[Constraint]):
40+
"""
41+
This pass enables direct memory load from global to lds without passing
42+
through register reducing the data movement. This instruction is supported
43+
only on specific architectures (gfx950).
44+
"""
45+
46+
if get_default_arch() not in gather_to_shared_supported_arch:
47+
return
48+
49+
global_read_nodes = trace.walk(is_valid_global_read)
50+
for read_node in global_read_nodes:
51+
read_custom = get_custom(read_node)
52+
src = read_custom.memory
53+
src_idx = read_custom.get_derived_indices[0]
54+
element_type = read_custom.type.dtype
55+
write_node, write_memory, write_idx = get_write_node_info(read_custom)
56+
if not write_node:
57+
continue
58+
for (dst_node, dst_memory, dst_idx) in zip(write_node, write_memory, write_idx):
59+
with dst_node.graph.inserting_before(dst_node.fx_node):
60+
dst_node.replace_all_uses_with(
61+
GatherToLDS(
62+
src, src_idx, dst_memory, dst_idx, element_type
63+
).add_to_graph(dst_node.graph)
64+
)
65+
66+
DCE(trace)

iree/turbine/kernel/wave/global_to_shared_gathers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from .utils.symbol_utils import subs_idxc
2323
from .utils.general_utils import is_gather
2424
from .minimize_global_loads import (
25-
has_write_shared_user,
2625
construct_min_global_access_pattern,
2726
materialize_shape,
2827
identify_optimizable_loads,
2928
update_write_dependencies,
3029
SharedReadMetadata,
3130
)
31+
from .utils.general_utils import (
32+
has_write_shared_user,
33+
)
3234

3335
"""
3436
We are given N global gathers that are promoted to shared memory. This function

iree/turbine/kernel/wave/minimize_global_loads.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TilingConstraint,
1212
)
1313
from .._support.tracing import CapturedTrace
14-
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
14+
from .._support.indexing import IndexSequence, IndexSymbol, IndexExpr
1515
from ..ops.wave_ops import Read, Write, get_custom
1616
from ..lang.global_symbols import *
1717
from .utils.general_utils import (
@@ -20,6 +20,9 @@
2020
is_shared_read,
2121
get_fastest_index,
2222
)
23+
from .utils.general_utils import (
24+
is_valid_global_read,
25+
)
2326
from .utils.graph_utils import (
2427
DCE,
2528
)
@@ -41,23 +44,6 @@ class SharedReadMetadata:
4144
memory_shape: tuple[int | IndexExpr]
4245

4346

44-
def has_write_shared_user(node: Read) -> bool:
45-
return any(
46-
isinstance(user, Write)
47-
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
48-
for user in node.users
49-
)
50-
51-
52-
def is_valid_global_read(node: fx.Node) -> bool:
53-
custom = get_custom(node)
54-
return (
55-
isinstance(custom, Read)
56-
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
57-
and has_write_shared_user(custom)
58-
)
59-
60-
6147
def is_transposed_read(custom: Read) -> bool:
6248
"""
6349
Checks whether or not we are doing a transposed read.

iree/turbine/kernel/wave/promotion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_promotion_pattern(
6767
```
6868
read_from_global lhs
6969
write_to_shared lhs
70-
read_from_global lhs
71-
write_to_shared lhs
70+
read_from_global rhs
71+
write_to_shared rhs
7272
shared_barrier
7373
read_from_shared lhs
7474
read_from_shared rhs

iree/turbine/kernel/wave/utils/general_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import os
1010
import sympy
1111
import torch
12+
import torch.fx as fx
1213
from typing import Any, Callable, Optional
1314

1415

1516
from ..._support.indexing import IndexExpr, IndexSequence, IndexSymbol
1617
from ...lang.global_symbols import *
17-
from ...ops.wave_ops import CustomOp, Read, Iterate, Write
18+
from ...ops.wave_ops import CustomOp, Read, Iterate, Write, get_custom
1819
from ..assumptions import Assumption
1920
from ..constraints import (
2021
Constraint,
@@ -375,6 +376,26 @@ def is_shared_read(node: CustomOp) -> bool:
375376
)
376377

377378

379+
def has_write_shared_user(node: Read) -> bool:
380+
return any(
381+
isinstance(user, Write)
382+
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
383+
for user in node.users
384+
)
385+
386+
387+
def is_valid_global_read(node: fx.Node) -> bool:
388+
"""
389+
Check if a read node is global and if its user writes to shared memory.
390+
"""
391+
custom = get_custom(node)
392+
return (
393+
isinstance(custom, Read)
394+
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
395+
and has_write_shared_user(custom)
396+
)
397+
398+
378399
def is_gather(custom: CustomOp) -> bool:
379400
if not isinstance(custom, Read):
380401
return False

iree/turbine/kernel/wave/utils/graph_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import iree.turbine.kernel.lang as tkl
1010
from ...ops.wave_ops import (
1111
get_custom,
12+
Read,
1213
Write,
1314
NestedRegionOp,
1415
Output,

iree/turbine/kernel/wave/wave.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..lang import Grid, IndexMapping
1313
from ..lang.global_symbols import *
1414
from ..ops import wave_ops
15-
from ..ops.wave_ops import Iterate, CustomOp, get_custom, IterArg
15+
from ..ops.wave_ops import Iterate, CustomOp, get_custom
1616
from .._support.indexing import IndexingContext, IndexExpr
1717
from .symbolic_constraints import SymbolicAlias
1818
from .._support.tracing import (
@@ -51,6 +51,7 @@
5151
from .decompose_scan_ops import decompose_scan_ops
5252
from .decompose_dot_mma import decompose_dot_mma
5353
from .expansion.expansion import expand_graph, add_get_results
54+
from .gather_to_shared import gather_to_shared
5455
from .global_to_shared_gathers import global_to_shared_gathers
5556
from .hoisting import hoist_loop_invariant_ops
5657
from .minimize_global_loads import minimize_global_loads
@@ -541,6 +542,7 @@ def _trace_and_get_kernel_signature(
541542
partial(hoist_loop_invariant_ops, trace, self.constraints),
542543
partial(global_to_shared_gathers, trace, self.constraints),
543544
partial(minimize_global_loads, trace, self.constraints),
545+
partial(gather_to_shared, trace, self.constraints),
544546
partial(apply_shared_memory_indexing_corrections, trace, self.constraints),
545547
]
546548

0 commit comments

Comments
 (0)