Skip to content

Conversation

@chengjunlu
Copy link
Contributor

@chengjunlu chengjunlu commented Aug 11, 2025

To use the transpose 2d block io to load column major matrix from global memory. (The column major matrix here could be generalized to the cases that fast change dimension of register layout is not same as the fast change dim on global memory.)

The 2d block io only can transpose the matrix of i32 type when load matrix from memory to register. To transpose the matrix of type bits number < 32, we need to further transpose matrix inside the register.

The steps the to load matrix with transposing with 2D block IO:

  1. To load the matrix as d32 type matrix from memory with transposed to register.
  2. (Optional if scalar type < 32 bits) To transpose the MxNxd32 to Mx(32/m)xNxdm inside the register.

Right now we only use the bitcast operation to transpose the matrix of which the width is equal to the threads per warp for step 2.

Further we will support more matrix of which the width is not equal to the threads per warp.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This draft PR implements transpose 2D block load functionality to efficiently load column major matrices from global memory on Intel Xe+ GPUs. The implementation introduces a transpose operation when the register layout's fast-changing dimension differs from the memory layout, using d32 type matrices with bitcast operations for the transformation.

  • Added support for transpose 2D block IO operations with transpose parameter
  • Enhanced block IO tile size calculation to handle transpose scenarios
  • Implemented new test coverage for transpose and column major load operations

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 8 comments.

File Description
LoadStoreOpToLLVM.cpp Major refactoring of 2D block load implementation to support transpose operations and simplified layout handling
tensor-pointer-load-block-2d.mlir Updated test expectations for new block load configurations and tile sizes
test_block_store.py Added transpose parameter and column major test cases for block operations

@chengjunlu chengjunlu force-pushed the chengjun/trans_2d_load branch from efff84d to 55c896e Compare August 11, 2025 07:42
@etiotto etiotto marked this pull request as draft October 9, 2025 14:09
@chengjunlu chengjunlu force-pushed the chengjun/trans_2d_load branch from 20a1637 to 942ca37 Compare November 4, 2025 04:49
@chengjunlu chengjunlu changed the title [Draft] Transpose 2d load. [LoadStoreOpToLLVM] Transpose 2d load. Nov 4, 2025
@chengjunlu chengjunlu marked this pull request as ready for review November 4, 2025 04:50
@chengjunlu chengjunlu force-pushed the chengjun/trans_2d_load branch 7 times, most recently from 210886e to e979428 Compare November 10, 2025 05:37
packedElemSizeInBits = 32;
numPackedVals = packedElemSizeInBits / elemSizeInBits;

// Improve this. The current 2D block load only transposes the matrix at
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The improvements will be added in another PR to minimal the changes in a single PR.

@chengjunlu chengjunlu requested a review from Copilot November 10, 2025 05:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@chengjunlu
Copy link
Contributor Author

@whitneywhtsang @etiotto , The transpose loading is ready for review.

@chengjunlu chengjunlu force-pushed the chengjun/trans_2d_load branch from e979428 to 248ae4c Compare November 12, 2025 03:00
@whitneywhtsang
Copy link
Contributor

Can you fix the typo in the image of the PR description or remove it?

return axisInfo ? axisInfo->getStride(dim) : -1;
if (axisInfo) {
const SmallVector<int64_t> &stride = axisInfo->getStride();
if (dim < stride.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we call getStride with dim more than the size of stride?

Copy link
Contributor Author

@chengjunlu chengjunlu Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a typical case in real Triton kernel but for LIT test cases.

There are simple LIT cases which is not real input kernel by user, like this

tt.func public @regular_pointer_gather_io(%arg0: tensor<128x64x!tt.ptr<f16>, #mma>,

The arguments of tensor type of the function is converted to LLVM struct type before the axis info analysis pass run. The axis info initilize the AxisInfo with only one dimenssion for those non-tensor type. The original code will dereference the stride information with dim > 1 which is out of the boundary of the AxisInfo for the operands.

This is just a simple protection to return unknown stride for out-of-boundary dim of AxisInfo.

@chengjunlu
Copy link
Contributor Author

Can you fix the typo in the image of the PR description or remove it?

Description has been updated for the code of this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[06-fused-attention] Determine if FP8 operand B can use 2d block load

3 participants