diff --git a/.gitignore b/.gitignore index abd6ff3..2ab4133 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ dist/ *.gz *-ubyte *.pth +*.pt *.onnx *.npz onnx/* diff --git a/DeepQuant/CustomForwards/Activations.py b/DeepQuant/CustomForwards/Activations.py index 2a30848..d114513 100644 --- a/DeepQuant/CustomForwards/Activations.py +++ b/DeepQuant/CustomForwards/Activations.py @@ -4,63 +4,28 @@ # # Federico Brancasi - import torch.nn as nn -from torch import Tensor from brevitas.nn.quant_layer import QuantNonLinearActLayer +from torch import Tensor -class InnerForwardImplWrapperActivation(nn.Module): - """ - A small wrapper around the activation function of a Brevitas QuantActivation layer. - - This wrapper exposes the original activation function as a standalone submodule - so that FX tracing can display it as a separate node. - """ +class WrapperActivation(nn.Module): + """Expose inner activation so FX sees it as a leaf.""" def __init__(self, actImpl: nn.Module) -> None: - """ - Args: - act_impl: The original activation function module (e.g. an instance of nn.ReLU). - """ super().__init__() self.actImpl = actImpl def forward(self, quantInput: Tensor) -> Tensor: - """ - Applies the wrapped activation function. - - Args: - quant_input: Input tensor after input quantization. - - Returns: - Output tensor after applying the activation. - """ return self.actImpl(quantInput) -def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor: - """ - Unrolled forward pass for a Brevitas QuantActivation layer. - - Steps: - 1) Apply self.input_quant to the input. - 2) Apply the activation function via the wrapped activation implementation. - 3) Apply self.act_quant to the activation output. - - Args: - self: The QuantNonLinearActLayer instance. - inp: The input tensor. - - Returns: - Output tensor after applying activation and output quantization. - """ +def activationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor: + """Unroll input→act→output quant steps.""" quantInput = self.input_quant(inp) if self.input_quant is not None else inp - # Use the wrapped activation if available; otherwise pass through. if hasattr(self, "wrappedActImpl"): output = self.wrappedActImpl(quantInput) else: output = quantInput - import IPython; IPython.embed() quantOutput = self.act_quant(output) if self.act_quant is not None else output return quantOutput diff --git a/DeepQuant/CustomForwards/Linear.py b/DeepQuant/CustomForwards/Linear.py deleted file mode 100644 index 9043677..0000000 --- a/DeepQuant/CustomForwards/Linear.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - - -import torch.nn as nn -from torch import Tensor -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer - - -class InnerForwardImplWrapperLinear(nn.Module): - """ - A small wrapper around the 'innerForwardImpl' of a Brevitas QuantLinear - (QuantWeightBiasInputOutputLayer). - - We want to expose the logic within 'innerForwardImpl' as a standalone - submodule, so that FX tracing can see it as a leaf. - """ - - def __init__(self, innerForwardImpl: nn.Module) -> None: - """ - Args: - innerForwardImpl: The original function that processes - (quant_input, quant_weight, quant_bias). - """ - super().__init__() - self.innerForwardImpl = innerForwardImpl - - def forward( - self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor - ) -> Tensor: - """ - Applies the wrapped innerForwardImpl. - - Args: - quant_input: Input after input_quant. - quant_weight: Weight after weight_quant. - quant_bias: Bias after bias_quant (or None). - - Returns: - A torch.Tensor with the linear operation applied. - """ - return self.innerForwardImpl(quantInput, quantWeight, quantBias) - - -def quantWBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor: - """ - Unrolled forward pass for a Brevitas QuantLinear: - - Steps: - 1) self.input_quant - 2) self.weight_quant - 3) self.bias_quant (if bias is present) - 4) innerForwardImpl (wrapped) - 5) self.output_quant - - Args: - self: The QuantWeightBiasInputOutputLayer instance. - inp: The input Tensor to be processed. - - Returns: - Output Tensor after the unrolled quantized linear steps. - """ - quantInput = self.input_quant(inp) - quantWeight = self.weight_quant(self.weight) - - quantBias = None - if self.bias is not None: - quantBias = self.bias_quant(self.bias, quantInput, quantWeight) - - output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias) - quantOutput = self.output_quant(output) - return quantOutput diff --git a/DeepQuant/CustomForwards/MultiHeadAttention.py b/DeepQuant/CustomForwards/MultiHeadAttention.py index 76fe3ae..5e31130 100644 --- a/DeepQuant/CustomForwards/MultiHeadAttention.py +++ b/DeepQuant/CustomForwards/MultiHeadAttention.py @@ -4,43 +4,22 @@ # # Federico Brancasi - import math + import torch import torch.nn.functional as F -from torch import Tensor from brevitas.nn.quant_mha import QuantMultiheadAttention +from torch import Tensor -def unrolledQuantMhaForward( +def mhaForward( self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor ) -> Tensor: - """ - Export-friendly forward that explicitly unrolls the multi-head logic. - - Steps: - 1) Q, K, V projections - 2) Reshapes & permutes for multi-head - 3) Scales queries - 4) Applies softmax and intermediate quantizations - 5) Out projection - - Args: - self: The QuantMultiheadAttention instance. - query: The query tensor of shape [sequence_len, batch_size, embed_dim]. - key: The key tensor, same shape as query. - value: The value tensor, same shape as query. - - Returns: - A torch.Tensor of shape [sequence_len, batch_size, embed_dim] - after the unrolled MHA steps. - """ - # 1) Q, K, V projections + """Explicit, export-friendly MHA forward.""" qOut = self.q_proj(query) kOut = self.k_proj(key) vOut = self.v_proj(value) - # 2) Multi-head reshape seqLen, batchSize, embedDim = qOut.shape headDim = embedDim // self.num_heads @@ -60,11 +39,9 @@ def unrolledQuantMhaForward( .reshape(batchSize * self.num_heads, seqLen, headDim) ) - # 3) Scale queries, then quantize qScaled = qOut / math.sqrt(headDim) qScaled = self.q_scaled_quant(qScaled) - # 4) Transpose + quantize K, compute attention weights k_t = kOut.transpose(-2, -1) k_t = self.k_transposed_quant(k_t) @@ -73,7 +50,6 @@ def unrolledQuantMhaForward( attnWeights = F.softmax(attnWeights, dim=-1) attnWeights = self.attn_output_weights_quant(attnWeights) - # 5) Quantize V, multiply, reshape back, and final out projection vOut = self.v_quant(vOut) attnOutput = torch.bmm(attnWeights, vOut) diff --git a/DeepQuant/CustomForwards/WBIOL.py b/DeepQuant/CustomForwards/WBIOL.py new file mode 100644 index 0000000..81e9e65 --- /dev/null +++ b/DeepQuant/CustomForwards/WBIOL.py @@ -0,0 +1,36 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import torch.nn as nn +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from torch import Tensor + + +class WrapperWBIOL(nn.Module): + """Expose `inner_forward_impl` as a standalone submodule.""" + + def __init__(self, innerForwardImpl: nn.Module) -> None: + super().__init__() + self.innerForwardImpl = innerForwardImpl + + def forward( + self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor + ) -> Tensor: + return self.innerForwardImpl(quantInput, quantWeight, quantBias) + + +def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor: + """Quant-in → quant-weight/bias → matmul → quant-out.""" + quantInput = self.input_quant(inp) + quantWeight = self.weight_quant(self.weight) + + quantBias = None + if self.bias is not None: + quantBias = self.bias_quant(self.bias, quantInput, quantWeight) + + output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias) + quantOutput = self.output_quant(output) + return quantOutput diff --git a/DeepQuant/CustomTracer.py b/DeepQuant/CustomTracer.py deleted file mode 100644 index fab5dbe..0000000 --- a/DeepQuant/CustomTracer.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Custom Brevitas tracer implementation for handling module transformation and tracing. -""" - -import torch.nn as nn -from brevitas.fx.brevitas_tracer import ( - _symbolic_trace, - _is_brevitas_leaf_module, - Tracer, -) -from torch.fx.graph_module import GraphModule -from typing import List, Type, Optional - - -class CustomBrevitasTracer(Tracer): - """ - A custom tracer that allows explicit control over leaf and non-leaf module designation. - - This tracer extends the Brevitas tracer to provide fine-grained control over which modules - should be treated as leaf modules (traced as a single unit) vs non-leaf modules - (traced into their constituent operations). - """ - - def __init__( - self, - leafClasses: Optional[List[Type[nn.Module]]] = None, - nonLeafClasses: Optional[List[Type[nn.Module]]] = None, - debug: bool = False, - ) -> None: - """ - Initialize the custom tracer with optional leaf and non-leaf module lists. - - Args: - leaf_classes: List of module classes to be treated as leaf modules. - non_leaf_classes: List of module classes to be treated as non-leaf modules. - debug: Whether to print debug information during tracing. - """ - super().__init__() - self.leafClasses = leafClasses if leafClasses is not None else [] - self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else [] - self.debug = debug - - def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None: - """ - Add a module class to the list of leaf modules. - - Args: - module_cls: The module class to register as a leaf module. - """ - if moduleCls not in self.leafClasses: - self.leafClasses.append(moduleCls) - - def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None: - """ - Add a module class to the list of non-leaf modules. - - Args: - module_cls: The module class to register as a non-leaf module. - """ - if moduleCls not in self.nonLeafClasses: - self.nonLeafClasses.append(moduleCls) - - def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool: - """ - Determine whether a module should be treated as a leaf module. - - The decision follows this priority: - 1. If module is in leaf_classes, treat as leaf - 2. If module is in non_leaf_classes, treat as non-leaf - 3. Otherwise, fall back to default Brevitas behavior - - Args: - m: The module to check. - module_qualified_name: The fully qualified name of the module. - - Returns: - bool: True if the module should be treated as a leaf module, False otherwise. - """ - # First check explicitly registered classes - if any(isinstance(m, lc) for lc in self.leafClasses): - return True - if any(isinstance(m, nlc) for nlc in self.nonLeafClasses): - return False - # Fall back to default Brevitas behavior - return _is_brevitas_leaf_module(m, moduleQualifiedName) - - -def customBrevitasTrace( - root: nn.Module, concreteArgs=None, tracer: Optional[CustomBrevitasTracer] = None -) -> GraphModule: - """ - Create an FX GraphModule using the CustomBrevitasTracer. - - Args: - root: The root module to trace. - concrete_args: Concrete arguments to use for tracing. - tracer: Optional pre-configured CustomBrevitasTracer instance. - - Returns: - GraphModule: The traced module. - """ - if tracer is None: - tracer = CustomBrevitasTracer() - return _symbolic_trace(tracer, root, concreteArgs) diff --git a/DeepQuant/Export.py b/DeepQuant/Export.py new file mode 100644 index 0000000..1b55f42 --- /dev/null +++ b/DeepQuant/Export.py @@ -0,0 +1,52 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.nn as nn + +from DeepQuant.Pipeline.DequantUnify import mergeDequants +from DeepQuant.Pipeline.Injection import injectCustomForwards +from DeepQuant.Pipeline.OnnxExport import exportToOnnx +from DeepQuant.Pipeline.OriginalTracing import traceOriginalModel +from DeepQuant.Pipeline.QuantSplit import splitQuantNodes + + +def brevitasToTrueQuant( + model: nn.Module, + exampleInput: torch.Tensor, + exportPath: Optional[Union[str, Path]] = Path.cwd() / "Tests" / "ONNX", + debug: bool = False, +) -> nn.Module: + """ + Export a Brevitas model to an FX GraphModule with unrolled quantization operations. + + This function applies a series of transformations to make the quantization steps + explicit in the model's computation graph, enabling efficient integer-only execution. + """ + + # Pipeline Step 1: Trace the original model + tracedModel, originalOutput = traceOriginalModel(model, exampleInput, debug) + + # Pipeline Step 2: Inject custom forward implementations + transformedModel, transformedOutput = injectCustomForwards( + tracedModel, exampleInput, originalOutput, debug + ) + + # Pipeline Step 3: Split quantization nodes + splitModel, splitOutput = splitQuantNodes( + transformedModel, exampleInput, transformedOutput, debug + ) + + # Pipeline Step 4: Unify dequant nodes + unifiedModel, _ = mergeDequants(splitModel, exampleInput, splitOutput, debug) + + # Pipeline Step 5: Export to ONNX + onnxFile, _ = exportToOnnx(unifiedModel, exampleInput, exportPath, debug) + + return unifiedModel diff --git a/DeepQuant/ExportBrevitas.py b/DeepQuant/ExportBrevitas.py deleted file mode 100644 index 0ab87f0..0000000 --- a/DeepQuant/ExportBrevitas.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -import torch -import torch.nn as nn -from pathlib import Path - -from DeepQuant.Injects.Transformations import ( - LinearTransformation, # Transformation for quantized linear layers (QuantLinear, QuantConv2d) - ActivationTransformation, # Transformation for quantized activation functions (QuantReLU, etc.) - MHATransformation, # Transformation for quantized multi-head attention modules -) -from DeepQuant.Injects.Executor import ( - TransformationExecutor, -) # Orchestrates sequential transformations -from .CustomTracer import ( - CustomBrevitasTracer, - customBrevitasTrace, -) # Custom FX tracer for Brevitas modules -from DeepQuant.QuantManipulation.ParameterExtractor import ( - extract_brevitas_proxy_params, # Extracts quantization parameters from Brevitas proxies - print_quant_params, # Displays quantization parameters in a readable format -) -from DeepQuant.QuantManipulation.QuantNodesDivider import ( - split_quant_nodes, -) # Splits quantization nodes into Quant/Dequant pairs -from brevitas.export.inference import ( - quant_inference_mode, -) # Inference mode for quantized models -from brevitas.export import ( - export_onnx_qcdq, -) # Native Brevitas ONNX export functions -from DeepQuant.QuantManipulation.DequantModifier import ( - unifyLinearDequants, -) # Unifies dequant nodes in linear layers -from brevitas.fx import brevitas_symbolic_trace # Brevitas-specific symbolic tracing -from DeepQuant.Utils.GraphPrinter import ( - GraphModulePrinter, -) # Custom Graph Printer -from DeepQuant.Utils.FxInterpreter import NodeTracer - - -# ANSI color codes for improved debug output readability -BLUE = "\033[94m" -RED = "\033[31m" -ENDC = "\033[0m" - - -def exportBrevitas( - model: nn.Module, exampleInput: torch.Tensor, debug: bool = False -) -> nn.Module: - """ - Export a Brevitas model to an FX GraphModule with unrolled quantization operations. - - This function applies a series of transformations to make the quantization steps - explicit in the model's computation graph, then traces the transformed model using - a custom FX tracer. - - Args: - model: The Brevitas-based model to export. - example_input: A representative input tensor for shape tracing. - debug: If True, prints transformation progress information. - - Returns: - nn.Module: An FX GraphModule with explicit quantization operations. - """ - - EXPORT_FOLDER = Path().cwd() - if Path().cwd().name == "DeepQuant": - EXPORT_FOLDER = EXPORT_FOLDER / "Tests/ONNX" - EXPORT_FOLDER.mkdir(parents=True, exist_ok=True) - - printer = GraphModulePrinter() - - ############################################################################### - # 1. Original Network - ############################################################################### - - model = brevitas_symbolic_trace( - model - ) # Symbolically trace the original model using Brevitas - if debug: - print("\n\n=== 1. Original Network ===\n") - printer.print_tabular(model) - print() - - with ( - torch.no_grad(), - quant_inference_mode(model), - ): # Disable gradients and use quantized inference mode - outputModel = model( - exampleInput - ) # Compute original model output on example input for validation - - # export_onnx_qcdq( # Export original model to ONNX format with QCDQ (Quant-Cast-DeQuant) nodes - # model, # Model to export - # args=exampleInput, # Example input for tracing - # export_path=EXPORT_FOLDER / "1_model_qcdq_original.onnx", - # opset_version=13, - # ) - - ############################################################################### - # 2. Injection of New Modules - ############################################################################### - - # Create transformation sequence in appropriate order - transformations = [ - MHATransformation(), # Multi-head attention transformation (applied first) - LinearTransformation(), # Quantized linear layers transformation - ActivationTransformation(), # Quantized activation functions transformation - ] - - # Initialize custom tracer for Brevitas - tracer = CustomBrevitasTracer(debug=debug) - - # Create and execute transformation sequence using the executor - executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) - transformedModel = executor.execute( - model, exampleInput - ) # Apply all transformations to the model - - # Generate FX graph using the same tracer for consistency - fxModel = customBrevitasTrace( - root=transformedModel, # Transformed model to trace - concreteArgs=(exampleInput,), - tracer=tracer, # Use same tracer to maintain consistency with transformations - ) - fxModel.recompile() # Recompile the FX module to update its forward method - with torch.no_grad(): - outputFxModel = fxModel(exampleInput) # Compute transformed model output - - if isinstance(outputModel, tuple): - outputModel = outputModel[0] - - if torch.allclose( - outputFxModel, outputModel, atol=1e-5 - ): # Check numerical equivalence within tolerance - if debug: - print(f"{BLUE} ✓ Injection of New Modules: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if outputs differ significantly - f"{RED} ✗ Injection of New Modules changed the output significantly{ENDC}" - ) - - if debug: - print(f"{BLUE} ✓ All transformations completed successfully!{ENDC}") - if debug: - print("\n=== 2. Network after the Injection of New Modules ===\n") - printer.print_tabular(fxModel) - - # export_onnx_qcdq( # Export transformed model to ONNX - # fxModel, # Transformed model - # args=exampleInput, - # export_path=EXPORT_FOLDER / "2_model_qcdq_transformed.onnx", - # opset_version=13, - # ) - - ############################################################################### - # 3. Extraction of Parameters & Split of Quant Nodes - ############################################################################### - - # Extract quantization parameters from the network's proxies - proxyParams = extract_brevitas_proxy_params( - fxModel - ) # Get scale, zero_point, bit_width for each quant node - - if debug: - print_quant_params( - proxyParams - ) # Display extracted parameters in a readable format - - # Split quantization nodes into separate Quant and Dequant nodes - splitFxModel = split_quant_nodes( - fxModel, proxyParams, debug - ) # Transform quant nodes into quant-dequant pairs - splitFxModel.recompile() # Recompile to update forward method with new nodes - - with torch.no_grad(): - outputFxModelSplitQuant = splitFxModel( - exampleInput - ) # Compute output after node splitting - - # print("Output Original: ", output_model) - # print("Output Split: ", output_fx_model_split_quant) - - if torch.allclose( - outputModel, outputFxModelSplitQuant, atol=1e-5 - ): # Verify numerical consistency - if debug: - print(f"{BLUE} ✓ Split of Quant Nodes: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if inconsistent - f"{RED} ✗ Split of Quant Nodes changed the output significantly{ENDC}" - ) - - if debug: - print("\n=== 3. Network after the Split of Quant Nodes ===\n") - printer.print_tabular(splitFxModel) - print() - - torch.onnx.export( - splitFxModel, - args=exampleInput, - f=EXPORT_FOLDER / "3_model_splitted_quant.onnx", - opset_version=13, - keep_initializers_as_inputs=True, - do_constant_folding=False, - ) - - # return split_fx_model - - ############################################################################### - # 4. Modification of Dequant Nodes (shift them down) - ############################################################################### - - # Perform the unification of linear dequant nodes (move dequantization after computation) - fxModelUnified = unifyLinearDequants(splitFxModel, debug=debug) - fxModelUnified.recompile() # Recompile to update forward method with new node arrangement - - # Compute output after dequant node unification - with torch.no_grad(): - outputFxModelDequantModified = fxModelUnified( - exampleInput - ) # Output after dequant modification - - print("Output Original: ", outputModel) - print("Output Dequant Modified: ", outputFxModelDequantModified) - - if debug: - print("\n=== 4. Network after the Modification of Dequant Nodes ===\n") - printer.print_tabular(fxModelUnified) - print() - - # # Verify numerical consistency after dequant modification - # if torch.allclose( - # output_model, output_fx_model_dequant_modified, atol=1e-5 - # ): # Verify numerical consistency - # if debug: - # print(f"{BLUE} ✓ Modification of Dequant Nodes: output is consistent{ENDC}") - # else: - # raise RuntimeError( # Raise error if inconsistent - # f"{RED} ✗ Modification of Dequant Nodes changed the output significantly{ENDC}" - # ) - - # if debug: - # print("\n=== 4. Network after the Modification of Dequant Nodes ===\n") - # printer.print_tabular(fx_model_unified) - # print() - - onnxFile: str = EXPORT_FOLDER / "4_model_dequant_moved.onnx" - torch.onnx.export( - fxModelUnified, - args=exampleInput, - # f=EXPORT_FOLDER / "4_model_dequant_moved.onnx", - f=onnxFile, - opset_version=13, - keep_initializers_as_inputs=True, - do_constant_folding=False, - input_names=["input"], - output_names=["output"], - ) - - # Verify numerical consistency after dequant modification - if torch.allclose( - outputModel, outputFxModelDequantModified, atol=1e-5 - ): # Verify numerical consistency - if debug: - print(f"{BLUE} ✓ Modification of Dequant Nodes: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if inconsistent - f"{RED} ✗ Modification of Dequant Nodes changed the output significantly{ENDC}" - ) - - import numpy as np - import onnxruntime as ort - import onnx - - # Step 2: Load the model and run shape inference - # (All tensors in ONNX graph should have explicit shape information) - onnxModel = onnx.load(onnxFile) - inferredModel = onnx.shape_inference.infer_shapes(onnxModel) - - # Step 3: Save the model with inferred shapes - onnx.save(inferredModel, onnxFile) - - inputFile: str = EXPORT_FOLDER / "inputs.npz" - np.savez(inputFile, input=exampleInput.cpu()) - print("Input npz: ", exampleInput) - print(f"Input data saved to {inputFile} ✓") - - # onnxruntime to run the exported model - ortSession: ort.InferenceSession = ort.InferenceSession(onnxFile) - ortInputs: dict = {"input": exampleInput.cpu().numpy()} - ortOutput: np.ndarray = ortSession.run(None, ortInputs)[0] - - outputFile: str = EXPORT_FOLDER / "outputs.npz" - np.savez(outputFile, output=ortOutput) - print("Output npz: ", ortOutput) - print(f"Output data saved to {outputFile} ✓") - - return fxModelUnified # Return the final optimized FX GraphModule diff --git a/DeepQuant/Injects/Base.py b/DeepQuant/Injects/Base.py deleted file mode 100644 index e9d72b9..0000000 --- a/DeepQuant/Injects/Base.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Base transformation infrastructure for the Brevitas export process. - -This module provides the foundational TransformationPass class that handles: -- Module type matching -- Forward method injection -- Output validation -- Recursive submodule transformation -""" - -import torch -import torch.nn as nn -from abc import ABC, abstractmethod -from typing import Any, Optional, Union, Tuple -from ..CustomTracer import CustomBrevitasTracer - - -class TransformationPass(ABC): - """ - Generic transformation pass for modifying Brevitas modules. - - A transformation pass targets specific module types and applies custom forward - implementations while ensuring output consistency. - """ - - def __init__( - self, - moduleCls: Union[type, Tuple[type, ...]], - validationTol: float = 1e-6, - ) -> None: - """ - Initialize a transformation pass. - - Args: - module_cls: Module class(es) this transformation targets. - injection_fn: Function that modifies the module's forward pass. - validation_tol: Tolerance for numerical comparison in validation. - """ - self.moduleCls = moduleCls - self.validationTol = validationTol - - def checkModuleType(self, module: nn.Module) -> bool: - """ - Check if a module is an instance of the target class(es). - - Args: - module: Module to check. - - Returns: - bool: True if module is an instance of self.module_cls. - """ - return isinstance(module, self.moduleCls) - - @abstractmethod - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject the custom forward implementation into a module. - - Args: - module: Module whose forward method will be replaced. - tracer: Optional tracer for registering module classes. - """ - pass - - def validateTransformation( - self, outputBefore: Any, outputAfter: Any, atol: Optional[float] = None - ) -> bool: - """ - Validate transformation by comparing outputs. - - Args: - output_before: Model output before transformation. - output_after: Model output after transformation. - atol: Optional custom tolerance for comparison. - - Returns: - bool: True if outputs match within tolerance. - """ - if atol is None: - atol = self.validationTol - return torch.allclose(outputBefore, outputAfter, atol=atol) - - def transform( - self, model: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> bool: - """ - Apply the transformation to all matching submodules. - - Args: - model: Model containing submodules to transform. - tracer: Optional tracer for registering transformed modules. - - Returns: - bool: True if any modules were transformed. - """ - transformDone = False - for _, submodule in model.named_modules(): - if self.checkModuleType(submodule): - self.injectForward(submodule, tracer) - transformDone = True - return transformDone diff --git a/DeepQuant/Injects/Executor.py b/DeepQuant/Injects/Executor.py deleted file mode 100644 index e41f3e9..0000000 --- a/DeepQuant/Injects/Executor.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Executor module for handling transformation sequences in the Brevitas export process. -""" - -import torch -import torch.nn as nn -from typing import List, Optional -from .Base import TransformationPass -from ..CustomTracer import CustomBrevitasTracer - -# ANSI color codes -BLUE = "\033[94m" -RED = "\033[91m" -ENDC = "\033[0m" - - -class TransformationExecutor: - """ - Manages and executes a sequence of model transformations. - - The executor applies each transformation in sequence, validating that model outputs - remain consistent after each transformation step. - """ - - def __init__( - self, - transformations: List[TransformationPass], - debug: bool = False, - tracer: Optional[CustomBrevitasTracer] = None, - ) -> None: - """ - Initialize the transformation executor. - - Args: - transformations: List of transformation passes to apply. - debug: Whether to print debug information during execution. - tracer: Optional CustomBrevitasTracer instance for module registration. - """ - self.transformations = transformations - self.debug = debug - self.tracer = tracer - - def execute(self, model: nn.Module, exampleInput: torch.Tensor) -> nn.Module: - """ - Execute all transformations on the model in sequence. - - For each transformation: - 1. Apply the transformation - 2. Validate that model outputs remain consistent - 3. Update the reference output for the next transformation - - Args: - model: The PyTorch model to transform. - example_input: A representative input tensor for validation. - - Returns: - nn.Module: The transformed model. - - Raises: - RuntimeError: If any transformation results in output mismatch. - """ - model.eval() - with torch.no_grad(): - outputBefore = model(exampleInput) - if isinstance(outputBefore, tuple): - outputBefore = outputBefore[0] - - for transformation in self.transformations: - if transformation.transform(model, tracer=self.tracer): - outputAfter = model(exampleInput) - if isinstance(outputAfter, tuple): - outputAfter = outputAfter[0] - - if not transformation.validateTransformation( - outputBefore, outputAfter - ): - raise RuntimeError( - f"{RED} ✗ {transformation.__class__.__name__} failed - outputs mismatch{ENDC}" - ) - - if self.debug: - print( - f"{BLUE} ✓ {transformation.__class__.__name__} transformation successful\n{ENDC}" - f" leafClasses: {self.tracer.leafClasses}\n" - f" nonLeafClasses: {self.tracer.nonLeafClasses}\n" - ) - - outputBefore = outputAfter - - return model diff --git a/DeepQuant/Injects/Transformations.py b/DeepQuant/Injects/Transformations.py deleted file mode 100644 index 9a0e031..0000000 --- a/DeepQuant/Injects/Transformations.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Transformation classes for different types of Brevitas modules. - -This module provides specific transformation passes for each type of quantized module: -- Linear layers (QuantLinear, QuantConv2d) -- Activation functions (QuantReLU, QuantSigmoid) -- Multi-head attention (QuantMultiheadAttention) - -Each transformation class implements the abstract injectForward method from TransformationPass -to define its specific module transformation logic. -""" - -import torch.nn as nn -from typing import Optional -from brevitas.nn.quant_layer import ( - QuantWeightBiasInputOutputLayer, - QuantNonLinearActLayer, -) -from brevitas.nn.quant_mha import QuantMultiheadAttention - -from .Base import TransformationPass -from ..CustomForwards.Linear import InnerForwardImplWrapperLinear, quantWBIOLForward -from ..CustomForwards.MultiHeadAttention import unrolledQuantMhaForward -from ..CustomTracer import CustomBrevitasTracer -from ..CustomForwards.Activations import ( - InnerForwardImplWrapperActivation, - quantActivationForward, -) - - -class LinearTransformation(TransformationPass): - """ - Transformation pass for quantized linear layers (QuantLinear, QuantConv2d). - - Replaces the default forward with an unrolled implementation that exposes - all quantization steps in the computation graph. - """ - - def __init__(self) -> None: - """Initialize the linear transformation pass.""" - super().__init__( - moduleCls=QuantWeightBiasInputOutputLayer, - validationTol=1e-6, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for linear layers. - - Args: - module: The linear module to transform. - tracer: Optional tracer for registering transformed modules. - """ - module.wrappedInnerForwardImpl = InnerForwardImplWrapperLinear( - module.inner_forward_impl - ) - module.forward = quantWBIOLForward.__get__(module) - - if tracer: - tracer.registerLeafModule(InnerForwardImplWrapperLinear) - tracer.registerNonLeafModule(QuantWeightBiasInputOutputLayer) - - -class ActivationTransformation(TransformationPass): - """ - Transformation pass for quantized activation functions. - - Replaces the default forward with an unrolled implementation that exposes - the input quantization and activation quantization steps. - """ - - def __init__(self) -> None: - """Initialize the activation transformation pass.""" - super().__init__( - moduleCls=QuantNonLinearActLayer, - validationTol=1e-6, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for activation layers. - - This method instantiates the original activation function (if provided) and - wraps it using InnerForwardImplWrapperActivation, then overrides the forward method. - - Args: - module: The activation module to transform. - tracer: Optional tracer for registering transformed modules. - """ - # If the activation implementation was provided (e.g. nn.ReLU for QuantReLU), - # instantiate it. Otherwise, default to an identity. - if hasattr(module, "act_impl") and module.act_impl is not None: - actInstance = module.act_impl() # e.g. nn.ReLU() - else: - actInstance = nn.Identity() - - module.wrappedActImpl = InnerForwardImplWrapperActivation(actInstance) - module.forward = quantActivationForward.__get__(module) - - if tracer: - tracer.registerLeafModule(InnerForwardImplWrapperActivation) - tracer.registerNonLeafModule(QuantNonLinearActLayer) - - -class MHATransformation(TransformationPass): - """ - Transformation pass for quantized multi-head attention layers. - - Replaces the default forward with an unrolled implementation that exposes - all attention operations and their associated quantization steps. - """ - - def __init__(self) -> None: - """Initialize the MHA transformation pass.""" - super().__init__( - moduleCls=QuantMultiheadAttention, - validationTol=1e-5, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for MHA layers. - - Args: - module: The MHA module to transform. - tracer: Optional tracer for registering transformed modules. - """ - module.forward = unrolledQuantMhaForward.__get__(module) - - if tracer: - tracer.registerNonLeafModule(QuantMultiheadAttention) diff --git a/DeepQuant/Pipeline/DequantUnify.py b/DeepQuant/Pipeline/DequantUnify.py new file mode 100644 index 0000000..a743c52 --- /dev/null +++ b/DeepQuant/Pipeline/DequantUnify.py @@ -0,0 +1,117 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.QuantManipulation.DequantModifier import unifyLinearDequants +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter +from DeepQuant.Utils.TensorRecorder import TensorRecorder + + +def mergeDequants( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """ + Unify dequantization nodes to enable integer-only computation. + + This step modifies the dequantization nodes in the graph to allow + operations to run in the integer domain, applying dequantization + only after the computations are complete (Requantization). + """ + printer = GraphModulePrinter() + tensorRecorder = TensorRecorder(debug=debug) + + if debug: + # FBRANCASI: Register hooks to record tensors from the split model (before dequant modification) + tensorRecorder.registerForwardHooks( + model, + nodeTypes=[ + "wrappedInnerForwardImpl", + "dequant", + "unified_dequant", + "linear", + "conv", + "quant", + "act", + "bias_quant", + "act_quant", + "relu", + ], + ) + + # FBRANCASI: Run the model to record tensors before modification + with torch.no_grad(): + _ = model(exampleInput) + + if debug: + # FBRANCASI: Save tensors as reference for comparison + tensorRecorder.setReferenceTensors() + + # FBRANCASI: Register mappings from wrappedInnerForwardImpl nodes to expected unified_dequant nodes + for node in model.graph.nodes: + if node.op == "call_module" and "wrappedInnerForwardImpl" in node.target: + baseName = node.target.replace(".wrappedInnerForwardImpl", "") + dequantName = f"{baseName}_unified_dequant" + dequantName = dequantName.replace(".", "_") + + tensorRecorder.recordNodeMapping(node.target, dequantName) + + unifiedModel = unifyLinearDequants(model, debug=debug) + unifiedModel.recompile() + + if debug: + print(cc.header("4. Network after Modification of Dequant Nodes")) + printer.printTabular(unifiedModel) + print() + + with torch.no_grad(): + output = unifiedModel(exampleInput) + + # FBRANCASI: Check output equivalence with a warning instead of error + if not torch.allclose(referenceOutput, output, atol=1e-5) and debug: + print( + cc.warning( + "Modification of Dequant Nodes may have changed the output slightly" + ) + ) + + if debug: + # FBRANCASI: Register hooks for the unified model and compare tensors + tensorRecorder.registerForwardHooks( + unifiedModel, + nodeTypes=[ + "wrappedInnerForwardImpl", + "dequant", + "unified_dequant", + "linear", + "conv", + "quant", + "act", + "bias_quant", + "act_quant", + "relu", + ], + ) + + # FBRANCASI: Run the model to record tensors after modification + with torch.no_grad(): + _ = unifiedModel(exampleInput) + + # FBRANCASI: Compare tensors before and after modification + print(cc.info("Tensor Comparison Before/After Dequant Unification:")) + results = tensorRecorder.compareTensors() + tensorRecorder.printComparisonResults(results) + + tensorRecorder.removeHooks() + + return unifiedModel, output diff --git a/DeepQuant/Pipeline/Injection.py b/DeepQuant/Pipeline/Injection.py new file mode 100644 index 0000000..477e909 --- /dev/null +++ b/DeepQuant/Pipeline/Injection.py @@ -0,0 +1,65 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.Transforms.Executor import TransformationExecutor +from DeepQuant.Transforms.Transformations import ( + ActivationTransformation, + LinearTransformation, + MHATransformation, +) +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def injectCustomForwards( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """Inject custom forward implementations into the model.""" + printer = GraphModulePrinter() + + tracer = QuantTracer(debug=debug) + + transformations = [ + MHATransformation(), + LinearTransformation(), + ActivationTransformation(), + ] + + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) + transformedModel = executor.execute(model, exampleInput) + + fxModel = customBrevitasTrace( + root=transformedModel, + tracer=tracer, + ) + fxModel.recompile() + + with torch.no_grad(): + output = fxModel(exampleInput) + + if torch.allclose(referenceOutput, output, atol=1e-5): + if debug: + print(cc.success("Injection of New Modules: output is consistent")) + else: + raise RuntimeError( + cc.error("Injection of New Modules changed the output significantly") + ) + + if debug: + print(cc.header("2. Network after Injection of New Modules")) + printer.printTabular(fxModel) + print() + + return fxModel, output diff --git a/DeepQuant/Pipeline/OnnxExport.py b/DeepQuant/Pipeline/OnnxExport.py new file mode 100644 index 0000000..d3ac909 --- /dev/null +++ b/DeepQuant/Pipeline/OnnxExport.py @@ -0,0 +1,61 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from pathlib import Path +from typing import Tuple, Union + +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.nn as nn + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + +def exportToOnnx( + model: nn.Module, + exampleInput: torch.Tensor, + exportPath: Union[str, Path], + debug: bool = False, +) -> Tuple[Path, np.ndarray]: + """Export model to ONNX format and save input/output data.""" + exportPath = Path(exportPath) + exportPath.mkdir(parents=True, exist_ok=True) + + onnxFile = exportPath / "network.onnx" + inputFile = exportPath / "inputs.npz" + outputFile = exportPath / "outputs.npz" + + torch.onnx.export( + model, + args=exampleInput, + f=onnxFile, + opset_version=13, + keep_initializers_as_inputs=False, # FBRANCASI: Prevent warnings + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + ) + + onnxModel = onnx.load(onnxFile) + inferredModel = onnx.shape_inference.infer_shapes(onnxModel) + onnx.save(inferredModel, onnxFile) + + np.savez(inputFile, input=exampleInput.cpu().numpy()) + if debug: + print() + print(cc.success(f"Input data saved to {inputFile}")) + + ortSession = ort.InferenceSession(onnxFile) + ortInputs = {"input": exampleInput.cpu().numpy()} + ortOutput = ortSession.run(None, ortInputs)[0] + + np.savez(outputFile, output=ortOutput) + if debug: + print(cc.success(f"Output data saved to {outputFile}\n")) + + return onnxFile, ortOutput diff --git a/DeepQuant/Pipeline/OriginalTracing.py b/DeepQuant/Pipeline/OriginalTracing.py new file mode 100644 index 0000000..d9e6bb9 --- /dev/null +++ b/DeepQuant/Pipeline/OriginalTracing.py @@ -0,0 +1,38 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn +from brevitas.export.inference import quant_inference_mode +from brevitas.fx import brevitas_symbolic_trace + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def traceOriginalModel( + model: nn.Module, exampleInput: torch.Tensor, debug: bool = False +) -> Tuple[nn.Module, torch.Tensor]: + """Symbolically trace the original model using Brevitas.""" + printer = GraphModulePrinter() + + tracedModel = brevitas_symbolic_trace(model) + + if debug: + print(cc.header("1. Original Network")) + printer.printTabular(tracedModel) + print() + + with torch.no_grad(), quant_inference_mode(model): + output = model(exampleInput) + + # FBRANCASI: Handle case where output is a tuple (e.g., MHA) + if isinstance(output, tuple): + output = output[0] + + return tracedModel, output diff --git a/DeepQuant/Pipeline/QuantSplit.py b/DeepQuant/Pipeline/QuantSplit.py new file mode 100644 index 0000000..0f30ee7 --- /dev/null +++ b/DeepQuant/Pipeline/QuantSplit.py @@ -0,0 +1,60 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.QuantManipulation.QuantizationParameterExtractor import ( + extractBrevitasProxyParams, + printQuantParams, +) +from DeepQuant.QuantManipulation.QuantNodesDivider import convertQuantOperations +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def splitQuantNodes( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """ + Split quantization nodes into separate Quant and Dequant nodes. + + This step transforms each quantization operation into explicit + Quant and Dequant node pairs, providing clear separation between + quantized and floating-point operations. + """ + printer = GraphModulePrinter() + + proxyParams = extractBrevitasProxyParams(model) + + if debug: + printQuantParams(proxyParams) + + splitModel = convertQuantOperations(model, proxyParams, debug) + splitModel.recompile() + + with torch.no_grad(): + output = splitModel(exampleInput) + + if torch.allclose(referenceOutput, output, atol=1e-5): + if debug: + print(cc.success("Split of Quant Nodes: output is consistent")) + else: + raise RuntimeError( + cc.error("Split of Quant Nodes changed the output significantly") + ) + + if debug: + print(cc.header("3. Network after Split of Quant Nodes")) + printer.printTabular(splitModel) + print() + + return splitModel, output diff --git a/DeepQuant/QuantManipulation/DequantModifier.py b/DeepQuant/QuantManipulation/DequantModifier.py index 8bd9ae5..d6357fd 100644 --- a/DeepQuant/QuantManipulation/DequantModifier.py +++ b/DeepQuant/QuantManipulation/DequantModifier.py @@ -4,66 +4,24 @@ # # Federico Brancasi -""" -This module provides a function to unify the linear dequant nodes (input, weight, bias) -into a single final dequant node after the linear wrappedInnerForwardImpl. - -Key steps: - 1) Rewire bias quant to reference the quant nodes of input/weight instead of their dequant. - 2) Rewire the linear's wrappedInnerForwardImpl so it references bias_quant instead of bias_dequant. - 3) Clone the bias dequant parameters (scale/zero_point/bit_width) to a new Dequant node - placed after the linear, removing the old bias_dequant node from the graph. - 4) Remove the input_dequant and weight_dequant nodes as well, once they have no more users. - 5) Recompile the FX GraphModule so that the generated forward code no longer references - the removed nodes. - -By the end, the linear operation is in the integer domain, and the final dequant occurs only once. -""" - import torch.fx as fx from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc -BLUE = "\033[94m" -ENDC = "\033[0m" -CHECK = " ✓" -ARROW = " ›" - - -def unifyLinearDequants( - fxModel: fx.GraphModule, debug: bool = False -) -> fx.GraphModule: - """ - Unify the linear dequant nodes (input, weight, bias) into a single final dequant node. - - This transformation: - * Redirects the linear's inputs to the quant nodes (removing input_dequant, weight_dequant). - * Updates bias_quant to reference those same quant nodes, removing references to dequant. - * Creates a new Dequant node after the linear operation, reusing the bias dequant parameters. - * Erases the old dequant nodes from the graph and submodules. - * Recompiles the graph so the final forward does not reference removed nodes. - - Args: - fxModel (fx.GraphModule): The input FX GraphModule to be modified. - debug (bool): If True, prints debug information. - - Returns: - fx.GraphModule: The modified FX GraphModule with a single dequant node after the linear. - """ +def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.GraphModule: + """Unify the linear dequant nodes (input, weight, bias) into a single final dequant node.""" graph = fxModel.graph allNodes = list(graph.nodes) if debug: - print(f"{BLUE}{ARROW} Starting Modification of Dequant Nodes...{ENDC}") + print(cc.info("Starting Modification of Dequant Nodes...")) for node in allNodes: - # Identify the "wrappedInnerForwardImpl" call for linear if node.op != "call_module" or "wrappedInnerForwardImpl" not in node.target: continue - # Typically the node args are: - # (linear1_input_dequant, linear1_weight_dequant, linear1_bias_dequant) oldArgs = list(node.args) biasDequantNode = None @@ -72,7 +30,6 @@ def unifyLinearDequants( newLinArgs = [] - # Collect and rewire the linear's arguments for arg in oldArgs: if arg.op == "call_module" and "dequant" in arg.target.lower(): if "bias_dequant" in arg.target.lower(): @@ -82,7 +39,6 @@ def unifyLinearDequants( else: inputDequantNode = arg - # Replace the dequant input with the corresponding quant node quantNode = arg.args[0] newLinArgs.append(quantNode) else: @@ -91,46 +47,37 @@ def unifyLinearDequants( node.args = tuple(newLinArgs) if biasDequantNode is None: - # This would be unusual if a linear is missing bias or missing a bias_dequant + # FBRANCASI: This would be unusual if a linear is missing bias or missing a bias_dequant if debug: print(f"Skipping {node.target}: no biasDequantNode found.") continue - # The bias_quant node that feeds biasDequantNode might reference input/weight dequant - # We rewrite it so that it references the input/weight quant nodes biasQuantNode = biasDequantNode.args[0] if ( biasQuantNode.op == "call_module" and "bias_quant" in biasQuantNode.target.lower() ): - new_bq_args = list(biasQuantNode.args) - # Typically new_bq_args = [bias, input_dequant, weight_dequant] - for i, bq_arg in enumerate(new_bq_args): - if bq_arg.op == "call_module" and "dequant" in bq_arg.target.lower(): - new_bq_args[i] = bq_arg.args[0] # The corresponding quant node - biasQuantNode.args = tuple(new_bq_args) + newBqArgs = list(biasQuantNode.args) + for i, bqArg in enumerate(newBqArgs): + if bqArg.op == "call_module" and "dequant" in bqArg.target.lower(): + newBqArgs[i] = bqArg.args[0] + biasQuantNode.args = tuple(newBqArgs) else: if debug: print( "Warning: Did not find a typical 'bias_quant' node shape in the graph." ) - # Erase input_dequant/weight_dequant from the graph - # They should now have zero real users for dnode in (inputDequantNode, weightDequantNode): if dnode is not None: - # For safety, remove all references for usr in list(dnode.users.keys()): dnode.users[usr] = None if hasattr(fxModel, dnode.target): delattr(fxModel, dnode.target) graph.erase_node(dnode) - # Now we create the final single Dequant node after the linear - # by cloning the bias_dequant submodule's parameters oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target) - # Construct a new Dequant module from the old bias_dequant newDequantModName = ( node.target.replace(".wrappedInnerForwardImpl", "") + "_unified_dequant" ) @@ -138,21 +85,19 @@ def unifyLinearDequants( newDequantModName = newDequantModName.replace(".", "_") unifiedDequantMod = Dequant( - original_module=oldBiasDequantMod.original_module, + originalModule=oldBiasDequantMod.originalModule, scale=oldBiasDequantMod.scale, - zero_point=oldBiasDequantMod.zero_point, - bit_width=oldBiasDequantMod.bit_width, + zeroPoint=oldBiasDequantMod.zeroPoint, + bitWidth=oldBiasDequantMod.bitWidth, ) fxModel.add_module(newDequantModName, unifiedDequantMod) - # Insert the new dequant node after the linear's forward_impl with graph.inserting_after(node): newDequantNode = graph.call_module(newDequantModName, args=(node,)) - # Reroute all users of node to the new dequant node - old_users = list(node.users.keys()) - for usr in old_users: + oldUsers = list(node.users.keys()) + for usr in oldUsers: if usr is not newDequantNode: newArgs = list(usr.args) for i, a in enumerate(newArgs): @@ -160,7 +105,6 @@ def unifyLinearDequants( newArgs[i] = newDequantNode usr.args = tuple(newArgs) - # Remove the old bias_dequant node from the graph for usr in list(biasDequantNode.users.keys()): biasDequantNode.users[usr] = None if hasattr(fxModel, biasDequantNode.target): @@ -168,21 +112,16 @@ def unifyLinearDequants( graph.erase_node(biasDequantNode) if debug: - print(f" {CHECK} Modification done for {node.target}") + print(cc.success(f"Modification done for {node.target}")) - # Clean up any leftover references graph.lint() graph.eliminate_dead_code() - # Remove submodules that are now unused fxModel.delete_all_unused_submodules() - # Recompile so that the generated forward code no longer references removed nodes fxModel.recompile() if debug: - print( - f"{BLUE}{ARROW} Modification of Dequant Nodes completed successfully{ENDC}" - ) + print(cc.info("Modification of Dequant Nodes completed successfully")) return fxModel diff --git a/DeepQuant/QuantManipulation/ParameterExtractor.py b/DeepQuant/QuantManipulation/ParameterExtractor.py deleted file mode 100644 index b11d77b..0000000 --- a/DeepQuant/QuantManipulation/ParameterExtractor.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -This module extracts quantization proxy parameters from an exported FX model. -It retrieves scale, zero_point, bit_width and deduces the signedness of the quant -modules in the model by using type- and attribute-based checks rather than string -inspection. - -The safe_get_is_signed() function first looks for an explicit `is_signed` attribute, -then uses the module's min_val (if available) to infer signedness (a negative value -indicates signed quantization). If neither is available, it falls back to checking -the zero_point (a zero or near-zero value suggests unsigned quantization). - -The extracted parameters are printed using a color-coded format. -""" - -from typing import Any, Dict -import torch -import torch.nn as nn -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector -from brevitas.proxy.parameter_quant import ( - WeightQuantProxyFromInjector, - BiasQuantProxyFromInjector, -) -from colorama import Fore, Style - - -def safe_get_scale(quant_obj: Any) -> Any: - """ - Safely retrieve the scale from a Brevitas quant proxy object. - - Args: - quant_obj: The quant proxy object. - - Returns: - The scale as a float if available, otherwise None. - """ - if quant_obj is None: - return None - maybe_scale = quant_obj.scale() if callable(quant_obj.scale) else quant_obj.scale - if maybe_scale is None: - return None - if isinstance(maybe_scale, torch.Tensor): - return maybe_scale.item() - elif isinstance(maybe_scale, float): - return maybe_scale - try: - return float(maybe_scale) - except Exception: - return None - - -def safe_get_zero_point(quant_obj: Any) -> Any: - """ - Safely retrieve the zero_point from a Brevitas quant proxy object. - - Args: - quant_obj: The quant proxy object. - - Returns: - The zero_point as a float if available, otherwise None. - """ - if quant_obj is None: - return None - maybe_zp = ( - quant_obj.zero_point() - if callable(quant_obj.zero_point) - else quant_obj.zero_point - ) - if maybe_zp is None: - return None - if isinstance(maybe_zp, torch.Tensor): - return maybe_zp.item() - elif isinstance(maybe_zp, float): - return maybe_zp - try: - return float(maybe_zp) - except Exception: - return None - - -def safe_get_is_signed(quant_obj: Any) -> bool: - """ - Determine whether a quant proxy/module is signed. - - The function first checks for an explicit `is_signed` attribute. - If not found, it checks for a `min_val` attribute: a negative min_val - indicates signed quantization. If that is unavailable, it examines the - zero_point (if nearly zero, it is assumed unsigned). Defaults to True. - - Args: - quant_obj: The quant proxy object. - - Returns: - True if the quantization is signed, False otherwise. - """ - if hasattr(quant_obj, "is_signed"): - return getattr(quant_obj, "is_signed") - if hasattr(quant_obj, "min_val"): - try: - return quant_obj.min_val < 0 - except Exception: - pass - zp = safe_get_zero_point(quant_obj) - if zp is not None: - # If zero_point is near zero, assume unsigned quantization. - return not (abs(zp) < 1e-5) - return True - - -def extract_brevitas_proxy_params(model: nn.Module) -> Dict[str, Dict[str, Any]]: - """ - Recursively scan the exported FX model to find quant proxy submodules of types: - ActQuantProxyFromInjector, WeightQuantProxyFromInjector, or BiasQuantProxyFromInjector. - For each matching module, extract the scale, zero_point, bit_width, and deduced signedness. - - Args: - model: The exported FX model. - - Returns: - A dictionary mapping module names to their quantization parameters: - { - 'module_name': { - 'scale': float or None, - 'zero_point': float or None, - 'bit_width': float or None, - 'is_signed': bool - }, - ... - } - """ - params_dict: Dict[str, Dict[str, Any]] = {} - - def recurse_modules(parent_mod: nn.Module, prefix: str = "") -> None: - for child_name, child_mod in parent_mod.named_children(): - full_name = f"{prefix}.{child_name}" if prefix else child_name - if isinstance( - child_mod, - ( - ActQuantProxyFromInjector, - WeightQuantProxyFromInjector, - BiasQuantProxyFromInjector, - ), - ): - scl = safe_get_scale(child_mod) - zp = safe_get_zero_point(child_mod) - bw = ( - child_mod.bit_width() - ) # Assumes bit_width() returns a numeric value. - is_signed = safe_get_is_signed(child_mod) - params_dict[full_name] = { - "scale": scl, - "zero_point": zp, - "bit_width": bw, - "is_signed": is_signed, - } - recurse_modules(child_mod, prefix=full_name) - - recurse_modules(model) - return params_dict - - -def print_quant_params(params_dict: Dict[str, Dict[str, Any]]) -> None: - """ - Print the extracted quantization parameters for each proxy module in a - color-coded format. - - Args: - params_dict: Dictionary containing quantization parameters. - """ - print(f"\n{Fore.BLUE}Extracted Parameters from the Network:{Style.RESET_ALL}") - for layer_name, quant_values in params_dict.items(): - print(f" {Fore.BLUE}{layer_name}:{Style.RESET_ALL}") - for param_key, param_val in quant_values.items(): - print(f" {param_key}: {param_val}") - print() diff --git a/DeepQuant/QuantManipulation/QuantDequantNodes.py b/DeepQuant/QuantManipulation/QuantDequantNodes.py index 7332833..d130d78 100644 --- a/DeepQuant/QuantManipulation/QuantDequantNodes.py +++ b/DeepQuant/QuantManipulation/QuantDequantNodes.py @@ -4,130 +4,76 @@ # # Federico Brancasi -""" -Basic implementation of Quant and Dequant modules. -""" +from typing import Optional import torch import torch.nn as nn -from typing import Any, Optional, Union class Quant(nn.Module): - """ - Fake-quant module that applies a "saturating" approach using scale, zero_point, bit_width, - and signedness parameters extracted from a Brevitas parameter dictionary. - - This module simulates quantization effects on tensors by scaling, shifting, rounding, - and clamping their values. - """ + """Quantization module that applies scale, zero-point, and bit-width constraints.""" def __init__( self, - original_module: nn.Module, + originalModule: nn.Module, scale: float, - zero_point: float, - bit_width: float, + zeroPoint: float, + bitWidth: float, signed: Optional[bool] = True, ) -> None: - """ - Initialize the Quant module. - - Args: - original_module: The original Brevitas quant module (kept for reference). - scale: Scale factor used for quantization. - zero_point: Zero-point used for quantization. - bit_width: Bit width for the quantized representation (e.g., 8.0, 32.0). - signed: Boolean flag indicating if quantization is signed. - """ super().__init__() - self.original_module = original_module + self.originalModule = originalModule self.scale = scale - self.zero_point = zero_point - self.bit_width = bit_width + self.zeroPoint = zeroPoint + self.bitWidth = bitWidth self.signed = signed - if self.bit_width is not None: - bw_int = int(self.bit_width) + if self.bitWidth is not None: + bwInt = int(self.bitWidth) if self.signed: - self.min_val = -(2 ** (bw_int - 1)) - self.max_val = (2 ** (bw_int - 1)) - 1 + self.minVal = -(2 ** (bwInt - 1)) + self.maxVal = (2 ** (bwInt - 1)) - 1 else: - self.min_val = 0 - self.max_val = (2**bw_int) - 1 + self.minVal = 0 + self.maxVal = (2**bwInt) - 1 else: - self.min_val = None - self.max_val = None + self.minVal = None + self.maxVal = None def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Apply fake quantization to the input tensor. - - The quantization process is as follows: - 1) Scale the input tensor by 1/scale. - 2) Shift the scaled tensor by the zero_point. - 3) Round the shifted tensor to the nearest integer. - 4) Clamp the rounded tensor to the representable range based on bit_width - and signedness. - - Args: - x: Input tensor. - - Returns: - The fake quantized tensor. - """ - if self.scale is None or self.zero_point is None: + """Quantize the input tensor.""" + if self.scale is None or self.zeroPoint is None: return x - x_scaled = x / self.scale - x_shifted = x_scaled + self.zero_point - x_rounded = torch.round(x_shifted) - if self.bit_width is not None: - x_rounded = torch.clamp(x_rounded, self.min_val, self.max_val) - return x_rounded + xScaled = x / self.scale + xShifted = xScaled + self.zeroPoint + xRounded = torch.round(xShifted) + if self.bitWidth is not None: + xRounded = torch.clamp(xRounded, self.minVal, self.maxVal) + return xRounded class Dequant(nn.Module): - """ - Dequant module that re-applies scale and zero_point to invert the quantization effect. - """ + """Dequantization module that applies inverse scale and zero-point transformations.""" def __init__( self, - original_module: nn.Module, + originalModule: nn.Module, scale: float, - zero_point: float, - bit_width: float, + zeroPoint: float, + bitWidth: float, signed: Optional[bool] = True, ) -> None: - """ - Initialize the Dequant module. - - Args: - original_module: The original Brevitas quant module. - scale: Scale factor from extracted parameters. - zero_point: Zero-point from extracted parameters. - bit_width: Bit width from extracted parameters. - signed: Boolean flag indicating if quantization is signed. - """ super().__init__() - self.original_module = original_module + self.originalModule = originalModule self.scale = scale - self.zero_point = zero_point - self.bit_width = bit_width + self.zeroPoint = zeroPoint + self.bitWidth = bitWidth self.signed = signed def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Undo the fake quantization by reversing the shift and scale. - - Args: - x: Input tensor. - - Returns: - The dequantized tensor. - """ - if self.scale is None or self.zero_point is None: + """Dequantize the input tensor.""" + if self.scale is None or self.zeroPoint is None: return x - x_dequant = (x - self.zero_point) * self.scale - return x_dequant + dequantizedX = (x - self.zeroPoint) * self.scale + return dequantizedX diff --git a/DeepQuant/QuantManipulation/QuantNodesDivider.py b/DeepQuant/QuantManipulation/QuantNodesDivider.py index 6b7ab10..ef7f627 100644 --- a/DeepQuant/QuantManipulation/QuantNodesDivider.py +++ b/DeepQuant/QuantManipulation/QuantNodesDivider.py @@ -4,145 +4,129 @@ # # Federico Brancasi -""" -Module for transforming FX graphs by splitting quantization nodes into Quant and Dequant, -while skipping activation quant nodes to preserve nonzero outputs. -""" +from typing import Any, Dict, List, Tuple import torch.fx as fx -from typing import Dict, Any, List, Tuple -from .QuantDequantNodes import Quant, Dequant import torch.nn as nn -BLUE = "\033[94m" -ENDC = "\033[0m" -ARROW = " ›" +from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant, Quant +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc -def create_quant_dequant_nodes( +def insertQuantDequantPair( graph: fx.Graph, node: fx.Node, - fx_model: fx.GraphModule, - quant_name: str, - dequant_name: str, - original_module: nn.Module, - param_dict: Dict[str, Any], + fxModel: fx.GraphModule, + quantName: str, + dequantName: str, + originalModule: nn.Module, + paramDict: Dict[str, Any], ) -> Tuple[fx.Node, fx.Node]: - """ - Create separate Quant and Dequant nodes for a given FX node. - - This function replaces a single quantization node (e.g. weight_quant) - with two call_module nodes: one for Quant and one for Dequant. Because - the Quant module only accepts one Tensor argument, multiple arguments - (e.g. bias, input, weight) must be reduced to one. - - Args: - graph: The FX graph to insert new nodes into. - node: The original node referencing a quantization module. - fx_model: The GraphModule containing submodules. - quant_name: Name for the new Quant submodule. - dequant_name: Name for the new Dequant submodule. - original_module: The original Brevitas quant module. - param_dict: Dictionary with keys 'scale', 'zero_point', 'bit_width', - and 'is_signed'. - - Returns: - A tuple containing the newly created Quant and Dequant nodes. - """ + """Create separate Quant and Dequant nodes for a given FX node.""" if "bias_quant" in node.target.lower(): - main_arg = node.args[0] + mainArg = node.args[0] elif "weight_quant" in node.target.lower(): - main_arg = node.args[0] + mainArg = node.args[0] else: - main_arg = node.args[0] + mainArg = node.args[0] - scale_val = param_dict.get("scale", None) - zp_val = param_dict.get("zero_point", None) - bw_val = param_dict.get("bit_width", None) - signed_val = param_dict.get("is_signed", True) + scaleVal = paramDict.get("scale", None) + zpVal = paramDict.get("zero_point", None) + bwVal = paramDict.get("bit_width", None) + signedVal = paramDict.get("is_signed", True) - fx_model.add_module( - quant_name, Quant(original_module, scale_val, zp_val, bw_val, signed=signed_val) + fxModel.add_module( + quantName, Quant(originalModule, scaleVal, zpVal, bwVal, signed=signedVal) ) - fx_model.add_module( - dequant_name, - Dequant(original_module, scale_val, zp_val, bw_val, signed=signed_val), + fxModel.add_module( + dequantName, + Dequant(originalModule, scaleVal, zpVal, bwVal, signed=signedVal), ) with graph.inserting_after(node): - quant_node = graph.call_module(quant_name, args=(main_arg,)) + quantNode = graph.call_module(quantName, args=(mainArg,)) - with graph.inserting_after(quant_node): - dequant_node = graph.call_module(dequant_name, args=(quant_node,)) + with graph.inserting_after(quantNode): + dequantNode = graph.call_module(dequantName, args=(quantNode,)) - return quant_node, dequant_node + return quantNode, dequantNode -def split_quant_nodes( - fx_model: fx.GraphModule, full_params_dict: Dict[str, Dict[str, Any]], debug: bool +def convertQuantOperations( + fxModel: fx.GraphModule, fullParamsDict: Dict[str, Dict[str, Any]], debug: bool ) -> fx.GraphModule: - """ - Transform an FX graph by splitting each "call_module(...quant...)" node into - separate Quant -> Dequant nodes, skipping activation quant nodes to preserve - numeric accuracy. - - Args: - fx_model: The input FX GraphModule. - full_params_dict: A dictionary mapping module names to quantization parameters. - debug: Whether to print debug output. - - Returns: - The updated FX GraphModule with weight/bias quant calls split. - """ - graph = fx_model.graph - nodes_to_erase: List[fx.Node] = [] + """Split quantization nodes into separate Quant and Dequant nodes.""" + graph = fxModel.graph + nodesToRemove: List[fx.Node] = [] if debug: - print(f"{BLUE}{ARROW} Starting Quantization Node Splitting...{ENDC}") + print(cc.info("Starting Quantization Node Splitting...")) - all_nodes = list(graph.nodes) + allNodes = list(graph.nodes) - for node in all_nodes: + for node in allNodes: if ( node.op == "call_module" and "quant" in node.target.lower() and "act_impl" not in node.target.lower() ): - top_level = node.target.split(".")[0] - if top_level in ["sigmoid"]: - continue # Skip sigmoid + topLevel = node.target.split(".")[0] + if topLevel in ["sigmoid"]: + continue # FBRANCASI: Skip sigmoid - original_module = fx_model.get_submodule(node.target) - safe_target = node.target.replace(".", "_").replace("_quant", "") - quant_name = f"{safe_target}_quant_1" - dequant_name = f"{safe_target}_dequant" - param_info = full_params_dict.get(node.target, {}) + originalModule = fxModel.get_submodule(node.target) + safeTarget = node.target.replace(".", "_").replace("_quant", "") + quantName = f"{safeTarget}_quant_1" + dequantName = f"{safeTarget}_dequant" + paramInfo = fullParamsDict.get(node.target, {}) - quant_node, dequant_node = create_quant_dequant_nodes( + quantNode, dequantNode = insertQuantDequantPair( graph, node, - fx_model, - quant_name, - dequant_name, - original_module, - param_info, + fxModel, + quantName, + dequantName, + originalModule, + paramInfo, ) - # Re-route all users of the original node. - for user_node in list(node.users.keys()): - new_args = [] - for arg in user_node.args: - new_args.append(dequant_node if arg is node else arg) - user_node.args = tuple(new_args) - - nodes_to_erase.append(node) - - for erase_node in nodes_to_erase: - graph.erase_node(erase_node) + usersUpdated = False + for userNode in list(node.users.keys()): + if ( + userNode.op == "call_function" + and hasattr(userNode.target, "__name__") + and userNode.target.__name__ == "cat" + ): + # FBRANCASI: This is a concatenation operation - Special Handling + newCatArgs = list(userNode.args) + if len(newCatArgs) >= 1 and isinstance(newCatArgs[0], list): + tensorsList = newCatArgs[0] + updatedTensors = [] + for tensor in tensorsList: + if tensor is node: + updatedTensors.append(dequantNode) + else: + updatedTensors.append(tensor) + newCatArgs[0] = updatedTensors + userNode.args = tuple(newCatArgs) + usersUpdated = True + else: + # FBRANCASI: Standard node reference replacement + newArgs = [] + for arg in userNode.args: + newArgs.append(dequantNode if arg is node else arg) + userNode.args = tuple(newArgs) + usersUpdated = True + + if usersUpdated: + nodesToRemove.append(node) + + for eraseNode in nodesToRemove: + graph.erase_node(eraseNode) graph.lint() if debug: - print(f"{BLUE}{ARROW} Quantization Node Splitting completed Successfully{ENDC}") + print(cc.info("Quantization Node Splitting completed Successfully")) - return fx_model + return fxModel diff --git a/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py b/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py new file mode 100644 index 0000000..22c0629 --- /dev/null +++ b/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py @@ -0,0 +1,110 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Any, Dict + +import torch +import torch.nn as nn +from brevitas.proxy.parameter_quant import ( + BiasQuantProxyFromInjector, + WeightQuantProxyFromInjector, +) +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector + + +def getScale(quantObj: Any) -> Any: + """Extract scale parameter from quantization object.""" + if quantObj is None: + return None + maybeScale = quantObj.scale() if callable(quantObj.scale) else quantObj.scale + if maybeScale is None: + return None + if isinstance(maybeScale, torch.Tensor): + return maybeScale.item() + elif isinstance(maybeScale, float): + return maybeScale + try: + return float(maybeScale) + except Exception: + return None + + +def getZeroPoint(quantObj: Any) -> Any: + """Extract zero point parameter from quantization object.""" + if quantObj is None: + return None + maybeZp = ( + quantObj.zero_point() if callable(quantObj.zero_point) else quantObj.zero_point + ) + if maybeZp is None: + return None + if isinstance(maybeZp, torch.Tensor): + return maybeZp.item() + elif isinstance(maybeZp, float): + return maybeZp + try: + return float(maybeZp) + except Exception: + return None + + +def getIsSigned(quantObj: Any) -> bool: + """Determine if quantization is signed.""" + if hasattr(quantObj, "is_signed"): + return getattr(quantObj, "is_signed") + if hasattr(quantObj, "min_val"): + try: + return quantObj.min_val < 0 + except Exception: + pass + zp = getZeroPoint(quantObj) + if zp is not None: + # If zero_point is near zero, assume unsigned quantization. + return not (abs(zp) < 1e-5) + return True + + +def extractBrevitasProxyParams(model: nn.Module) -> Dict[str, Dict[str, Any]]: + """Extract quantization parameters from Brevitas proxy modules.""" + paramsDict: Dict[str, Dict[str, Any]] = {} + + def recurseModules(parentMod: nn.Module, prefix: str = "") -> None: + for childName, childMod in parentMod.named_children(): + fullName = f"{prefix}.{childName}" if prefix else childName + if isinstance( + childMod, + ( + ActQuantProxyFromInjector, + WeightQuantProxyFromInjector, + BiasQuantProxyFromInjector, + ), + ): + scl = getScale(childMod) + zp = getZeroPoint(childMod) + bw = childMod.bit_width() + isSigned = getIsSigned(childMod) + paramsDict[fullName] = { + "scale": scl, + "zero_point": zp, + "bit_width": bw, + "is_signed": isSigned, + } + recurseModules(childMod, prefix=fullName) + + recurseModules(model) + return paramsDict + + +def printQuantParams(paramsDict: Dict[str, Dict[str, Any]]) -> None: + """Print extracted quantization parameters in a readable format.""" + from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + print(f"{cc.wrap('Extracted Parameters from the Network:', cc.blue)}") + for layerName, quantValues in paramsDict.items(): + print(f" {cc.wrap(layerName + ':', cc.blue)}") + for paramKey, paramVal in quantValues.items(): + print(f" {paramKey}: {paramVal}") + print() diff --git a/DeepQuant/Transforms/Base.py b/DeepQuant/Transforms/Base.py new file mode 100644 index 0000000..392d6cf --- /dev/null +++ b/DeepQuant/Transforms/Base.py @@ -0,0 +1,53 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class TransformationPass(ABC): + """Base class for module transformation passes.""" + + def __init__( + self, + moduleCls: Union[type, Tuple[type, ...]], + validationTol: float = 1e-6, + ) -> None: + self.moduleCls = moduleCls + self.validationTol = validationTol + + def checkModuleType(self, module: nn.Module) -> bool: + """Check if a module is an instance of the target class(es).""" + return isinstance(module, self.moduleCls) + + @abstractmethod + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject the custom forward implementation into a module.""" + pass + + def validateTransformation( + self, outputBefore: Any, outputAfter: Any, atol: Optional[float] = None + ) -> bool: + """Validate transformation by comparing outputs.""" + if atol is None: + atol = self.validationTol + return torch.allclose(outputBefore, outputAfter, atol=atol) + + def transform(self, model: nn.Module, tracer: Optional[QuantTracer] = None) -> bool: + """Apply the transformation to all matching submodules.""" + transformDone = False + for _, submodule in model.named_modules(): + if self.checkModuleType(submodule): + self.injectForward(submodule, tracer) + transformDone = True + return transformDone diff --git a/DeepQuant/Transforms/Executor.py b/DeepQuant/Transforms/Executor.py new file mode 100644 index 0000000..09068f7 --- /dev/null +++ b/DeepQuant/Transforms/Executor.py @@ -0,0 +1,65 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import List, Optional + +import torch +import torch.nn as nn + +from DeepQuant.Transforms.Base import TransformationPass +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class TransformationExecutor: + """Runs a sequence of transformation passes.""" + + def __init__( + self, + transformations: List[TransformationPass], + debug: bool = False, + tracer: Optional[QuantTracer] = None, + ) -> None: + self.transformations = transformations + self.debug = debug + self.tracer = tracer + + def execute(self, model: nn.Module, exampleInput: torch.Tensor) -> nn.Module: + """Execute all transformations on the model.""" + model.eval() + with torch.no_grad(): + outputBefore = model(exampleInput) + if isinstance(outputBefore, tuple): + outputBefore = outputBefore[0] + + for transformation in self.transformations: + if transformation.transform(model, tracer=self.tracer): + outputAfter = model(exampleInput) + if isinstance(outputAfter, tuple): + outputAfter = outputAfter[0] + + if not transformation.validateTransformation( + outputBefore, outputAfter + ): + raise RuntimeError( + cc.error( + f"{transformation.__class__.__name__} failed - outputs mismatch" + ) + ) + + if self.debug: + print( + cc.success( + f"{transformation.__class__.__name__} transformation successful" + ) + ) + if self.tracer: + print(f" leafClasses: {self.tracer.leafClasses}") + print(f" nonLeafClasses: {self.tracer.nonLeafClasses}") + + outputBefore = outputAfter + + return model diff --git a/DeepQuant/Transforms/Transformations.py b/DeepQuant/Transforms/Transformations.py new file mode 100644 index 0000000..8c2bd41 --- /dev/null +++ b/DeepQuant/Transforms/Transformations.py @@ -0,0 +1,88 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Optional + +import torch.nn as nn +from brevitas.nn.quant_layer import ( + QuantNonLinearActLayer, + QuantWeightBiasInputOutputLayer, +) +from brevitas.nn.quant_mha import QuantMultiheadAttention + +from DeepQuant.CustomForwards.Activations import WrapperActivation, activationForward +from DeepQuant.CustomForwards.WBIOL import WBIOLForward, WrapperWBIOL +from DeepQuant.CustomForwards.MultiHeadAttention import mhaForward +from DeepQuant.Transforms.Base import TransformationPass +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class LinearTransformation(TransformationPass): + """Transforms quantized linear layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantWeightBiasInputOutputLayer, + validationTol=1e-6, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for linear layers.""" + module.wrappedInnerForwardImpl = WrapperWBIOL(module.inner_forward_impl) + module.forward = WBIOLForward.__get__(module) + + if tracer: + tracer.registerLeafModule(WrapperWBIOL) + tracer.registerNonLeafModule(QuantWeightBiasInputOutputLayer) + + +class ActivationTransformation(TransformationPass): + """Transforms quantized activation layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantNonLinearActLayer, + validationTol=1e-6, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for activation layers.""" + # FBRANCASI: If the activation implementation was provided (e.g. nn.ReLU + # for QuantReLU), instantiate it. Otherwise, default to an identity. + if hasattr(module, "act_impl") and module.act_impl is not None: + actInstance = module.act_impl() + else: + actInstance = nn.Identity() + + module.wrappedActImpl = WrapperActivation(actInstance) + module.forward = activationForward.__get__(module) + + if tracer: + tracer.registerLeafModule(WrapperActivation) + tracer.registerNonLeafModule(QuantNonLinearActLayer) + + +class MHATransformation(TransformationPass): + """Transforms quantized multi-head attention layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantMultiheadAttention, + validationTol=1e-5, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for multi-head attention layers.""" + module.forward = mhaForward.__get__(module) + + if tracer: + tracer.registerNonLeafModule(QuantMultiheadAttention) diff --git a/DeepQuant/Utils/ConsoleFormatter.py b/DeepQuant/Utils/ConsoleFormatter.py new file mode 100644 index 0000000..e5d03f8 --- /dev/null +++ b/DeepQuant/Utils/ConsoleFormatter.py @@ -0,0 +1,54 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi +class ConsoleColor: + """Console color utilities for formatted terminal output.""" + + # Color codes + blue = "\033[94m" + green = "\033[92m" + red = "\033[91m" + yellow = "\033[93m" + cyan = "\033[96m" + magenta = "\033[95m" + bold = "\033[1m" + reset = "\033[0m" + + # Symbols + checkmark = " ✓" + cross = " ✗" + arrow = " ›" + + @staticmethod + def wrap(text: str, color: str) -> str: + """Wrap text with color codes.""" + return f"{color}{text}{ConsoleColor.reset}" + + @staticmethod + def success(text: str) -> str: + """Format a success message.""" + return ConsoleColor.wrap(f"{ConsoleColor.checkmark} {text}", ConsoleColor.green) + + @staticmethod + def error(text: str) -> str: + """Format an error message.""" + return ConsoleColor.wrap(f"{ConsoleColor.cross} {text}", ConsoleColor.red) + + @staticmethod + def info(text: str) -> str: + """Format an informational message.""" + return ConsoleColor.wrap(f"{ConsoleColor.arrow} {text}", ConsoleColor.blue) + + @staticmethod + def warning(text: str) -> str: + """Format a warning message.""" + return ConsoleColor.wrap(text, ConsoleColor.yellow) + + @staticmethod + def header(text: str) -> str: + """Format a step header with separator lines.""" + separator = "=" * 50 + header_text = f"{separator}\n{text}\n{separator}" + return f"\n{ConsoleColor.wrap(header_text, ConsoleColor.magenta)}" diff --git a/DeepQuant/Utils/CustomTracer.py b/DeepQuant/Utils/CustomTracer.py new file mode 100644 index 0000000..4343496 --- /dev/null +++ b/DeepQuant/Utils/CustomTracer.py @@ -0,0 +1,57 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import List, Optional, Type + +import torch.nn as nn +from brevitas.fx.brevitas_tracer import ( + Tracer, + _is_brevitas_leaf_module, + _symbolic_trace, +) +from torch.fx.graph_module import GraphModule + + +class QuantTracer(Tracer): + """Enhanced tracer with fine-grained control over module tracing.""" + + def __init__( + self, + leafClasses: Optional[List[Type[nn.Module]]] = None, + nonLeafClasses: Optional[List[Type[nn.Module]]] = None, + debug: bool = False, + ) -> None: + super().__init__() + self.leafClasses = leafClasses if leafClasses is not None else [] + self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else [] + self.debug = debug + + def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None: + """Register a module class as a leaf module.""" + if moduleCls not in self.leafClasses: + self.leafClasses.append(moduleCls) + + def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None: + """Register a module class as a non-leaf module.""" + if moduleCls not in self.nonLeafClasses: + self.nonLeafClasses.append(moduleCls) + + def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool: + """Determine if a module should be treated as a leaf module.""" + if any(isinstance(m, lc) for lc in self.leafClasses): + return True + if any(isinstance(m, nlc) for nlc in self.nonLeafClasses): + return False + return _is_brevitas_leaf_module(m, moduleQualifiedName) + + +def customBrevitasTrace( + root: nn.Module, concreteArgs=None, tracer: Optional[QuantTracer] = None +) -> GraphModule: + """Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer).""" + if tracer is None: + tracer = QuantTracer() + return _symbolic_trace(tracer, root, concreteArgs) diff --git a/DeepQuant/Utils/FxInterpreter.py b/DeepQuant/Utils/FxInterpreter.py deleted file mode 100644 index 1ac434a..0000000 --- a/DeepQuant/Utils/FxInterpreter.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -FX Graph tracer that traces each node by wrapping submodules with proxy objects. -""" - -import torch -import torch.nn as nn -import torch.fx as fx -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, Callable -import functools -import inspect - - -class NodeTracer: - """ - Traces execution through an FX graph by wrapping each module with a - proxy that logs input and output values. - """ - - def __init__(self, debug: bool = True) -> None: - """ - Initialize the tracer. - - Args: - debug: Whether to print debug information. - """ - self.debug = debug - self.BLUE = "\033[94m" - self.GREEN = "\033[92m" - self.YELLOW = "\033[93m" - self.RED = "\033[91m" - self.RESET = "\033[0m" - self.traced_modules: Dict[str, nn.Module] = {} - self.call_count: Dict[str, int] = {} - - def trace( - self, model: fx.GraphModule, example_input: torch.Tensor - ) -> Optional[torch.Tensor]: - """ - Trace the execution of the model by wrapping modules with proxies. - - Args: - model: The FX GraphModule to trace. - example_input: The input tensor. - - Returns: - The model output, if successful. - """ - if self.debug: - print( - f"\n{self.BLUE}===== Starting FX Graph Execution Tracing ====={self.RESET}\n" - ) - print( - f"{self.BLUE}Input shape: {tuple(example_input.shape)}, dtype: {example_input.dtype}{self.RESET}\n" - ) - - # Wrap all submodules with our proxy - self._wrap_modules(model) - - # Create a copy of the original model to restore wrapped modules after tracing - original_modules = { - name: module - for name, module in model.named_modules() - if not isinstance(module, fx.GraphModule) - } - - try: - # Execute the model with the example input - with torch.no_grad(): - output = model(example_input) - - if self.debug: - print(f"\n{self.GREEN}Execution completed successfully!{self.RESET}") - if isinstance(output, torch.Tensor): - print( - f"{self.GREEN}Output shape: {tuple(output.shape)}, dtype: {output.dtype}{self.RESET}" - ) - else: - print(f"{self.GREEN}Output type: {type(output)}{self.RESET}") - - return output - - except Exception as e: - if self.debug: - print(f"\n{self.RED}Error during execution: {str(e)}{self.RESET}") - return None - - finally: - # Restore original modules - self._restore_modules(model, original_modules) - - def _wrap_modules(self, model: fx.GraphModule) -> None: - """ - Wrap all relevant modules with tracing proxies. - - Args: - model: The model containing modules to wrap. - """ - # Find relevant modules that match nodes in the graph - for name, module in list(model.named_modules()): - if not isinstance(module, fx.GraphModule): - if hasattr(module, "forward"): - original_forward = module.forward - self.traced_modules[name] = original_forward - - # Create wrapped forward method with tracing - @functools.wraps(original_forward) - def traced_forward(self, *args, **kwargs): - module_name = self._tracing_name - - # Increment call count - self._tracer.call_count.setdefault(module_name, 0) - self._tracer.call_count[module_name] += 1 - call_idx = self._tracer.call_count[module_name] - - # Print module info before call - if self._tracer.debug: - module_type = type(self).__name__ - print( - f"\n{self._tracer.YELLOW}[{module_name} ({module_type}) - Call #{call_idx}]{self._tracer.RESET}" - ) - - # Print input tensor info - for i, arg in enumerate(args): - if isinstance(arg, torch.Tensor): - print( - f" Input {i}: Tensor{tuple(arg.shape)} ({arg.dtype})" - ) - # Sample values for extra context - if arg.numel() > 0: - flat = arg.reshape(-1) - sample = flat[:3].tolist() - sample_str = ", ".join( - ( - f"{x:.6f}" - if isinstance(x, float) - else str(x) - ) - for x in sample - ) - print( - f" Values: [{sample_str}{'...' if flat.numel() > 3 else ''}]" - ) - elif ( - isinstance(arg, (list, tuple)) - and len(arg) > 0 - and isinstance(arg[0], torch.Tensor) - ): - print( - f" Input {i}: {type(arg).__name__} of {len(arg)} Tensors" - ) - else: - print(f" Input {i}: {type(arg).__name__}") - - # Call original forward method - result = self._original_forward(*args, **kwargs) - - # Print output info - if self._tracer.debug: - if isinstance(result, torch.Tensor): - print( - f" {self._tracer.GREEN}Output: Tensor{tuple(result.shape)} ({result.dtype}){self._tracer.RESET}" - ) - # Sample output values - if result.numel() > 0: - flat = result.reshape(-1) - sample = flat[:3].tolist() - sample_str = ", ".join( - f"{x:.6f}" if isinstance(x, float) else str(x) - for x in sample - ) - print( - f" Values: [{sample_str}{'...' if flat.numel() > 3 else ''}]" - ) - elif isinstance(result, (list, tuple)) and len(result) > 0: - print( - f" {self._tracer.GREEN}Output: {type(result).__name__} of length {len(result)}{self._tracer.RESET}" - ) - else: - print( - f" {self._tracer.GREEN}Output: {type(result).__name__}{self._tracer.RESET}" - ) - - return result - - # Attach tracer reference and original forward to the wrapped method - traced_forward.__self__ = module - traced_forward.__self__._tracer = self - traced_forward.__self__._original_forward = original_forward - traced_forward.__self__._tracing_name = name - - # Replace forward with wrapped version - module.forward = traced_forward.__get__(module) - - def _restore_modules( - self, model: fx.GraphModule, original_modules: Dict[str, nn.Module] - ) -> None: - """ - Restore original forward methods for all wrapped modules. - - Args: - model: The model containing wrapped modules. - original_modules: Dictionary of original modules. - """ - for name, original_forward in self.traced_modules.items(): - parts = name.split(".") - current = model - - # Navigate to the module - for part in parts: - if hasattr(current, part): - current = getattr(current, part) - else: - break - - # Restore original forward if found - if hasattr(current, "forward") and hasattr(current, "_original_forward"): - current.forward = original_forward diff --git a/DeepQuant/Utils/GraphPrinter.py b/DeepQuant/Utils/GraphPrinter.py index d3d6b9e..35dc97b 100644 --- a/DeepQuant/Utils/GraphPrinter.py +++ b/DeepQuant/Utils/GraphPrinter.py @@ -4,85 +4,23 @@ # # Federico Brancasi -""" -This module provides a specialized GraphModulePrinter class to display an FX GraphModule -in a tabular format, including optional metadata about quantization (like eps, n_levels, signed). +from typing import List, Literal -Usage: - from DeepQuant.graph_printer import GraphModulePrinter - - printer = GraphModulePrinter() - printer.print_tabular( - fx_model, - show_opcode=True, - show_class=True, - show_name=True, - show_target=True, - show_args=True, - show_kwargs=True, - show_eps=False, - show_nlevels=True, - show_signed=True, - unicode=False - ) - -Note: -- This example assumes that each node in the graph may have a `node.meta['quant']` dict - with fields like eps_in, eps_out, n_levels_in, n_levels_out, signed_in, and signed_out. -- If these fields are not present, the code will gracefully skip them or display placeholders. -- If you do not have such metadata in node.meta, you can adapt the logic to suit your needs. -""" - -import math -from typing import Any, List, Literal, Optional import torch.fx as fx - -try: - # Optional: colorama for colored output (requires `pip install colorama`) - from colorama import Fore, Back, Style - - COLORAMA_AVAILABLE = True -except ImportError: - COLORAMA_AVAILABLE = False - -try: - # Optional: tabulate for printing tables (requires `pip install tabulate`) - from tabulate import tabulate - - TABULATE_AVAILABLE = True -except ImportError: - TABULATE_AVAILABLE = False +from colorama import Back, Fore, Style +from tabulate import tabulate class GraphModulePrinter: - """ - Class for printing an FX GraphModule in a tabular format, optionally displaying - quantization metadata stored in node.meta['quant']. - - The code is based on an example snippet from a supervisor. The logic is adjusted - to fit our code style and to gracefully handle missing metadata. - """ + """Formatter and printer for FX graph modules.""" @staticmethod - def quant_info( + def quantInfo( node: fx.Node, prop: Literal["eps_in", "eps_out", "n_levels", "signed"] ) -> str: - """ - Retrieve a string representation of the quantization property for a given node. - - Args: - node: The FX node containing potential quantization metadata. - prop: The quantization property to display. One of 'eps_in', 'eps_out', - 'n_levels', or 'signed'. - - Returns: - A string representation of the requested property if it exists, or '{}' otherwise. - """ if "quant" not in node.meta: return "{}" - # At this point, we assume node.meta['quant'] is a dict-like object containing - # fields such as eps_in, eps_out, n_levels_in, n_levels_out, signed_in, signed_out, etc. qmeta = node.meta["quant"] if prop == "eps_in": @@ -90,12 +28,10 @@ def quant_info( elif prop == "eps_out": return str(qmeta.get("eps_out", "{}")) elif prop == "n_levels": - # This is just an example: we might have n_levels_in, n_levels_out, etc. n_in = qmeta.get("n_levels_in", "{}") n_out = qmeta.get("n_levels_out", "{}") return f"{n_in} -> {n_out}" elif prop == "signed": - # Example: 'signed_in' and 'signed_out' s_in = qmeta.get("signed_in", "{}") s_out = qmeta.get("signed_out", "{}") return f"{s_in} -> {s_out}" @@ -103,196 +39,126 @@ def quant_info( return "{}" @staticmethod - def class_info(node: fx.Node, gm: fx.GraphModule, unicode: bool = False) -> str: - """ - Retrieve class name for call_module nodes. For example, if node.target is - referencing a submodule of type nn.Conv2d, this returns 'Conv2d'. - - Args: - node: The FX node to analyze. - gm: The FX GraphModule containing the node. - unicode: If True, optionally highlight certain classes. - - Returns: - The class name as a string, or '' if not applicable. - """ + def classInfo(node: fx.Node, gm: fx.GraphModule, unicode: bool = False) -> str: if node.op == "call_module": submodule = gm.get_submodule(node.target) class_name = submodule.__class__.__name__ - if not COLORAMA_AVAILABLE or not unicode: + if not unicode: return class_name - # Optionally highlight if it's a special class, e.g. 'PACT' or so. if "PACT" in class_name: return Fore.GREEN + class_name + Style.RESET_ALL return class_name return "" @staticmethod - def node_info(node: fx.Node, attr: str, unicode: bool = False) -> str: - """ - Retrieve a specified attribute from the node (e.g. 'op', 'name', 'target', 'args'). - - Args: - node: The FX node. - attr: The name of the attribute to retrieve (e.g. 'op', 'name', 'target', 'args'). - unicode: If True, highlight certain functions in color. - - Returns: - A string representation of the requested attribute, or '' if not present. - """ + def nodeInfo(node: fx.Node, attr: str, unicode: bool = False) -> str: if not hasattr(node, attr): return "" value = getattr(node, attr) if attr == "op": - # Optionally highlight certain call_function ops - if node.op == "call_function" and COLORAMA_AVAILABLE and unicode: - # Example of a function whitelist + if node.op == "call_function" and unicode: whitelist_functions = ["getitem"] - if node.target.__name__ not in whitelist_functions: + if ( + hasattr(node.target, "__name__") + and node.target.__name__ not in whitelist_functions + ): return Back.YELLOW + str(value) + Style.RESET_ALL return str(value) @classmethod - def get_node_spec( + def getNodeSpec( cls, node: fx.Node, gm: fx.GraphModule, - show_opcode: bool = True, - show_class: bool = True, - show_name: bool = True, - show_target: bool = True, - show_args: bool = True, - show_kwargs: bool = True, - show_eps: bool = False, - show_nlevels: bool = True, - show_signed: bool = True, + showOpcode: bool = True, + showClass: bool = True, + showName: bool = True, + showTarget: bool = True, + showArgs: bool = True, + showKwargs: bool = True, + showEps: bool = False, + showNlevels: bool = True, + showSigned: bool = True, unicode: bool = False, ) -> List[str]: - """ - Collect string representations of the node's attributes/metadata for printing. - - Args: - node: The FX node to process. - gm: The FX GraphModule containing the node. - show_opcode: Whether to display the node's op code. - show_class: Whether to display the submodule class name (for call_module). - show_name: Whether to display the node's name. - show_target: Whether to display the node's target. - show_args: Whether to display the node's args. - show_kwargs: Whether to display the node's kwargs. - show_eps: Whether to display the quantization eps_in/eps_out (if available). - show_nlevels: Whether to display the n_levels_in -> n_levels_out. - show_signed: Whether to display the signed_in -> signed_out. - unicode: If True, apply color highlights for certain attributes. - - Returns: - A list of strings representing each requested attribute in order. - """ - node_specs: List[str] = [] - - if show_opcode: - node_specs.append(cls.node_info(node, "op", unicode)) - if show_class: - node_specs.append(cls.class_info(node, gm, unicode)) - if show_name: - node_specs.append(cls.node_info(node, "name", unicode)) - if show_target: - node_specs.append(cls.node_info(node, "target", unicode)) - if show_args: - node_specs.append(cls.node_info(node, "args", unicode)) - if show_kwargs: - node_specs.append(cls.node_info(node, "kwargs", unicode)) - - if show_nlevels: - node_specs.append(cls.quant_info(node, "n_levels")) - if show_signed: - node_specs.append(cls.quant_info(node, "signed")) - if show_eps: - node_specs.append(cls.quant_info(node, "eps_in")) - node_specs.append(cls.quant_info(node, "eps_out")) - - return node_specs + nodeSpecs: List[str] = [] + + if showOpcode: + nodeSpecs.append(cls.nodeInfo(node, "op", unicode)) + if showClass: + nodeSpecs.append(cls.classInfo(node, gm, unicode)) + if showName: + nodeSpecs.append(cls.nodeInfo(node, "name", unicode)) + if showTarget: + nodeSpecs.append(cls.nodeInfo(node, "target", unicode)) + if showArgs: + nodeSpecs.append(cls.nodeInfo(node, "args", unicode)) + if showKwargs: + nodeSpecs.append(cls.nodeInfo(node, "kwargs", unicode)) + + if showNlevels: + nodeSpecs.append(cls.quantInfo(node, "n_levels")) + if showSigned: + nodeSpecs.append(cls.quantInfo(node, "signed")) + if showEps: + nodeSpecs.append(cls.quantInfo(node, "eps_in")) + nodeSpecs.append(cls.quantInfo(node, "eps_out")) + + return nodeSpecs @classmethod - def print_tabular( + def printTabular( cls, gm: fx.GraphModule, - show_opcode: bool = True, - show_class: bool = True, - show_name: bool = True, - show_target: bool = True, - show_args: bool = False, - show_kwargs: bool = False, - show_eps: bool = False, - show_nlevels: bool = False, - show_signed: bool = False, + showOpcode: bool = True, + showClass: bool = True, + showName: bool = True, + showTarget: bool = True, + showArgs: bool = False, + showKwargs: bool = False, + showEps: bool = False, + showNlevels: bool = False, + showSigned: bool = False, unicode: bool = False, ) -> None: - """ - Print the graph in a tabular format with optional quantization metadata. - - Args: - gm: The FX GraphModule to display. - show_opcode: Whether to display the node's op code. - show_class: Whether to display the submodule class name (for call_module). - show_name: Whether to display the node's name. - show_target: Whether to display the node's target. - show_args: Whether to display the node's args. - show_kwargs: Whether to display the node's kwargs. - show_eps: Whether to display the quantization eps_in/eps_out (if available). - show_nlevels: Whether to display the n_levels_in -> n_levels_out. - show_signed: Whether to display the signed_in -> signed_out. - unicode: If True, apply color highlights for certain attributes. - - Returns: - None - """ - if not TABULATE_AVAILABLE: - print( - "Warning: 'tabulate' is not installed. Install via 'pip install tabulate' to use print_tabular." - ) - return - - node_list = list(gm.graph.nodes) - node_specs = [ - cls.get_node_spec( + nodeList = list(gm.graph.nodes) + nodeSpecs = [ + cls.getNodeSpec( node, gm, - show_opcode=show_opcode, - show_class=show_class, - show_name=show_name, - show_target=show_target, - show_args=show_args, - show_kwargs=show_kwargs, - show_eps=show_eps, - show_nlevels=show_nlevels, - show_signed=show_signed, + showOpcode=showOpcode, + showClass=showClass, + showName=showName, + showTarget=showTarget, + showArgs=showArgs, + showKwargs=showKwargs, + showEps=showEps, + showNlevels=showNlevels, + showSigned=showSigned, unicode=unicode, ) - for node in node_list + for node in nodeList ] headers = [] - if show_opcode: + if showOpcode: headers.append("opcode") - if show_class: + if showClass: headers.append("class") - if show_name: + if showName: headers.append("name") - if show_target: + if showTarget: headers.append("target") - if show_args: + if showArgs: headers.append("args") - if show_kwargs: + if showKwargs: headers.append("kwargs") - if show_nlevels: + if showNlevels: headers.append("n_levels") - if show_signed: + if showSigned: headers.append("signed") - if show_eps: + if showEps: headers.append("eps_in") headers.append("eps_out") - from tabulate import tabulate # safe import inside method - - print(tabulate(node_specs, headers=headers, tablefmt="mixed_grid")) + print(tabulate(nodeSpecs, headers=headers, tablefmt="mixed_grid")) diff --git a/DeepQuant/Utils/TensorRecorder.py b/DeepQuant/Utils/TensorRecorder.py new file mode 100644 index 0000000..d798c3a --- /dev/null +++ b/DeepQuant/Utils/TensorRecorder.py @@ -0,0 +1,177 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from collections import OrderedDict +from typing import Dict, List, Optional, Set + +import torch +import torch.fx as fx + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + +class TensorRecorder: + """Records and compares tensor values during model execution.""" + + def __init__(self, debug: bool = False): + self.debug = debug + self._hooks: List[torch.utils.hooks.RemovableHandle] = [] + self._current: Dict[str, torch.Tensor] = {} + self._reference: Optional[Dict[str, torch.Tensor]] = None + self._executionOrder: List[str] = [] + self._nameMap: Dict[str, str] = {} + self._ignore: Set[str] = set() + + def clear(self) -> None: + self.removeHooks() + self._current.clear() + self._reference = None + self._executionOrder.clear() + self._nameMap.clear() + self._ignore.clear() + + def removeHooks(self) -> None: + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + def registerForwardHooks( + self, model: fx.GraphModule, nodeTypes: Optional[List[str]] = None + ) -> None: + self.removeHooks() + wanted = [w.lower() for w in nodeTypes] if nodeTypes else [] + + def makeHook(name: str): + def hook(_, __, output): + if isinstance(output, torch.Tensor): + self._current[name] = output.detach().clone() + if name not in self._executionOrder: + self._executionOrder.append(name) + + return hook + + for name, module in model.named_modules(): + if name and any(w in name.lower() for w in wanted): + self._hooks.append(module.register_forward_hook(makeHook(name))) + + def recordNodeMapping(self, referenceName: str, currentName: str) -> None: + self._nameMap[referenceName] = currentName + if self.debug: + print(f"Registered mapping: {referenceName} → {currentName}") + + def setReferenceTensors(self) -> None: + self._reference = {k: v.clone() for k, v in self._current.items()} + self._referenceOrder = list(self._executionOrder) + + def compareTensors(self) -> Dict[str, Dict]: + if self._reference is None: + raise RuntimeError("setReferenceTensors has not been called") + + results: Dict[str, Dict] = OrderedDict() + for refName, refTensor in self._reference.items(): + if refName in self._ignore: + continue + + curName = self._nameMap.get(refName, refName) + if curName not in self._current: + results[refName] = {"match": False, "error": f"missing '{curName}'"} + continue + + curTensor = self._current[curName] + equal = torch.equal(refTensor, curTensor) + diffMask = refTensor != curTensor + + results[refName] = { + "match": equal, + "mapped": curName != refName, + "current_name": curName, + "shape": tuple(refTensor.shape), + "diff_count": diffMask.sum().item() if not equal else 0, + "diff_mask": diffMask, + "ref_tensor": refTensor, + "cur_tensor": curTensor, + } + return results + + def _topDifferences( + self, ref: torch.Tensor, cur: torch.Tensor, diffMask: torch.Tensor + ) -> List[str]: + maskFlat = diffMask.view(-1).bool() + if maskFlat.sum() == 0: + return [] + + absDiff = (ref - cur).abs().view(-1)[maskFlat] + unique, counts = torch.unique(absDiff, return_counts=True) + order = counts.argsort(descending=True) + + lines: List[str] = [] + for idx in order[:5]: + delta = unique[idx].item() + count = counts[idx].item() + sampleIndex = (absDiff == delta).nonzero(as_tuple=False)[0].item() + globalIndex = maskFlat.nonzero(as_tuple=False)[sampleIndex].item() + beforeValue = ref.view(-1)[globalIndex].item() + afterValue = cur.view(-1)[globalIndex].item() + + lines.append( + f" · Δ={delta:.6f} ({count} values) e.g. idx {globalIndex}: " + f"{beforeValue:.6f} → {afterValue:.6f}" + ) + return lines + + def printComparisonResults(self, results: Dict[str, Dict]) -> None: + if not results: + print("No comparison data available.") + return + + matches = sum(1 for r in results.values() if r["match"]) + total = len(results) + + print( + f"Compared {total}: " + f"{cc.wrap(str(matches) + ' equal', cc.green)}, " + f"{cc.wrap(str(total - matches) + ' different', cc.red)}\n" + ) + + orderedNames = getattr(self, "_referenceOrder", list(results.keys())) + for name in orderedNames: + if name not in results: + continue + + res = results[name] + statusColor = cc.green if res["match"] else cc.red + statusTag = cc.wrap("[OK]" if res["match"] else "[DIFF]", statusColor) + mappedNote = f" → {res['current_name']}" if res["mapped"] else "" + + print(f" {statusTag} {name}{mappedNote} | shape {res['shape']}") + if res["match"]: + continue + + if "error" in res: + print(cc.wrap(f" {res['error']}", cc.yellow)) + continue + + diffCount = res["diff_count"] + totalValues = torch.tensor(res["shape"]).prod().item() + percentage = diffCount / totalValues * 100 + absDiff = (res["ref_tensor"] - res["cur_tensor"]).abs() + nonZero = absDiff[absDiff > 0] + minDiff = nonZero.min().item() if nonZero.numel() else 0.0 + + print(f" Max diff: {absDiff.max().item():.8f}") + print(f" Min diff: {minDiff:.8f}") + print(f" Mean diff: {absDiff.mean().item():.8f}") + print( + f" Total differing values: {diffCount} of {totalValues} ({percentage:.4f}%)" + ) + + topLines = self._topDifferences( + res["ref_tensor"], res["cur_tensor"], res["diff_mask"] + ) + if topLines: + print(" Most common differences (up to 5):") + for line in topLines: + print(line) diff --git a/DeepQuant/__init__.py b/DeepQuant/__init__.py new file mode 100644 index 0000000..6ac381f --- /dev/null +++ b/DeepQuant/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +# FBRANCASI: Workaround for PyTorch/FX API change: ensure private alias exists +import torch.fx.node as _fx_node + +if not hasattr(_fx_node.Node, "_Node__update_args_kwargs"): + _fx_node.Node._Node__update_args_kwargs = _fx_node.Node._update_args_kwargs + +from DeepQuant.Export import brevitasToTrueQuant + +__all__ = ["brevitasToTrueQuant"] diff --git a/Tests/TestConv.py b/Tests/TestConv.py index 011612c..2d5a135 100644 --- a/Tests/TestConv.py +++ b/Tests/TestConv.py @@ -5,22 +5,23 @@ # Victor Jung # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class QuantConvNet(nn.Module): + """Simple quantized CNN with a single conv layer.""" - convAndLinQuantParams = { + convQuantParams = { "bias": True, "weight_bit_width": 4, "bias_quant": Int32Bias, @@ -30,31 +31,26 @@ class QuantConvNet(nn.Module): "return_quant_tensor": True, } - def __init__(self, in_channels: int = 1) -> None: + def __init__(self, inChannels: int = 1) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) - self.conv1 = qnn.QuantConv2d( - in_channels=in_channels, + in_channels=inChannels, out_channels=16, kernel_size=3, padding=1, - **QuantConvNet.convAndLinQuantParams + **QuantConvNet.convQuantParams, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.inputQuant(x) x = self.conv1(x) - return x @pytest.mark.SingleLayerTests def deepQuantTestConv() -> None: - torch.manual_seed(42) - model = QuantConvNet().eval() sampleInput = torch.randn(1, 1, 28, 28) - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True) diff --git a/Tests/TestLinear.py b/Tests/TestLinear.py index 675653f..e4c17b5 100644 --- a/Tests/TestLinear.py +++ b/Tests/TestLinear.py @@ -4,33 +4,28 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest - -### PyTorch Imports ### import torch import torch.nn as nn - -### Brevitas Import ### -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class QuantLinearNet(nn.Module): + """Simple quantized network with a single linear layer.""" - def __init__(self, in_features: int = 16, hidden_features: int = 32) -> None: + def __init__(self, inFeatures: int = 16, hiddenFeatures: int = 32) -> None: super().__init__() - self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) - self.linear1 = qnn.QuantLinear( - in_features=in_features, - out_features=hidden_features, + in_features=inFeatures, + out_features=hiddenFeatures, bias=True, weight_bit_width=4, bias_quant=Int32Bias, @@ -41,19 +36,14 @@ def __init__(self, in_features: int = 16, hidden_features: int = 32) -> None: ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.inputQuant(x) x = self.linear1(x) - return x @pytest.mark.SingleLayerTests def deepQuantTestLinear() -> None: - torch.manual_seed(42) - model = QuantLinearNet().eval() sampleInput = torch.randn(1, 4, 16) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True) diff --git a/Tests/TestMHSA.py b/Tests/TestMHSA.py index d5be3a9..448057c 100644 --- a/Tests/TestMHSA.py +++ b/Tests/TestMHSA.py @@ -4,39 +4,34 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn -from torch import Tensor -from DeepQuant.ExportBrevitas import exportBrevitas - from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, Uint8ActPerTensorFloat, ) +from torch import Tensor + +from DeepQuant import brevitasToTrueQuant class QuantMHSANet(nn.Module): + """Simple quantized network with multi-head self-attention.""" - def __init__(self, embed_dim: int, num_heads: int) -> None: - """ - Args: - embed_dim: The dimension of each embedding vector. - num_heads: The number of attention heads. - """ + def __init__(self, embedDim: int, numHeads: int) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) self.mha = qnn.QuantMultiheadAttention( - embed_dim=embed_dim, - num_heads=num_heads, + embed_dim=embedDim, + num_heads=numHeads, dropout=0.0, bias=True, - packed_in_proj=False, # separate Q, K, V - batch_first=False, # expects (sequence, batch, embed_dim) + packed_in_proj=False, # FBRANCASI: separate Q, K, V + batch_first=False, # FBRANCASI: expects (sequence, batch, embed_dim) in_proj_input_quant=Int8ActPerTensorFloat, in_proj_weight_quant=Int8WeightPerTensorFloat, in_proj_bias_quant=Int32Bias, @@ -51,16 +46,6 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: ) def forward(self, x: Tensor) -> Tensor: - """ - Forward pass that first quantizes the input, then applies multi-head attention. - - Args: - x: Input tensor of shape [sequence_len, batch_size, embed_dim]. - - Returns: - A tuple (output, None) as per the Brevitas MHA API, where output has shape - [sequence_len, batch_size, embed_dim]. - """ x = self.inputQuant(x) out = self.mha(x, x, x) return out @@ -68,10 +53,7 @@ def forward(self, x: Tensor) -> Tensor: @pytest.mark.SingleLayerTests def deepQuantTestMHSA() -> None: - torch.manual_seed(42) - - model = QuantMHSANet(embed_dim=16, num_heads=4).eval() + model = QuantMHSANet(embedDim=16, numHeads=4).eval() sampleInput = torch.randn(10, 2, 16) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput) diff --git a/Tests/TestMobileNetV3Small.py b/Tests/TestMobileNetV3Small.py index 7a36392..308e3fc 100644 --- a/Tests/TestMobileNetV3Small.py +++ b/Tests/TestMobileNetV3Small.py @@ -2,39 +2,30 @@ # Licensed under the Apache License, Version 2.0, see LICENSE for details. # SPDX-License-Identifier: Apache-2.0 # -# Victor Juing +# Victor Jung +import brevitas.nn as qnn import pytest import torch import torch.nn as nn import torchvision.models as models -from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool -import brevitas.nn as qnn +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) -from brevitas.graph.quantize import quantize -from DeepQuant.ExportBrevitas import exportBrevitas +from DeepQuant import brevitasToTrueQuant def prepareMBNetV3Model() -> nn.Module: - """ - Prepare a quantized MobileNetV3Small model for testing. - Steps: - 1) Load the torchvision MobileNetV3Small. - 2) Convert it to eval mode. - 3) Preprocess and adapt average pooling. - 4) Quantize it using Brevitas. - - Returns: - A quantized MobileNetV3Small model ready for export tests. - """ - baseModel = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1) + """Prepare a quantized MobileNetV3Small model for testing.""" + baseModel = models.mobilenet_v3_small( + weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1 + ) baseModel = baseModel.eval() computeLayerMap = { @@ -99,9 +90,7 @@ def prepareMBNetV3Model() -> nn.Module: baseModel = preprocess_for_quantize( baseModel, equalize_iters=20, equalize_scale_computation="range" ) - baseModel = AdaptiveAvgPoolToAvgPool().apply( - baseModel, torch.ones(1, 3, 224, 224) - ) + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, torch.ones(1, 3, 224, 224)) quantizedModel = quantize( graph_model=baseModel, @@ -115,10 +104,7 @@ def prepareMBNetV3Model() -> nn.Module: @pytest.mark.ModelTests def deepQuantTestMobileNetV3Small() -> None: - torch.manual_seed(42) - - quantizedModel = prepareMBNetV3Model() + model = prepareMBNetV3Model() sampleInput = torch.randn(1, 3, 224, 224) - - exportBrevitas(quantizedModel, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True) diff --git a/Tests/TestResNet18.py b/Tests/TestResNet18.py index 3b62a06..1b41ecd 100644 --- a/Tests/TestResNet18.py +++ b/Tests/TestResNet18.py @@ -4,39 +4,98 @@ # # Federico Brancasi +import tarfile +import urllib.request +from pathlib import Path +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import torchvision.models as models -from brevitas.graph.quantize import preprocess_for_quantize +import torchvision +import torchvision.transforms as transforms +from brevitas.graph.calibrate import calibration_mode from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool -import brevitas.nn as qnn +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) -from brevitas.graph.quantize import quantize +from torch.utils.data import DataLoader, Subset +from torchvision.datasets import ImageFolder +from tqdm import tqdm + +from DeepQuant import brevitasToTrueQuant + + +def evaluateModel(model, dataLoader, evalDevice, name="Model"): + model.eval() + correctTop1 = 0 + correctTop5 = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): + isTQ = "TQ" in name + + if isTQ: + # FBRANCASI: Process different batches for the TQ model + for i in range(inputs.size(0)): + singleInput = inputs[i : i + 1].to(evalDevice) + singleOutput = model(singleInput) + + _, predicted = singleOutput.max(1) + if predicted.item() == targets[i].item(): + correctTop1 += 1 + + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) + if targets[i].item() in top5Pred[0].cpu().numpy(): + correctTop5 += 1 + + total += 1 + else: + inputs = inputs.to(evalDevice) + targets = targets.to(evalDevice) + output = model(inputs) + + _, predicted = output.max(1) + correctTop1 += (predicted == targets).sum().item() -from DeepQuant.ExportBrevitas import exportBrevitas + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): + if targets[i] in top5Pred[i]: + correctTop5 += 1 + total += targets.size(0) -def prepareResnet18Model() -> nn.Module: - """ - Prepare a quantized ResNet18 model for testing. - Steps: - 1) Load the torchvision ResNet18. - 2) Convert it to eval mode. - 3) Preprocess and adapt average pooling. - 4) Quantize it using Brevitas. + top1Accuracy = 100.0 * correctTop1 / total + top5Accuracy = 100.0 * correctTop5 / total + + print( + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " + f"Top-5 Accuracy: {top5Accuracy:.2f}%" + ) + + return top1Accuracy, top5Accuracy - Returns: - A quantized ResNet18 model ready for export tests. - """ - baseModel = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) - baseModel = baseModel.eval() + +def calibrateModel(model, calibLoader): + model.eval() + with torch.no_grad(), calibration_mode(model): + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): + inputs = inputs.to("cpu") + model(inputs) + print("Calibration completed.") + + +def prepareFQResNet18(): + """Prepare a fake-quantized (FQ) ResNet18 model.""" + baseModel = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + ) + baseModel = baseModel.eval().to("cpu") computeLayerMap = { nn.Conv2d: ( @@ -97,29 +156,133 @@ def prepareResnet18Model() -> nn.Module: ), } + dummyInput = torch.ones(1, 3, 224, 224).to("cpu") + + print("Preprocessing model for quantization...") baseModel = preprocess_for_quantize( baseModel, equalize_iters=20, equalize_scale_computation="range" ) - baseModel = AdaptiveAvgPoolToAvgPool().apply( - baseModel, torch.ones(1, 3, 224, 224) - ) - quantizedResnet = quantize( + print("Converting AdaptiveAvgPool to AvgPool...") + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput) + + print("Quantizing model...") + FQModel = quantize( graph_model=baseModel, compute_layer_map=computeLayerMap, quant_act_map=quantActMap, quant_identity_map=quantIdentityMap, ) - return quantizedResnet + return FQModel @pytest.mark.ModelTests def deepQuantTestResnet18() -> None: + HOME = Path.home() + BASE = HOME / "Documents" / "ImagenetV2" + TAR_URL = ( + "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/" + "imagenetv2-matched-frequency.tar.gz" + ) + TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz" + EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val" - torch.manual_seed(42) + if not TAR_PATH.exists(): + BASE.mkdir(parents=True, exist_ok=True) + print(f"Downloading ImageNetV2 from {TAR_URL}...") + urllib.request.urlretrieve(TAR_URL, TAR_PATH) - quantizedModel = prepareResnet18Model() - sampleInput = torch.randn(1, 3, 224, 224) + if not EXTRACT_DIR.exists(): + print(f"Extracting to {EXTRACT_DIR}...") + with tarfile.open(TAR_PATH, "r:*") as tar: + for member in tqdm(tar.getmembers(), desc="Extracting files"): + tar.extract(member, BASE) + print("Extraction completed.") + + transformsVal = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal) + dataset.classes = sorted(dataset.classes, key=lambda x: int(x)) + dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)} + + newSamples = [] + for path, _ in dataset.samples: + clsName = Path(path).parent.name + newLabel = dataset.class_to_idx[clsName] + newSamples.append((path, newLabel)) + dataset.samples = newSamples + dataset.targets = [s[1] for s in newSamples] + + # FBRANCASI: Optional, reduce number of example for faster validation + DATASET_LIMIT = 256 + dataset = Subset(dataset, list(range(DATASET_LIMIT))) + print(f"Validation dataset size set to {len(dataset)} images.") + + calibLoader = DataLoader( + Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True + ) + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) + + # FBRANCASI: I'm on mac, so mps for me + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("mps" if torch.backends.mps.is_available() else device) + print(f"Using device: {device}") + + originalModel = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + ) + originalModel = originalModel.eval().to(device) + print("Original ResNet18 loaded.") + + print("Evaluating original model...") + originalTop1, originalTop5 = evaluateModel( + originalModel, valLoader, device, "Original ResNet18" + ) + + print("Preparing and quantizing ResNet18...") + FQModel = prepareFQResNet18() + + print("Calibrating FQ model...") + calibrateModel(FQModel, calibLoader) + + print("Evaluating FQ model...") + # FBRANCASI: I'm on mac, mps doesn't work with brevitas + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ResNet18") + + sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu") + TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True) + + numParameters = sum(p.numel() for p in TQModel.parameters()) + print(f"Number of parameters: {numParameters:,}") + + print("Evaluating TQ model...") + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ResNet18") + + print("\nComparison Summary:") + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") + print("-" * 75) + print(f"{'Original ResNet18':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") + print(f"{'FQ ResNet18':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") + print(f"{'TQ ResNet18':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") + print( + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" + ) + print( + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" + ) - exportBrevitas(quantizedModel, sampleInput, debug=True) + if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0: + print( + f"Warning: Large accuracy drop between FQ and TQ models. " + f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, " + f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%" + ) diff --git a/Tests/TestSimpleCNN.py b/Tests/TestSimpleCNN.py index bc755ec..23738c5 100644 --- a/Tests/TestSimpleCNN.py +++ b/Tests/TestSimpleCNN.py @@ -4,29 +4,23 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class SimpleQuantCNN(nn.Module): - """ - A simple quantized CNN that includes: - - Input quantization - - Two QuantConv2d layers with Quantized ReLU - - MaxPool2d - - A final QuantLinear layer - """ + """A simple quantized CNN with two conv layers and a linear layer.""" - convAndLinQuantParams = { + convQuantParams = { "bias": True, "weight_bit_width": 4, "bias_quant": Int32Bias, @@ -36,21 +30,16 @@ class SimpleQuantCNN(nn.Module): "return_quant_tensor": True, } - def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None: - """ - Args: - in_channels: Number of input channels (e.g., 1 for grayscale). - num_classes: Number of output classes for the final linear layer. - """ + def __init__(self, inChannels: int = 1, numClasses: int = 10) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) self.conv1 = qnn.QuantConv2d( - in_channels=in_channels, + in_channels=inChannels, out_channels=16, kernel_size=3, padding=1, - **SimpleQuantCNN.convAndLinQuantParams + **SimpleQuantCNN.convQuantParams, ) self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) self.pool1 = nn.MaxPool2d(kernel_size=2) @@ -60,28 +49,19 @@ def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None: out_channels=32, kernel_size=3, padding=1, - **SimpleQuantCNN.convAndLinQuantParams + **SimpleQuantCNN.convQuantParams, ) self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) self.pool2 = nn.MaxPool2d(kernel_size=2) self.flatten = nn.Flatten() self.fc = qnn.QuantLinear( - in_features=32 * 7 * 7, # If input is 28x28, shape after pooling is 7x7 - out_features=num_classes, - **SimpleQuantCNN.convAndLinQuantParams + in_features=32 * 7 * 7, + out_features=numClasses, + **SimpleQuantCNN.convQuantParams, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the SimpleQuantCNN. - - Args: - x: Input tensor of shape [batch_size, in_channels, height, width]. - - Returns: - A quantized output tensor (batch_size, num_classes). - """ x = self.inputQuant(x) x = self.conv1(x) @@ -99,10 +79,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @pytest.mark.ModelTests def deepQuantTestSimpleCNN() -> None: - torch.manual_seed(42) - model = SimpleQuantCNN().eval() sampleInput = torch.randn(1, 1, 28, 28) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True) diff --git a/Tests/TestSimpleFCNN.py b/Tests/TestSimpleFCNN.py index 33b90f6..c3c7821 100644 --- a/Tests/TestSimpleFCNN.py +++ b/Tests/TestSimpleFCNN.py @@ -4,39 +4,25 @@ # # Federico Brancasi - -import warnings - -warnings.filterwarnings("ignore", category=UserWarning) -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_cuda.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_cudnn.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_mps.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_mkldnn.*") -warnings.filterwarnings( - "ignore", category=UserWarning, message=".*experimental feature.*" -) -warnings.filterwarnings("ignore", category=UserWarning, message=".*deprecated.*") - from pathlib import Path -from tqdm import tqdm +import brevitas.nn as qnn import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader -from torchvision import datasets, transforms - -import brevitas.nn as qnn -from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm import tqdm -from DeepQuant.ExportBrevitas import exportBrevitas +from DeepQuant import brevitasToTrueQuant class SimpleFCNN(nn.Module): @@ -65,7 +51,6 @@ def trainModel( epochs: int = 10, learningRate: float = 0.001, ) -> nn.Module: - """Train the model if no saved weights exist.""" if savePath.exists(): print(f"Loading existing model from {savePath}") @@ -89,7 +74,6 @@ def trainModel( print(f"Epoch [{epoch+1}/{epochs}], Loss: {runningLoss/len(trainLoader):.4f}") - # Evaluate model.eval() correct = 0 total = 0 @@ -102,36 +86,35 @@ def trainModel( print(f"Accuracy on the test set: {100 * correct / total:.2f}%") - # Save model torch.save(model.state_dict(), savePath) print(f"Model saved to {savePath}") return model -def calibrate_model( - model: nn.Module, calib_loader: DataLoader, device: torch.device +def calibrateModel( + model: nn.Module, calibLoader: DataLoader, device: torch.device ) -> None: - """Calibrate the quantized model.""" model.eval() model.to(device) with ( torch.no_grad(), calibration_mode(model), - tqdm(calib_loader, desc="Calibrating") as pbar, + tqdm(calibLoader, desc="Calibrating") as pbar, ): for images, _ in pbar: images = images.to(device) images = images.to(torch.float) model(images) + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EXPORT_FOLDER = Path().cwd() / "Tests" MODEL_PATH = EXPORT_FOLDER / "Models" DATA_PATH = EXPORT_FOLDER / "Data" + def deepQuantTestSimpleFCNN() -> None: - EXPORT_FOLDER.mkdir(parents=True, exist_ok=True) MODEL_PATH.mkdir(parents=True, exist_ok=True) @@ -143,26 +126,21 @@ def deepQuantTestSimpleFCNN() -> None: ] ) - train_dataset = datasets.MNIST( + trainDataset = datasets.MNIST( root=DATA_PATH, train=True, download=True, transform=transform ) - test_dataset = datasets.MNIST( + testDataset = datasets.MNIST( root=DATA_PATH, train=False, download=True, transform=transform ) - trainLoader = DataLoader(train_dataset, batch_size=64, shuffle=True) - testLoader = DataLoader( - test_dataset, batch_size=64, shuffle=False, pin_memory=True - ) + trainLoader = DataLoader(trainDataset, batch_size=64, shuffle=True) + testLoader = DataLoader(testDataset, batch_size=64, shuffle=False, pin_memory=True) - # Train or load model - m = SimpleFCNN() - model = trainModel(m, trainLoader, testLoader, MODEL_PATH / "mnist_model.pth") + model = SimpleFCNN() + model = trainModel(model, trainLoader, testLoader, MODEL_PATH / "mnist_model.pth") - # Prepare for quantization model = preprocess_for_quantize(model) - # Quantization configurations computeLayerMap = { nn.Linear: ( qnn.QuantLinear, @@ -208,7 +186,6 @@ def deepQuantTestSimpleFCNN() -> None: ), } - # Quantize and calibrate modelQuant = quantize( model, compute_layer_map=computeLayerMap, @@ -216,11 +193,9 @@ def deepQuantTestSimpleFCNN() -> None: quant_identity_map=quantIdentityMap, ) - calibrate_model(modelQuant, testLoader, DEVICE) + calibrateModel(modelQuant, testLoader, DEVICE) - # Export and transform sampleInput, _ = next(iter(testLoader)) sampleInput = sampleInput[0:1] - print(f"Sample input shape: {sampleInput.shape}") - exportBrevitas(modelQuant, sampleInput.to(DEVICE), debug=True) + brevitasToTrueQuant(modelQuant, sampleInput.to(DEVICE), debug=True) diff --git a/Tests/TestYOLOv5.py b/Tests/TestYOLOv5.py new file mode 100644 index 0000000..7231492 --- /dev/null +++ b/Tests/TestYOLOv5.py @@ -0,0 +1,127 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) + +from DeepQuant import brevitasToTrueQuant + + +def prepareYOLOv5Backbone() -> nn.Module: + """Prepare a quantized partial YOLOv5 model for testing.""" + from ultralytics import YOLO + + model = YOLO("Models/yolov5nu.pt") + pytorchModel = model.model + + # FBRANCASI: Just first few layers for simplicity + backbone = pytorchModel.model[0:4] + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 4, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 4, + }, + ), + } + + quantActMap = { + nn.SiLU: ( + qnn.QuantReLU, # FBRANCASI: As a substitute for now + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + nn.ReLU: ( + qnn.QuantReLU, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + nn.LeakyReLU: ( + qnn.QuantReLU, # FBRANCASI: As a substitute for now + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + backbone = preprocess_for_quantize( + backbone, equalize_iters=10, equalize_scale_computation="range" + ) + + quantizedModel = quantize( + graph_model=backbone, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + return quantizedModel + + +@pytest.mark.ModelTests +def deepQuantTestYOLOv5(): + torch.manual_seed(42) + quantizedModel = prepareYOLOv5Backbone() + sampleInput = torch.randn(1, 3, 128, 128) + quantizedModel.eval() + brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) diff --git a/conftest.py b/conftest.py deleted file mode 100644 index 950c05e..0000000 --- a/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Pytest configuration file that suppresses specific warnings, including those -related to torch.tensor constant registration in FX tracing. -""" - -import warnings - -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=UserWarning, message="Named tensors.*") -warnings.filterwarnings( - "ignore", category=UserWarning, message=".*__torch_function__.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="Was not able to add assertion.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_cuda' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_cudnn' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_mps' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_mkldnn' is deprecated.*" -) diff --git a/pyproject.toml b/pyproject.toml index 0534afe..b64cb21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "onnx", "onnxoptimizer", "onnxruntime", + "ultralytics", ] [tool.setuptools]