|
| 1 | +# Copyright 2025 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +from .._support.tracing import CapturedTrace |
| 8 | +from ..lang.global_symbols import * |
| 9 | +from ..ops.wave_ops import GatherToLDS, Write, get_custom |
| 10 | +from ..wave.constraints import ( |
| 11 | + Constraint, |
| 12 | +) |
| 13 | +from ..wave.utils.run_utils import get_default_arch |
| 14 | +from .utils.general_utils import is_valid_global_read |
| 15 | +from .utils.graph_utils import DCE |
| 16 | +from .utils.symbol_utils import ( |
| 17 | + subs_idxc, |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +gather_to_shared_supported_arch = ["gfx950"] |
| 22 | + |
| 23 | + |
| 24 | +def get_write_node_info(read_custom): |
| 25 | + write_node, write_memory, write_idx = [], [], [] |
| 26 | + |
| 27 | + for user in read_custom.users: |
| 28 | + if ( |
| 29 | + isinstance(user, Write) |
| 30 | + and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE |
| 31 | + ): |
| 32 | + write_node.append(user) |
| 33 | + write_memory.append(user.memory) |
| 34 | + write_idx.append(user.get_derived_indices[0]) |
| 35 | + |
| 36 | + return write_node, write_memory, write_idx |
| 37 | + |
| 38 | + |
| 39 | +def gather_to_shared(trace: CapturedTrace, constraints: list[Constraint]): |
| 40 | + """ |
| 41 | + This pass enables direct memory load from global to lds without passing |
| 42 | + through register reducing the data movement. This instruction is supported |
| 43 | + only on specific architectures (gfx950). |
| 44 | + """ |
| 45 | + |
| 46 | + if get_default_arch() not in gather_to_shared_supported_arch: |
| 47 | + return |
| 48 | + |
| 49 | + global_read_nodes = trace.walk(is_valid_global_read) |
| 50 | + for read_node in global_read_nodes: |
| 51 | + read_custom = get_custom(read_node) |
| 52 | + src = read_custom.memory |
| 53 | + src_idx = read_custom.get_derived_indices[0] |
| 54 | + element_type = read_custom.type.dtype |
| 55 | + write_node, write_memory, write_idx = get_write_node_info(read_custom) |
| 56 | + if not write_node: |
| 57 | + continue |
| 58 | + for (dst_node, dst_memory, dst_idx) in zip(write_node, write_memory, write_idx): |
| 59 | + with dst_node.graph.inserting_before(dst_node.fx_node): |
| 60 | + dst_node.replace_all_uses_with( |
| 61 | + GatherToLDS( |
| 62 | + src, src_idx, dst_memory, dst_idx, element_type |
| 63 | + ).add_to_graph(dst_node.graph) |
| 64 | + ) |
| 65 | + |
| 66 | + DCE(trace) |
0 commit comments