-
Notifications
You must be signed in to change notification settings - Fork 74
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: raggediterdomain_clone
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit 3a80926 Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||
| Tests |
| ||||||||||
| Documentation |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Memory Management
|
Greptile SummaryImplements the Key changes:
Implementation note: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant RaggedIterDomain
participant IterDomainBuilder
participant Combine as Combine Expr
participant IrContainer
User->>RaggedIterDomain: combine(component, ragged)
RaggedIterDomain->>RaggedIterDomain: Validate inputs (null checks, types)
RaggedIterDomain->>RaggedIterDomain: Check parallelization (must be Serial)
RaggedIterDomain->>RaggedIterDomain: Check iter types (must be Iteration)
alt ragged has Partition definition
RaggedIterDomain->>Combine: Validate component matches partition
Note over RaggedIterDomain,Combine: Option 3: Validate when possible
else No Partition definition
Note over RaggedIterDomain: Trust user (after set/segmentation)
end
RaggedIterDomain->>RaggedIterDomain: Get extents from ragged
RaggedIterDomain->>RaggedIterDomain: Verify extents is 1D tensor
RaggedIterDomain->>IrContainer: Create symbolic Val for combined_extent
Note over RaggedIterDomain,IrContainer: Extent represents sum(extents)<br/>but not explicitly computed
RaggedIterDomain->>IterDomainBuilder: Build output IterDomain
IterDomainBuilder-->>RaggedIterDomain: combined_id (regular IterDomain)
RaggedIterDomain->>Combine: Create Combine expression
Combine->>Combine: addOutput(combined_id)
Combine->>Combine: addInput(component)
Combine->>Combine: addInput(ragged)
RaggedIterDomain-->>User: Return combined_id
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| NVF_ERROR(component != nullptr, "combine: component IterDomain is null"); | ||
| NVF_ERROR(ragged != nullptr, "combine: ragged IterDomain is null"); | ||
|
|
||
| NVF_ERROR( | ||
| !component->isA<RaggedIterDomain>(), | ||
| "combine: component must be a regular IterDomain, got RaggedIterDomain: ", | ||
| component->toString()); | ||
|
|
||
| // Validate that component and ragged have compatible properties | ||
| NVF_ERROR_EQ( | ||
| component->getParallelType(), | ||
| ParallelType::Serial, | ||
| "Combining parallelized IterDomain not supported: ", | ||
| component->toString()); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| ragged->getParallelType(), | ||
| ParallelType::Serial, | ||
| "Combining parallelized RaggedIterDomain not supported: ", | ||
| ragged->toString()); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| component->getIterType(), | ||
| IterType::Iteration, | ||
| "combine: only IterType::Iteration is supported for component, got ", | ||
| component->getIterType(), | ||
| " for IterDomain: ", | ||
| component->toString()); | ||
|
|
||
| NVF_ERROR_EQ( | ||
| ragged->getIterType(), | ||
| IterType::Iteration, | ||
| "combine: only IterType::Iteration is supported for ragged, got ", | ||
| ragged->getIterType(), | ||
| " for RaggedIterDomain: ", | ||
| ragged->toString()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing validation that component and ragged are semantically compatible or from the same partition. Consider verifying:
- That the component extent matches the ragged extents dimension (component extent == extents_tv logical domain size)
- Or that both came from the same
Partitionoperation (component->definition() == ragged->definition())
Without this, arbitrary unrelated IterDomains could be combined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
| IrBuilder::createInContainer<Combine>( | ||
| container, combined_id, component, ragged); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unlike Partition which stores the extents tensor as an attribute (line 2634 in internal_nodes.cpp), Combine doesn't store it anywhere. While the extents can be accessed via ragged->extents(), this creates an inconsistency.
If the ragged IterDomain later loses its connection to the original extents (e.g., through cloning or transformation), the Combine expression won't have a direct reference to them.
| IrBuilder::createInContainer<Combine>( | |
| container, combined_id, component, ragged); | |
| // Store extents as an attribute for consistency with Partition | |
| IrBuilder::createInContainer<Combine>( | |
| container, combined_id, component, ragged, ragged->extents()); |
Note: This would require updating the Combine class signature to accept and store extents as an attribute.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 1 comment
| Val* combined_extent = | ||
| IrBuilder::createInContainer<Val>(container, DataType::Index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The symbolic extent approach leaves the combined extent as an opaque Val without establishing its relationship to the actual sum. During lowering/indexing, the system needs to compute actual indices, but there's no expression connecting combined_extent to sum(extents_tv, {0}).
Check whether this symbolic extent will cause issues during:
- Index computation when iterating over the combined dimension
- Extent analysis passes that need to know actual sizes
- Fusion validation that checks dimension compatibility
This PR introduces the combine operation as discussed in the RaggedIterDomain design doc.
One design decision that I changed from the original design doc is about detecting and validating component iter domains. Previously, I was thinking about using the exact graph to find the corresponding component iter domain for a given ragged iter domain (e.g., #5550 (comment)). However, it won't work, for example, when a fusion is segmented and a segment does not have the corresponding
Partitionexpr for aRaggedIterDomain. For example, when a tensor is used as an input forasNested, followed by some other operations, if the fusion is segmented after some operations, the latter segment won't be able to see theasNestedand thePartitionoperations as they don't exist in the segment. This could be alleviated by providing an exact graph for the whole complete fusion, but more fundamentally, if a fusion has a nested tensor as an input, there doesn't seem to be any reasonable way to attach aPartitionexpr.See doc/dev/ragged_iter_domain_combine_design_doc.md for detailed discussions. At this moment, I decided to not worry too much about the validation and assume the correctness is guaranteed by the user.
Note that partitioning is still limited to 1D extents. Multi-dim offsets will be the next step of this series of RPs.