1818from typing import Any , Callable
1919
2020import flax
21+ from flax import nnx
2122import jax
2223from jax import numpy as jnp
2324from qwix ._src import averaging
2425from qwix ._src import flax_util
2526from qwix ._src import qconfig
26- from qwix ._src .core import qarray
2727from qwix ._src .providers import ptq
2828
2929
@@ -32,7 +32,7 @@ class CalibrationProvider(qconfig.QuantizationProvider, metaclass=abc.ABCMeta):
3232
3333 This provider handles the common boilerplate for all calibration providers:
3434 rule type checking, dimension validation, weight name lookup, and LHS
35- reshaping. Subclasses implement `_collect_stats` to define what happens
35+ reshaping. Subclasses implement `` _collect_stats` ` to define what happens
3636 with the validated, reshaped activations.
3737 """
3838
@@ -45,7 +45,16 @@ def get_stats_suffix(self) -> str:
4545 """Returns the suffix for the stats variable name (e.g., '_gptq')."""
4646
4747 @abc .abstractmethod
48- def _collect_stats (self , lhs : jax .Array , weight_name : str ) -> None :
48+ def _collect_stats (
49+ self ,
50+ lhs : jax .Array ,
51+ weight_name : str ,
52+ * ,
53+ module_path : tuple [str , ...],
54+ op_name : str ,
55+ op_id : str | None ,
56+ lhs_id : int ,
57+ ) -> None :
4958 """Collects statistics from the reshaped input activations.
5059
5160 Called after all validation passes. The LHS has already been reshaped
@@ -54,6 +63,10 @@ def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
5463 Args:
5564 lhs: Input activations reshaped to (ca, rest) format.
5665 weight_name: The name of the weight parameter for this operation.
66+ module_path: The current module path for this operation.
67+ op_name: The intercepted operation name.
68+ op_id: The operation identifier assigned by the quantization tracer.
69+ lhs_id: Python object id of the original lhs before reshaping.
5770 """
5871
5972 def dot_general (
@@ -63,11 +76,35 @@ def dot_general(
6376 dimension_numbers : jax .lax .DotDimensionNumbers ,
6477 * args ,
6578 rule : qconfig .QuantizationRule | None = None ,
79+ op_id : str | None = None ,
6680 ** kwargs ,
6781 ) -> jax .Array :
82+ """Intercepts supported weight-bearing ``dot_general`` ops for calibration.
83+
84+ Subclasses do not need to reimplement the common matching and validation
85+ logic. This method:
86+
87+ - resolves the active quantization rule and op id,
88+ - rejects unsupported dimension-number patterns,
89+ - identifies the weight parameter on the RHS,
90+ - reshapes the LHS to ``(contracting_dim, rest)``, and
91+ - delegates the actual stats handling to ``_collect_stats``.
92+
93+ Args:
94+ lhs: The left-hand side array.
95+ rhs: The right-hand side array.
96+ dimension_numbers: The dimension numbers for the dot_general operation.
97+ *args: Additional positional arguments to pass to dot_general.
98+ rule: The quantization rule to apply.
99+ op_id: The operation identifier assigned by the quantization tracer.
100+ **kwargs: Additional keyword arguments to pass to dot_general.
101+
102+ Returns:
103+ The result of the dot_general operation.
104+ """
68105 res = jax .lax .dot_general (lhs , rhs , dimension_numbers , * args , ** kwargs )
69- if rule is None :
70- rule , _ = self ._get_current_rule_and_op_id ('dot_general' )
106+ if rule is None or op_id is None :
107+ rule , op_id = self ._get_current_rule_and_op_id ('dot_general' )
71108
72109 rule_type = self .get_rule_type ()
73110 if not isinstance (rule , rule_type ):
@@ -84,16 +121,25 @@ def dot_general(
84121 # If we cannot identify the weight parameter, we skip calibration.
85122 return res
86123
124+ lhs_id = id (lhs )
87125 # Reorder lhs to (ca, rest) format.
88126 lhs = jnp .moveaxis (lhs , lhs_ca [0 ], 0 )
89127 lhs = lhs .reshape (lhs .shape [0 ], - 1 )
90128
91- self ._collect_stats (lhs , weight_name )
129+ self ._collect_stats (
130+ lhs ,
131+ weight_name ,
132+ module_path = tuple (map (str , flax_util .get_current_module_path ())),
133+ op_name = 'dot_general' ,
134+ op_id = op_id ,
135+ lhs_id = lhs_id ,
136+ )
92137
93138 return res
94139
95140 def einsum (self , einsum_str , * operands , ** kwargs ):
96- rule , _ = self ._get_current_rule_and_op_id ('einsum' )
141+ """Intercepts supported binary ``einsum`` ops via their lowered dot call."""
142+ rule , op_id = self ._get_current_rule_and_op_id ('einsum' )
97143 rule_type = self .get_rule_type ()
98144 if not isinstance (rule , rule_type ):
99145 return jnp .einsum (einsum_str , * operands , ** kwargs )
@@ -103,7 +149,7 @@ def einsum(self, einsum_str, *operands, **kwargs):
103149
104150 def stats_dot_general (lhs , rhs , dimension_numbers , * args , ** kwargs ):
105151 return self .dot_general (
106- lhs , rhs , dimension_numbers , * args , rule = rule , ** kwargs
152+ lhs , rhs , dimension_numbers , * args , rule = rule , op_id = op_id , ** kwargs
107153 )
108154
109155 with jax .disable_jit ():
@@ -124,16 +170,27 @@ def get_intercept_map(self) -> dict[str, Callable[..., Any]]:
124170class SinglePassCalibrationProvider (CalibrationProvider , metaclass = abc .ABCMeta ):
125171 """Calibration provider that collects single-pass statistics.
126172
127- This provider implements the simple stats template: `compute_stats`
128- produces a dict of arrays, which are accumulated into the `quant_stats`
129- collection using `SimpleMovingAverage`.
173+ This provider implements the simple stats template: `` compute_stats` `
174+ produces a dict of arrays, which are accumulated into the `` quant_stats` `
175+ collection using `` SimpleMovingAverage` `.
130176 """
131177
132178 @abc .abstractmethod
133179 def compute_stats (self , lhs : jax .Array ) -> dict [str , Any ]:
134180 """Computes statistics from the input array."""
135181
136- def _collect_stats (self , lhs : jax .Array , weight_name : str ) -> None :
182+ def _collect_stats (
183+ self ,
184+ lhs : jax .Array ,
185+ weight_name : str ,
186+ * ,
187+ module_path : tuple [str , ...],
188+ op_name : str ,
189+ op_id : str | None ,
190+ lhs_id : int ,
191+ ) -> None :
192+ """Accumulates one batch of single-pass calibration statistics."""
193+ del module_path , op_name , op_id , lhs_id
137194 stats = self .compute_stats (lhs )
138195 aggregator = averaging .SimpleMovingAverage ()
139196 var_name = weight_name + self .get_stats_suffix ()
@@ -146,7 +203,7 @@ def _collect_stats(self, lhs: jax.Array, weight_name: str) -> None:
146203
147204def normalize_weight (
148205 x : jax .Array , contraction_axis : int
149- ) -> tuple [jax .Array , Callable [..., qarray . MaybeQArray ]]:
206+ ) -> tuple [jax .Array , Callable [..., Any ]]:
150207 """Normalizes a weight tensor into (rows, columns) format.
151208
152209 Reshapes a weight tensor of arbitrary rank into a 2D matrix where the
@@ -175,7 +232,7 @@ def restore_shape(x):
175232
176233@dataclasses .dataclass (frozen = True )
177234class CalibratedQuantContext :
178- """Context containing a weight, calibration stats, and quantization metadata .
235+ """A weight prepared for algorithm-specific quantization.
179236
180237 Attributes:
181238 weight: Normalized weight in (rows, columns) format.
@@ -192,10 +249,58 @@ class CalibratedQuantContext:
192249 calibration_stats : dict [str , jax .Array ]
193250 abs_w : ptq .WithAux
194251 contracting_axis : int
195- restore_shape : Callable [..., qarray . MaybeQArray ]
252+ restore_shape : Callable [..., Any ]
196253 path : tuple [str , ...]
197254
198255
256+ def extract_calibrated_quant_context (
257+ path : tuple [str , ...],
258+ weight : jax .Array ,
259+ abs_w : ptq .WithAux ,
260+ stats : Any ,
261+ ) -> CalibratedQuantContext | None :
262+ """Extracts the calibration context for a single weight.
263+
264+ Args:
265+ path: The dictionary path for this weight.
266+ weight: Floating-point weight to quantize.
267+ abs_w: The WithAux wrapper from the abstract quantized tree.
268+ stats: The calibration statistics for this weight.
269+
270+ Returns:
271+ The CalibratedQuantContext, or None if the weight cannot be quantized
272+ (e.g., if there is not exactly one contracting axis).
273+ """
274+ # Get the contracting axis by assuming that all non-contracting axes
275+ # are in channelwise_axes.
276+ contracting_axis = set (range (weight .ndim )) - set (abs_w .how .channelwise_axes )
277+ if len (contracting_axis ) != 1 :
278+ # Fallback to PTQ if we can't identify a single contracting axis.
279+ return None
280+ contracting_axis = contracting_axis .pop ()
281+
282+ # Normalize the weight to (ra, ca) format.
283+ w_norm , restore_shape = normalize_weight (weight , contracting_axis )
284+ how = dataclasses .replace (abs_w .how , channelwise_axes = [0 ])
285+ if contracting_axis in how .tiled_axes :
286+ how = dataclasses .replace (
287+ how , tiled_axes = {1 : how .tiled_axes [contracting_axis ]}
288+ )
289+
290+ # Get calibration stats.
291+ calibration_stats = averaging .SimpleMovingAverage ().get_calibration (stats )
292+
293+ return CalibratedQuantContext (
294+ weight = w_norm ,
295+ how = how ,
296+ calibration_stats = calibration_stats ,
297+ abs_w = abs_w ,
298+ contracting_axis = contracting_axis ,
299+ restore_shape = restore_shape ,
300+ path = path ,
301+ )
302+
303+
199304def quantize_params_with_calibration (
200305 params : Any ,
201306 abstract_quantized_params : Any ,
@@ -210,7 +315,7 @@ def quantize_params_with_calibration(
210315 This function handles the common boilerplate for all calibration-based
211316 quantization algorithms (GPTQ, QEP, AWQ): parameter iteration, stats
212317 lookup, weight normalization, and PTQ fallback. The algorithm-specific
213- logic is provided via `quantize_fn`.
318+ logic is provided via `` quantize_fn` `.
214319
215320 Args:
216321 params: The floating-point param tree to quantize.
@@ -237,36 +342,11 @@ def quantize_params_with_calibration(
237342 not_quantized_params [path ] = w
238343 continue
239344
240- # Get the contracting axis by assuming that all non-contracting axes
241- # are in channelwise_axes.
242- contracting_axis = set (range (w .ndim )) - set (abs_w .how .channelwise_axes )
243- if len (contracting_axis ) != 1 :
244- # Fallback to PTQ if we can't identify a single contracting axis.
345+ ctx = extract_calibrated_quant_context (path , w , abs_w , stats )
346+ if ctx is None :
245347 not_quantized_params [path ] = w
246348 continue
247- contracting_axis = list (contracting_axis )[0 ]
248-
249- # Normalize the weight to (ra, ca) format.
250- w_norm , restore_shape = normalize_weight (w , contracting_axis )
251- how = dataclasses .replace (abs_w .how , channelwise_axes = [0 ])
252- if contracting_axis in how .tiled_axes :
253- how = dataclasses .replace (
254- how , tiled_axes = {1 : how .tiled_axes [contracting_axis ]}
255- )
256349
257- # Get calibration stats.
258- calibration_stats = averaging .SimpleMovingAverage ().get_calibration (stats )
259-
260- # Delegate to algorithm-specific quantization.
261- ctx = CalibratedQuantContext (
262- weight = w_norm ,
263- how = how ,
264- calibration_stats = calibration_stats ,
265- abs_w = abs_w ,
266- contracting_axis = contracting_axis ,
267- restore_shape = restore_shape ,
268- path = path ,
269- )
270350 quantized_params [path ] = quantize_fn (ctx )
271351
272352 # PTQ fallback for non-quantized params.
@@ -279,4 +359,7 @@ def quantize_params_with_calibration(
279359 ptq_quantized_params = flax .traverse_util .flatten_dict (ptq_quantized_params )
280360 quantized_params .update (ptq_quantized_params )
281361
362+ if isinstance (abstract_quantized_params , nnx .Module ):
363+ quantized_params = nnx .to_pure_dict (nnx .state (quantized_params ))
364+ return flax .traverse_util .unflatten_dict (quantized_params )
282365 return flax .traverse_util .unflatten_dict (quantized_params )
0 commit comments