@@ -72,39 +72,33 @@ def fp8_linear(
72
72
) -> torch .Tensor :
73
73
device = input .device
74
74
origin_dtype = input .dtype
75
- scale_a = 1.0
75
+ origin_shape = input .shape
76
+ input = input .reshape (- 1 , origin_shape [- 1 ])
77
+
78
+ x_max = torch .max (torch .abs (input ), dim = - 1 , keepdim = True ).values
79
+ fp8_max = 448.0
76
80
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
77
81
# To avoid overflow and ensure numerical compatibility during FP8 computation,
78
82
# we scale down the input by 2.0 in advance.
79
83
# This scaling will be compensated later during the final result scaling.
80
84
if DTYPE_FP8 == torch .float8_e4m3fnuz :
81
- scale_a = 2.0
82
- input = input / scale_a
85
+ fp8_max = fp8_max / 2.0
86
+ scale_a = torch .clamp (x_max / fp8_max , min = 1.0 ).float ().to (device = device )
87
+ scale_b = torch .ones ((weight .shape [0 ], 1 )).float ().to (device = device )
88
+ input = input / scale_a
83
89
input = input .to (DTYPE_FP8 )
84
90
weight = weight .to (DTYPE_FP8 )
85
91
86
- if len (input .shape ) > 2 :
87
- origin_shape = input .shape
88
- input = input .reshape (- 1 , origin_shape [- 1 ])
89
- result = torch ._scaled_mm (
90
- input ,
91
- weight .T ,
92
- scale_a = torch .tensor (scale_a ).to (device = device ),
93
- scale_b = torch .tensor (1.0 ).to (device = device ),
94
- bias = bias ,
95
- out_dtype = origin_dtype ,
96
- )
97
- new_shape = origin_shape [:- 1 ] + result .shape [- 1 :]
98
- result = result .reshape (new_shape )
99
- else :
100
- result = torch ._scaled_mm (
101
- input ,
102
- weight .T ,
103
- scale_a = torch .tensor (scale_a ).to (device = device ),
104
- scale_b = torch .tensor (1.0 ).to (device = device ),
105
- bias = bias ,
106
- out_dtype = origin_dtype ,
107
- )
92
+ result = torch ._scaled_mm (
93
+ input ,
94
+ weight .T ,
95
+ scale_a = scale_a ,
96
+ scale_b = scale_b .T ,
97
+ bias = bias ,
98
+ out_dtype = origin_dtype ,
99
+ )
100
+ new_shape = origin_shape [:- 1 ] + result .shape [- 1 :]
101
+ result = result .reshape (new_shape )
108
102
return result
109
103
110
104
F .linear = fp8_linear
0 commit comments