Skip to content

Commit 3928dfd

Browse files
JyotinderSinghcopybara-github
authored andcommitted
## Add Quantization Error Propagation (QEP) support to Qwix
QEP extends standard GPTQ by compensating for cascading quantization noise introduced by preceding layers during inference. While GPTQ minimizes `||W @ X - W_q @ X||^2` assuming perfect float inputs, QEP actively minimizes `||W @ X_float - W_q @ X_q||^2`. This is achieved by computing an input cross-correlation statistic (`H_delta`) and applying a localized weight correction (`W_corrected = W + alpha * (W @ H_delta @ inv(H))`) prior to standard GPTQ rounding. ## API & Usage The primary entry point is `qep.quantize(...)`. Because QEP must measure the accumulated error from previously quantized layers, it orchestrates a multi-pass calibration loop stage-by-stage rather than relying on a single forward pass. ```python result = qep.quantize( model=model, # calibration_data must be reiterable since QEP sweeps it multiple times calibration_data=dataset_iterator_factory, rules=[qep.QepRule(module_path='Dense_.*', weight_qtype=jnp.int8)], variables=variables ) # The returned QepResult contains everything needed for inference: inference_output = result.model.apply( {'params': result.params, 'quant_stats': result.quant_stats}, sample_input ) ``` For offline or distributed pipelines where statistics are pre-computed remotely, `qep.quantize_params()` can be directly invoked to apply the QEP correction and GPTQ rounding to float weights without re-running the model graph. ## Key modifications - `qep_core.py`: Pure-JAX algorithms for QEP statistics (`compute_qep_stats`) and the core weight shifting logic (`weight_correct`). - `qep.py`: The stagewise orchestrator (`qep.quantize`). Dynamically discovers interconnected topological stages, applies a two-pass (float vs. quantized) calibration loop per batch, and updates weights progressively through the network. - `calibration.py`: Refactored the core `CalibrationProvider` mechanics to decouple single-pass logic, enabling robust multi-pass activation interception for QEP. - `QepRule`: New configuration struct extending `GptqRule` with hyperparameter tuning (`correction_factor`, `damping_factor`). PiperOrigin-RevId: 888364583
1 parent e43873c commit 3928dfd

File tree

6 files changed

+1683
-43
lines changed

6 files changed

+1683
-43
lines changed

qwix/contrib/calibration.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from typing import Any, Callable
1919

2020
import flax
21+
from flax import nnx
2122
import jax
2223
from jax import numpy as jnp
2324
from qwix._src import averaging
2425
from qwix._src import flax_util
2526
from qwix._src import qconfig
26-
from qwix._src.core import qarray
2727
from qwix._src.providers import ptq
2828

2929

