diff --git a/qwix/_src/core/dot_general_qt.py b/qwix/_src/core/dot_general_qt.py index b69ab67..d4e3e8e 100644 --- a/qwix/_src/core/dot_general_qt.py +++ b/qwix/_src/core/dot_general_qt.py @@ -215,7 +215,7 @@ def _compute_gradient_for_operand(g: jax.Array, *, for_dlhs: bool): if g_qtype and numerics.should_quantize(g.dtype): if isinstance(y, qarray.QArray) and not any( - v > 1 for v in qarray.get_tiled_axes(y).values() + tile_size > 1 for tile_size in qarray.get_tiled_axes(y).values() ): # Apply the scale of y to g, this trick avoids requantizing y because # the y from fwd pass has different channelwise_axes.