Skip to content

Commit 3ab0cb8

Browse files
committed
Addressing reviews
Signed-off-by: Riyad Islam <[email protected]>
1 parent 56d8ee0 commit 3ab0cb8

File tree

2 files changed

+7
-33
lines changed

2 files changed

+7
-33
lines changed

modelopt/torch/quantization/plugins/pytorch_geometric.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
>>> quantized_model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calibrate)
3434
"""
3535

36-
import torch
3736
from torch_geometric.nn.dense.linear import Linear as PyGLinear
3837

3938
from modelopt.torch.quantization.nn.modules.quant_module import (
@@ -48,7 +47,13 @@ class QuantPyGLinear(QuantLinearConvBase):
4847
4948
PyTorch Geometric uses a custom Linear layer that is functionally equivalent to
5049
torch.nn.Linear but has a different API (in_channels/out_channels instead of
51-
in_features/out_features). This class enables quantization of PyG Linear layers.
50+
in_features/out_features). This class enables quantization of PyG Linear layers
51+
by inheriting from QuantLinearConvBase, which handles all quantization logic.
52+
53+
The quantization is handled automatically by the base classes:
54+
- Input quantization: Handled by QuantInputBase.forward()
55+
- Weight quantization: Handled by QuantLinearConvBase's dynamic weight attribute
56+
- Output quantization: Handled by QuantInputBase.forward()
5257
5358
Note:
5459
Many PyTorch Geometric layers (GCNConv, GATConv, SAGEConv, TransformerConv, etc.)
@@ -58,32 +63,5 @@ class QuantPyGLinear(QuantLinearConvBase):
5863

5964
default_quant_desc_weight = QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW
6065

61-
def forward(self, input, *args, **kwargs):
62-
"""Forward pass with quantization.
63-
64-
Args:
65-
input: Input tensor to the linear layer
66-
*args: Additional positional arguments
67-
**kwargs: Additional keyword arguments
68-
69-
Returns:
70-
Quantized output tensor
71-
"""
72-
# Quantize input activations
73-
input_q = self.input_quantizer(input)
74-
75-
# Quantize weights
76-
weight_q = self.weight_quantizer(self.weight)
77-
78-
# Perform linear operation
79-
output = torch.nn.functional.linear(
80-
input_q,
81-
weight_q,
82-
self.bias if hasattr(self, "bias") and self.bias is not None else None,
83-
)
84-
85-
# Quantize output (typically disabled by default)
86-
return self.output_quantizer(output)
87-
8866

8967
QuantModuleRegistry.register({PyGLinear: "torch_geometric.nn.dense.linear.Linear"})(QuantPyGLinear)

tests/unit/torch/quantization/plugins/test_pytorch_geometric_plugin.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,3 @@ def calibrate(m):
183183
mean_relative_error = relative_error.mean().item()
184184

185185
assert mean_relative_error < 0.1, f"Quantization error too large: {mean_relative_error:.2%}"
186-
187-
188-
if __name__ == "__main__":
189-
pytest.main([__file__])

0 commit comments

Comments
 (0)