@@ -32,7 +32,7 @@ class CalibrationProvider(qconfig.QuantizationProvider, metaclass=abc.ABCMeta):
3232
3333
This provider handles the common boilerplate for all calibration providers:
3434
rule type checking, dimension validation, weight name lookup, and LHS
35-
reshaping. Subclasses implement `_collect_stats` to define what happens
35+
reshaping. Subclasses implement ``_collect_stats`` to define what happens
3636
with the validated, reshaped activations.
3737
"""
3838

@@ -45,7 +45,16 @@ def get_stats_suffix(self) -> str:
4545
"""Returns the suffix for the stats variable name (e.g., '_gptq')."""
4646

4747
@abc.abstractmethod
48-
def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
48+
def _collect_stats(
49+
self,
50+
lhs: jax.Array,
51+
weight_name: str,
52+
*,
53+
module_path: tuple[str, ...],
54+
op_name: str,
55+
op_id: str | None,
56+
lhs_id: int,
57+
) -> None:
4958
"""Collects statistics from the reshaped input activations.
5059
5160
Called after all validation passes. The LHS has already been reshaped
@@ -54,6 +63,10 @@ def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
5463
Args:
5564
lhs: Input activations reshaped to (ca, rest) format.
5665
weight_name: The name of the weight parameter for this operation.
66+
module_path: The current module path for this operation.
67+
op_name: The intercepted operation name.
68+
op_id: The operation identifier assigned by the quantization tracer.
69+
lhs_id: Python object id of the original lhs before reshaping.
5770
"""
5871

5972
def dot_general(
@@ -63,11 +76,35 @@ def dot_general(
6376
dimension_numbers: jax.lax.DotDimensionNumbers,
6477
*args,
6578
rule: qconfig.QuantizationRule | None = None,
79+
op_id: str | None = None,
6680
**kwargs,
6781
) -> jax.Array:
82+
"""Intercepts supported weight-bearing ``dot_general`` ops for calibration.
83+
84+
Subclasses do not need to reimplement the common matching and validation
85+
logic. This method:
86+
87+
- resolves the active quantization rule and op id,
88+
- rejects unsupported dimension-number patterns,
89+
- identifies the weight parameter on the RHS,
90+
- reshapes the LHS to ``(contracting_dim, rest)``, and
91+
- delegates the actual stats handling to ``_collect_stats``.
92+
93+
Args:
94+
lhs: The left-hand side array.
95+
rhs: The right-hand side array.
96+
dimension_numbers: The dimension numbers for the dot_general operation.
97+
*args: Additional positional arguments to pass to dot_general.
98+
rule: The quantization rule to apply.
99+
op_id: The operation identifier assigned by the quantization tracer.
100+
**kwargs: Additional keyword arguments to pass to dot_general.
101+
102+
Returns:
103+
The result of the dot_general operation.
104+
"""
68105
res = jax.lax.dot_general(lhs, rhs, dimension_numbers, *args, **kwargs)
69-
if rule is None:
70-
rule, _ = self._get_current_rule_and_op_id('dot_general')
106+
if rule is None or op_id is None:
107+
rule, op_id = self._get_current_rule_and_op_id('dot_general')
71108

72109
rule_type = self.get_rule_type()
73110
if not isinstance(rule, rule_type):
@@ -84,16 +121,25 @@ def dot_general(
84121
# If we cannot identify the weight parameter, we skip calibration.
85122
return res
86123

124+
lhs_id = id(lhs)
87125
# Reorder lhs to (ca, rest) format.
88126
lhs = jnp.moveaxis(lhs, lhs_ca[0], 0)
89127
lhs = lhs.reshape(lhs.shape[0], -1)
90128

91-
self._collect_stats(lhs, weight_name)
129+
self._collect_stats(
130+
lhs,
131+
weight_name,
132+
module_path=tuple(map(str, flax_util.get_current_module_path())),
133+
op_name='dot_general',
134+
op_id=op_id,
135+
lhs_id=lhs_id,
136+
)
92137

93138
return res
94139

95140
def einsum(self, einsum_str, *operands, **kwargs):
96-
rule, _ = self._get_current_rule_and_op_id('einsum')
141+
"""Intercepts supported binary ``einsum`` ops via their lowered dot call."""
142+
rule, op_id = self._get_current_rule_and_op_id('einsum')
97143
rule_type = self.get_rule_type()
98144
if not isinstance(rule, rule_type):
99145
return jnp.einsum(einsum_str, *operands, **kwargs)
@@ -103,7 +149,7 @@ def einsum(self, einsum_str, *operands, **kwargs):
103149

104150
def stats_dot_general(lhs, rhs, dimension_numbers, *args, **kwargs):
105151
return self.dot_general(
106-
lhs, rhs, dimension_numbers, *args, rule=rule, **kwargs
152+
lhs, rhs, dimension_numbers, *args, rule=rule, op_id=op_id, **kwargs
107153
)
108154

109155
with jax.disable_jit():
@@ -124,16 +170,27 @@ def get_intercept_map(self) -> dict[str, Callable[..., Any]]:
124170
class SinglePassCalibrationProvider(CalibrationProvider, metaclass=abc.ABCMeta):
125171
"""Calibration provider that collects single-pass statistics.
126172
127-
This provider implements the simple stats template: `compute_stats`
128-
produces a dict of arrays, which are accumulated into the `quant_stats`
129-
collection using `SimpleMovingAverage`.
173+
This provider implements the simple stats template: ``compute_stats``
174+
produces a dict of arrays, which are accumulated into the ``quant_stats``
175+
collection using ``SimpleMovingAverage``.
130176
"""
131177

132178
@abc.abstractmethod
133179
def compute_stats(self, lhs: jax.Array) -> dict[str, Any]:
134180
"""Computes statistics from the input array."""
135181

136-
def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
182+
def _collect_stats(
183+
self,
184+
lhs: jax.Array,
185+
weight_name: str,
186+
*,
187+
module_path: tuple[str, ...],
188+
op_name: str,
189+
op_id: str | None,
190+
lhs_id: int,
191+
) -> None:
192+
"""Accumulates one batch of single-pass calibration statistics."""
193+
del module_path, op_name, op_id, lhs_id
137194
stats = self.compute_stats(lhs)
138195
aggregator = averaging.SimpleMovingAverage()
139196
var_name = weight_name + self.get_stats_suffix()
@@ -146,7 +203,7 @@ def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
146203

147204
def normalize_weight(
148205
x: jax.Array, contraction_axis: int
149-
) -> tuple[jax.Array, Callable[..., qarray.MaybeQArray]]:
206+
) -> tuple[jax.Array, Callable[..., Any]]:
150207
"""Normalizes a weight tensor into (rows, columns) format.
151208
152209
Reshapes a weight tensor of arbitrary rank into a 2D matrix where the
@@ -175,7 +232,7 @@ def restore_shape(x):
175232

176233
@dataclasses.dataclass(frozen=True)
177234
class CalibratedQuantContext:
178-
"""Context containing a weight, calibration stats, and quantization metadata.
235+
"""A weight prepared for algorithm-specific quantization.
179236
180237
Attributes:
181238
weight: Normalized weight in (rows, columns) format.
@@ -192,10 +249,58 @@ class CalibratedQuantContext:
192249
calibration_stats: dict[str, jax.Array]
193250
abs_w: ptq.WithAux
194251
contracting_axis: int
195-
restore_shape: Callable[..., qarray.MaybeQArray]
252+
restore_shape: Callable[..., Any]
196253
path: tuple[str, ...]
197254

