-
Notifications
You must be signed in to change notification settings - Fork 143
Use incsubtesor in gradient of repeat #1621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Use incsubtesor in gradient of repeat #1621
Conversation
9f98ccb
to
6935466
Compare
6935466
to
2230e91
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (78.00%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1621 +/- ##
==========================================
- Coverage 81.64% 81.62% -0.02%
==========================================
Files 231 231
Lines 52997 52968 -29
Branches 9395 9384 -11
==========================================
- Hits 43267 43237 -30
- Misses 7282 7285 +3
+ Partials 2448 2446 -2
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Repeat operation to use IncSubtensor
for gradient computation instead of dense dot products, removing support for scalar repeats and axis=None at the Op level. These cases are now handled by the helper function pt.repeat
using broadcast_to and reshape operations.
- Replace gradient computation with
IncSubtensor
approach for better memory/performance - Remove scalar repeat and axis=None logic from Repeat Op, delegating to helper function
- Simplify Numba implementation by removing object mode for most cases
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
pytensor/tensor/extra_ops.py | Main refactoring of Repeat Op constructor, make_node, grad, and repeat helper function |
tests/tensor/test_extra_ops.py | Updated tests to reflect new Op constraints and added static shape inference tests |
tests/link/numba/test_extra_ops.py | Updated test parameters to match new Op requirements |
pytensor/link/numba/dispatch/extra_ops.py | Simplified Numba implementation with fallback for unsupported cases |
pytensor/link/numba/dispatch/basic.py | Made node parameter required in generate_fallback_impl |
pytensor/utils.py | Removed outdated comment about Python < 2.6 |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
|
||
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs): | ||
def generate_fallback_impl(op, node, storage_map=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it's not really optional, the code immediately assumes node is passed (and not None)
if axis is None or axis < 0: | ||
# Operator Repeat does not support None or negative axis | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that the relevant error is raised?
r = Repeat(axis=0)(x, 2) | ||
assert r.broadcastable == (False, True, False) | ||
def test_static_shape(self): | ||
x = TensorType(config.floatX, shape=(None, 1, 3))() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x = TensorType(config.floatX, shape=(None, 1, 3))() | |
x = pt.tensor(shape=(None, 1, 3)) |
Why is TensorType being used directly here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
someone wrote it like that at first
# This case could be implemented in the future | ||
r = repeat(x, [1, 2, 4], axis=2) | ||
assert r.type.shape == (None, 1, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The static shape is this case is just the sum of the repeat values?
You want to do something like try to constant fold the sum of repeats? We could just check if its a constant and grab the data out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could... we can always COULD
This should be better than a dense dot product, both memory and speed wise. It's also more elegant?
Also:
pt.repeat
uses broadcast_to + alloc in this case, so it's never triggered unless users create their Op manually. Well no moreaxis=None
symbolically instead of within CumOp #1574This way we only have to support the code that is really specific to Repeat with vector repeats.
📚 Documentation preview 📚: https://pytensor--1621.org.readthedocs.build/en/1621/