Conversation
and test for matrix add where it is inferred from writes Signed-off-by: Alex Zinenko <git@ozinenko.com>
| // TODO: pywave just ignores this not sure if we want to, including the | ||
| // case below where there may be zero constraints. Interestingly, it | ||
| // asserts if trailing dimensions are not found when computing the | ||
| // stride... |
There was a problem hiding this comment.
Is it safe to simply ignore the symbols for which there are no constraints when setting index sequences from write?
There was a problem hiding this comment.
If no constraints are specified or the vector shape is not set to 0 (dimensions we don't want to expand), then the symbol either corresponds to the actual tensor dimension or is set dynamically in the kernel. I don't think we should ignore the symbol because it could be meaningful in the analysis.
| emitError() << "expected a single workgroup constraint for dimension " | ||
| << tensorType.getShape()[i] | ||
| << " used in the write operation without explicit " | ||
| "`elements_per_thread`"; | ||
| return failure(); |
There was a problem hiding this comment.
Ditto, but in absence of a workgroup constraint?
It feels like we need to set it to start=0, and likely size=1 and stride=1 but not sure
There was a problem hiding this comment.
I think the code does make some assumptions like that where it falls back to start = 0, size and stride of 1, but I think we shouldn't allow that and instead be more explicit.
| // TODO: in pywave, we always do `startExpr % threadsPerWave` where | ||
| // threadsPerWave == 1 for workgroup dims other than X, making it | ||
| // always zero. It mentions an assumption about the (64, 1, 1) thread | ||
| // shape, but it is unclear whether that assumption always holds. | ||
| // It looks like the intention for this was to express lane ID rather | ||
| // than thread ID, but it is unclear how it accounts for multiple | ||
| // wavefronts running in parallel. |
There was a problem hiding this comment.
The comment in the original source (
wave/wave_lang/kernel/wave/constraints.py
Lines 498 to 501 in 601ab68
There was a problem hiding this comment.
This comes up in the SIMT context (no MMA, you can also see this in the example for the atomic case). If you look at the original code, what was happening was that because we dont have an MMA, the default pattern for SIMT is a thread linear pattern and so for the atomicAdd we were getting a dependence on x and y, even though that shouldn't be the case for the example. So this was a fix to handle that scenario. Will also tag @nithinsubbiah to add more context.
and test for matrix add where it is inferred from writes