198255

256+
def extract_calibrated_quant_context(
257+
path: tuple[str, ...],
258+
weight: jax.Array,
259+
abs_w: ptq.WithAux,
260+
stats: Any,
261+
) -> CalibratedQuantContext | None:
262+
"""Extracts the calibration context for a single weight.
263+
264+
Args:
265+
path: The dictionary path for this weight.
266+
weight: Floating-point weight to quantize.
267+
abs_w: The WithAux wrapper from the abstract quantized tree.
268+
stats: The calibration statistics for this weight.
269+
270+
Returns:
271+
The CalibratedQuantContext, or None if the weight cannot be quantized
272+
(e.g., if there is not exactly one contracting axis).
273+
"""
274+
# Get the contracting axis by assuming that all non-contracting axes
275+
# are in channelwise_axes.
276+
contracting_axis = set(range(weight.ndim)) - set(abs_w.how.channelwise_axes)
277+
if len(contracting_axis) != 1:
278+
# Fallback to PTQ if we can't identify a single contracting axis.
279+
return None
280+
contracting_axis = contracting_axis.pop()
281+
282+
# Normalize the weight to (ra, ca) format.
283+
w_norm, restore_shape = normalize_weight(weight, contracting_axis)
284+
how = dataclasses.replace(abs_w.how, channelwise_axes=[0])
285+
if contracting_axis in how.tiled_axes:
286+
how = dataclasses.replace(
287+
how, tiled_axes={1: how.tiled_axes[contracting_axis]}
288+
)
289+
290+
# Get calibration stats.
291+
calibration_stats = averaging.SimpleMovingAverage().get_calibration(stats)
292+
293+
return CalibratedQuantContext(
294+
weight=w_norm,
295+
how=how,
296+
calibration_stats=calibration_stats,
297+
abs_w=abs_w,
298+
contracting_axis=contracting_axis,
299+
restore_shape=restore_shape,
300+
path=path,
301+
)
302+
303+
199304
def quantize_params_with_calibration(
200305
params: Any,
201306
abstract_quantized_params: Any,
@@ -210,7 +315,7 @@ def quantize_params_with_calibration(
210315
This function handles the common boilerplate for all calibration-based
211316
quantization algorithms (GPTQ, QEP, AWQ): parameter iteration, stats
212317
lookup, weight normalization, and PTQ fallback. The algorithm-specific
213-
logic is provided via `quantize_fn`.
318+
logic is provided via ``quantize_fn``.
214319
215320
Args:
216321
params: The floating-point param tree to quantize.
@@ -237,36 +342,11 @@ def quantize_params_with_calibration(
237342
not_quantized_params[path] = w
238343
continue
239344

240-
# Get the contracting axis by assuming that all non-contracting axes
241-
# are in channelwise_axes.
242-
contracting_axis = set(range(w.ndim)) - set(abs_w.how.channelwise_axes)
243-
if len(contracting_axis) != 1:
244-
# Fallback to PTQ if we can't identify a single contracting axis.
345+
ctx = extract_calibrated_quant_context(path, w, abs_w, stats)
346+
if ctx is None:
245347
not_quantized_params[path] = w
246348
continue
247-
contracting_axis = list(contracting_axis)[0]
248-
249-
# Normalize the weight to (ra, ca) format.
250-
w_norm, restore_shape = normalize_weight(w, contracting_axis)
251-
how = dataclasses.replace(abs_w.how, channelwise_axes=[0])
252-
if contracting_axis in how.tiled_axes:
253-
how = dataclasses.replace(
254-
how, tiled_axes={1: how.tiled_axes[contracting_axis]}
255-
)
256349

257-
# Get calibration stats.
258-
calibration_stats = averaging.SimpleMovingAverage().get_calibration(stats)
259-
260-
# Delegate to algorithm-specific quantization.
261-
ctx = CalibratedQuantContext(
262-
weight=w_norm,
263-
how=how,
264-
calibration_stats=calibration_stats,
265-
abs_w=abs_w,
266-
contracting_axis=contracting_axis,
267-
restore_shape=restore_shape,
268-
path=path,
269-
)
270350
quantized_params[path] = quantize_fn(ctx)
271351

272352
# PTQ fallback for non-quantized params.
@@ -279,4 +359,7 @@ def quantize_params_with_calibration(
279359
ptq_quantized_params = flax.traverse_util.flatten_dict(ptq_quantized_params)
280360
quantized_params.update(ptq_quantized_params)
281361

362+
if isinstance(abstract_quantized_params, nnx.Module):
363+
quantized_params = nnx.to_pure_dict(nnx.state(quantized_params))
364+
return flax.traverse_util.unflatten_dict(quantized_params)
282365
return flax.traverse_util.unflatten_dict(quantized_params)

0 commit comments

Comments
 (0)