Skip to content

Conversation

archana-ramalingam
Copy link
Collaborator

Revert the mask generation functions from ops back to utils

@github-actions
Copy link
Contributor

github-actions bot commented Oct 16, 2025

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  sharktank/sharktank/layers
  paged_attention.py
  sharktank/sharktank/ops
  default_impls.py
  signatures.py
  sharktank/sharktank/utils
  attention.py 124-131, 166-182
  sharktank/tests/models/llama
  attention_test.py
  sharktank/tests/models/llama4
  llama4_test.py
Project Total  

This report was generated by python-coverage-comment-action

Comment on lines +43 to +44
boolean_mask = torch.logical_or(causal_mask, boolean_input_mask[:, None, None, :])
numeric_mask = torch.where(boolean_mask, max_negative_value(dtype, device), 0).to(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have torch.logical_or and torch.where implemented as sharktank.ops, the this function does not need to be decorated as trivially replicable. The same argument holds for some other functions in create_attention_mask_for_decode and create_chunked_attention_mask.
We generally want to expand the op converge to all that torch has. Definitely to all that we use.

I see that the problem is create_causal_context_mask. Because it creates un-replicated tensors with

    src = torch.arange(src_len, device=device)[None, None, None, :]
    target = torch.arange(target_len, device=device)[None, None, :, None]

We will keep running into the problem of having to construct something that does not depend on a tensor arg, so there is nothing to propagate the sharding nature from. E.g. construct all-zeros or all-ones tensor. To generalize this approach we need to come up with something. Maybe allow downstream ops to mix sharded and unsharded args.

If we want to write generic model code we need a solution or we will keep dancing around the problem. This can make things quite nasty if you need to modify existing code to create such tensors where sharded tensors would be present. This can happen for example when extending our LLM to support some new architecture.

Copy link
Contributor

@Alex-Vasile Alex-Vasile Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have started on this (see #2532 for arange) for a different reason (the current approach introduces transfers that break fusion).

Maybe allow downstream ops to mix sharded and unsharded args.

This approach is already causing issue with elementwise_binary. The transfers break fusion. We can't use it. If we need ShardedTensor inputs we should be making those directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will implement this seperately in a few PRs. I want to land this as is for now since this was supposed to be a simple move of the functions from A to B.

@Alex-Vasile Alex-Vasile self-assigned this Oct 20, 2025
Signed-off-by: Alex Vasile <[email protected]>
@archana-ramalingam archana-ramalingam merged commit 137e12c into main Oct 21, 2025
39 checks passed
@archana-ramalingam archana-ramalingam deleted the refactor-mask-utils branch October 21, 2025 02:17
archana-ramalingam added a commit that referenced this pull request Oct 21, 2025
Refactor and cleanup unused imports in sharktank
Dependent on #2496
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants