-
Notifications
You must be signed in to change notification settings - Fork 68
[sharktank] Refactor mask generation to utils from ops #2496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Coverage reportClick to see where and how coverage changed
This report was generated by python-coverage-comment-action |
boolean_mask = torch.logical_or(causal_mask, boolean_input_mask[:, None, None, :]) | ||
numeric_mask = torch.where(boolean_mask, max_negative_value(dtype, device), 0).to( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have torch.logical_or
and torch.where
implemented as sharktank.ops
, the this function does not need to be decorated as trivially replicable. The same argument holds for some other functions in create_attention_mask_for_decode
and create_chunked_attention_mask
.
We generally want to expand the op converge to all that torch has. Definitely to all that we use.
I see that the problem is create_causal_context_mask. Because it creates un-replicated tensors with
src = torch.arange(src_len, device=device)[None, None, None, :]
target = torch.arange(target_len, device=device)[None, None, :, None]
We will keep running into the problem of having to construct something that does not depend on a tensor arg, so there is nothing to propagate the sharding nature from. E.g. construct all-zeros or all-ones tensor. To generalize this approach we need to come up with something. Maybe allow downstream ops to mix sharded and unsharded args.
If we want to write generic model code we need a solution or we will keep dancing around the problem. This can make things quite nasty if you need to modify existing code to create such tensors where sharded tensors would be present. This can happen for example when extending our LLM to support some new architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have started on this (see #2532 for arange
) for a different reason (the current approach introduces transfers that break fusion).
Maybe allow downstream ops to mix sharded and unsharded args.
This approach is already causing issue with elementwise_binary. The transfers break fusion. We can't use it. If we need ShardedTensor inputs we should be making those directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will implement this seperately in a few PRs. I want to land this as is for now since this was supposed to be a simple move of the functions from A to B.
Signed-off-by: Alex Vasile <[email protected]>
fb339cf
to
f5e331f
Compare
Signed-off-by: Alex Vasile <[email protected]>
Refactor and cleanup unused imports in sharktank Dependent on #2496
Revert the mask generation functions from ops back to utils