Skip to content

Conversation

@Wei-Lin-Intel
Copy link

@Wei-Lin-Intel Wei-Lin-Intel commented Oct 24, 2025

Conv1D's compute precision in Qwen3-Next in G2 should be kept as float, thus Conv1D weight should cast to fp32.

This PR removes the unnecessary bf16->fp32 cast in every forward call in flat linear attention. Instead it just calls the cast once in the first time (normally during the profile run), and then removes the original bf16 conv1d weight since it won't be used any more.

Copy link

@czhu15 czhu15 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except one minor question

attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
row = attn[..., i, :i].contiguous()
sub = attn[..., :i, :]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why row need contiguous but sub doesn't?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

row requires to slice from the last dim and would cause discontinue on the address, so it is better to perform contiguous.

@czhu15 czhu15 merged commit 4f5009a into HabanaAI:aice/v1.22.0 Oct 29, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants