Skip to content

## Add Quantization Error Propagation (QEP) support to Qwix#217

Closed
copybara-service[bot] wants to merge 0 commit intomainfrom
test_872186505
Closed

## Add Quantization Error Propagation (QEP) support to Qwix#217
copybara-service[bot] wants to merge 0 commit intomainfrom
test_872186505

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Feb 19, 2026

Add Quantization Error Propagation (QEP) support to Qwix

QEP extends standard GPTQ by compensating for cascading quantization noise
introduced by preceding layers during inference. While GPTQ minimizes
||W @ X - W_q @ X||^2 assuming perfect float inputs, QEP actively minimizes
||W @ X_float - W_q @ X_q||^2.

This is achieved by computing an input cross-correlation statistic (H_delta)
and applying a localized weight correction (W_corrected = W + alpha * (W @ H_delta @ inv(H)))
prior to standard GPTQ rounding.

API & Usage

The primary entry point is qep.quantize(...). Because QEP must measure the
accumulated error from previously quantized layers, it orchestrates a multi-pass
calibration loop stage-by-stage rather than relying on a single forward pass.

result = qep.quantize(
    model=model,
    # calibration_data must be reiterable since QEP sweeps it multiple times
    calibration_data=dataset_iterator_factory, 
    rules=[qep.QepRule(module_path='Dense_.*', weight_qtype=jnp.int8)],
    variables=variables
)

# The returned QepResult contains everything needed for inference:
inference_output = result.model.apply(
     {'params': result.params, 'quant_stats': result.quant_stats},
     sample_input
)

For offline or distributed pipelines where statistics are pre-computed remotely,
qep.quantize_params() can be directly invoked to apply the QEP correction
and GPTQ rounding to float weights without re-running the model graph.

Key modifications

  • qep_core.py: Pure-JAX algorithms for QEP statistics (compute_qep_stats) and the core weight shifting logic (weight_correct).
  • qep.py: The stagewise orchestrator (qep.quantize). Dynamically discovers interconnected topological stages, applies a two-pass (float vs. quantized) calibration loop per batch, and updates weights progressively through the network.
  • calibration.py: Refactored the core CalibrationProvider mechanics to decouple single-pass logic, enabling robust multi-pass activation interception for QEP.
  • QepRule: New configuration struct extending GptqRule with hyperparameter tuning (correction_factor, damping_factor).

@copybara-service copybara-service bot force-pushed the test_872186505 branch 3 times, most recently from bf19872 to ed57104 Compare February 19, 2026 06:18
@copybara-service copybara-service bot force-pushed the test_872186505 branch 2 times, most recently from f017668 to 4c65d66 Compare March 17, 2026 15:16
@copybara-service copybara-service bot changed the title Adds Quantization Error Propagation (QEP) Algorithm ## Add Quantization Error Propagation (QEP) support to Qwix Mar 17, 2026
@copybara-service copybara-service bot force-pushed the test_872186505 branch 2 times, most recently from 1258d5a to e6d1c85 Compare March 24, 2026 00:40
@copybara-service copybara-service bot closed this Mar 24, 2026
@copybara-service copybara-service bot deleted the test_872186505 branch March 24, 2026 00:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants