3333 >>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
3434"""
3535
36- import torch
3736from torch_geometric .nn .dense .linear import Linear as PyGLinear
3837
3938from modelopt .torch .quantization .nn .modules .quant_module import (
@@ -48,7 +47,13 @@ class QuantPyGLinear(QuantLinearConvBase):
4847
4948 PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
5049 torch.nn.Linear but has a different API (in_channels/out_channels instead of
51- in_features/out_features). This class enables quantization of PyG Linear layers.
50+ in_features/out_features). This class enables quantization of PyG Linear layers
51+ by inheriting from QuantLinearConvBase, which handles all quantization logic.
52+
53+ The quantization is handled automatically by the base classes:
54+ - Input quantization: Handled by QuantInputBase.forward()
55+ - Weight quantization: Handled by QuantLinearConvBase's dynamic weight attribute
56+ - Output quantization: Handled by QuantInputBase.forward()
5257
5358 Note:
5459 Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
@@ -58,32 +63,5 @@ class QuantPyGLinear(QuantLinearConvBase):
5863
5964 default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
6065
61- def forward (self , input , * args , ** kwargs ):
62- """Forward pass with quantization.
63-
64- Args:
65- input: Input tensor to the linear layer
66- *args: Additional positional arguments
67- **kwargs: Additional keyword arguments
68-
69- Returns:
70- Quantized output tensor
71- """
72- # Quantize input activations
73- input_q = self .input_quantizer (input )
74-
75- # Quantize weights
76- weight_q = self .weight_quantizer (self .weight )
77-
78- # Perform linear operation
79- output = torch .nn .functional .linear (
80- input_q ,
81- weight_q ,
82- self .bias if hasattr (self , "bias" ) and self .bias is not None else None ,
83- )
84-
85- # Quantize output (typically disabled by default)
86- return self .output_quantizer (output )
87-
8866
8967QuantModuleRegistry .register ({PyGLinear : "torch_geometric.nn.dense.linear.Linear" })(QuantPyGLinear )
0 commit comments