-
Notifications
You must be signed in to change notification settings - Fork 44
Add epilogue subtiling #948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from .. import exc | ||
from .._compat import get_tensor_descriptor_fn_name | ||
from .ast_extension import expr_from_string | ||
from .ast_extension import statement_from_string | ||
from .compile_environment import CompileEnvironment | ||
from .device_function import DeviceFunction | ||
from .host_function import HostFunction | ||
|
@@ -103,6 +104,12 @@ def _get_tile_with_offset_info( | |
|
||
return None | ||
|
||
def _supports_epilogue_subtiling(): | ||
env = CompileEnvironment.current() | ||
if env.device.type != "cuda" or not env.settings.allow_epilogue_subtiling: | ||
return False | ||
return torch.cuda.get_device_capability() >= (10, 0) | ||
|
||
|
||
class IndexingStrategy: | ||
def codegen_load( | ||
|
@@ -376,6 +383,7 @@ def codegen_store( | |
subscript: list[object], | ||
value: ast.AST, | ||
extra_mask: ast.AST | None, | ||
epilogue_subtile: int | None, | ||
) -> ast.AST: | ||
if not self.is_supported(state, fake_tensor, subscript, extra_mask): | ||
return PointerIndexingStrategy().codegen_store( | ||
|
@@ -384,6 +392,10 @@ def codegen_store( | |
assert extra_mask is None | ||
indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) | ||
|
||
config = DeviceFunction.current().config | ||
if _supports_epilogue_subtiling and config.epilogue_subtiling: | ||
return self._codegen_epilogue_subtile_store(state, fake_tensor, indexing, store_value) | ||
|
||
# Apply permutation to the value being stored if needed | ||
desc_arg = indexing.tensor_descriptor_arg(state) | ||
store_value = indexing.reshape_store(state, value) | ||
|
@@ -394,12 +406,102 @@ def codegen_store( | |
f"tl.permute({{store_val}}, {desc_arg.permutation!r})", | ||
store_val=store_value, | ||
) | ||
|
||
return expr_from_string( | ||
f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", | ||
value=store_value, | ||
) | ||
|
||
def _codegen_epilogue_subtile_store( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is subtiling to store() only sufficient? Or do we want to have a graph base that collects any pointwise ops flowing into the store? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You will want any pointwise ops as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh so basically performing the epilogue on the subtile? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly. Epilogue subtiling is to avoid materializing all the registers needed to compute the result over the entire TMEM allocated tile. |
||
self, | ||
state: CodegenState, | ||
fake_tensor: torch.Tensor, | ||
indexing: BlockedSubscriptIndexing, | ||
store_value: ast.AST, | ||
) -> ast.AST | None: | ||
# Currently support 2D tiles without permutations | ||
if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2: | ||
return None | ||
|
||
env = CompileEnvironment.current() | ||
block_m, block_n = indexing.block_shape | ||
try: | ||
block_n_hint = env.size_hint(block_n) | ||
except Exception: | ||
return None | ||
|
||
if block_n_hint % 2 != 0: | ||
return None | ||
|
||
device_fn = state.device_function | ||
codegen = state.codegen | ||
|
||
block_m_str = device_fn.literal_expr(block_m) | ||
block_n_str = device_fn.literal_expr(block_n) | ||
indexing.block_shape[1] //= 2 | ||
desc_arg = indexing.tensor_descriptor_arg(state) | ||
|
||
if desc_arg.permutation is not None: | ||
return None | ||
|
||
|
||
block_n_half_str = f"({block_n_str} // 2)" | ||
|
||
# Lift the store value into a temporary variable for reuse | ||
acc_var = codegen.lift(store_value, prefix="acc") | ||
|
||
reshape_expr = expr_from_string( | ||
"tl.reshape({acc}, [{dim_m}, 2, {dim_half}])", | ||
acc=acc_var, | ||
dim_m=expr_from_string(block_m_str), | ||
dim_half=expr_from_string(block_n_half_str), | ||
) | ||
reshape_var = codegen.lift(reshape_expr, prefix="acc") | ||
|
||
permute_expr = expr_from_string( | ||
"tl.permute({acc}, [0, 2, 1])", | ||
acc=reshape_var, | ||
) | ||
permute_var = codegen.lift(permute_expr, prefix="acc") | ||
|
||
acc0_name = codegen.tmpvar(prefix="acc") | ||
acc1_name = codegen.tmpvar(prefix="acc") | ||
codegen.add_statement( | ||
statement_from_string( | ||
f"{acc0_name}, {acc1_name} = tl.split({{acc}})", | ||
acc=permute_var, | ||
) | ||
) | ||
acc0 = expr_from_string(acc0_name) | ||
acc1 = expr_from_string(acc1_name) | ||
|
||
desc_name = indexing.tensor_descriptor(state) | ||
offset0 = expr_from_string(indexing.offsets[0]) | ||
offset1 = expr_from_string(indexing.offsets[1]) | ||
|
||
# First subtile store | ||
codegen.add_statement( | ||
statement_from_string( | ||
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", | ||
off0=offset0, | ||
off1=offset1, | ||
value=acc0, | ||
) | ||
) | ||
|
||
offset1_shifted = expr_from_string( | ||
"({offset} + {half})", | ||
offset=expr_from_string(indexing.offsets[1]), | ||
half=expr_from_string(block_n_half_str), | ||
) | ||
|
||
# Emit second subtile store as the expression returned to the caller | ||
return expr_from_string( | ||
f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", | ||
off0=offset0, | ||
off1=offset1_shifted, | ||
value=acc1, | ||
) | ||
|
||
class StackIndexingStrategy: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 didn't you add something to fix this somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes - we have sanitization pass for
triton_helpers.*
right now athelion/helion/_compiler/device_function.py
Lines 519 to 532 in 1aaba3f