Skip to content

Commit 8b056d8

Browse files
per-token scaling for fp8 linear (#160)
* per-token scaling for fp8 linear * Update diffsynth_engine/utils/fp8_linear.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 73e9179 commit 8b056d8

File tree

1 file changed

+19
-25
lines changed

1 file changed

+19
-25
lines changed

diffsynth_engine/utils/fp8_linear.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,39 +72,33 @@ def fp8_linear(
7272
) -> torch.Tensor:
7373
device = input.device
7474
origin_dtype = input.dtype
75-
scale_a = 1.0
75+
origin_shape = input.shape
76+
input = input.reshape(-1, origin_shape[-1])
77+
78+
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
79+
fp8_max = 448.0
7680
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
7781
# To avoid overflow and ensure numerical compatibility during FP8 computation,
7882
# we scale down the input by 2.0 in advance.
7983
# This scaling will be compensated later during the final result scaling.
8084
if DTYPE_FP8 == torch.float8_e4m3fnuz:
81-
scale_a = 2.0
82-
input = input / scale_a
85+
fp8_max = fp8_max / 2.0
86+
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
87+
scale_b = torch.ones((weight.shape[0], 1)).float().to(device=device)
88+
input = input / scale_a
8389
input = input.to(DTYPE_FP8)
8490
weight = weight.to(DTYPE_FP8)
8591

86-
if len(input.shape) > 2:
87-
origin_shape = input.shape
88-
input = input.reshape(-1, origin_shape[-1])
89-
result = torch._scaled_mm(
90-
input,
91-
weight.T,
92-
scale_a=torch.tensor(scale_a).to(device=device),
93-
scale_b=torch.tensor(1.0).to(device=device),
94-
bias=bias,
95-
out_dtype=origin_dtype,
96-
)
97-
new_shape = origin_shape[:-1] + result.shape[-1:]
98-
result = result.reshape(new_shape)
99-
else:
100-
result = torch._scaled_mm(
101-
input,
102-
weight.T,
103-
scale_a=torch.tensor(scale_a).to(device=device),
104-
scale_b=torch.tensor(1.0).to(device=device),
105-
bias=bias,
106-
out_dtype=origin_dtype,
107-
)
92+
result = torch._scaled_mm(
93+
input,
94+
weight.T,
95+
scale_a=scale_a,
96+
scale_b=scale_b.T,
97+
bias=bias,
98+
out_dtype=origin_dtype,
99+
)
100+
new_shape = origin_shape[:-1] + result.shape[-1:]
101+
result = result.reshape(new_shape)
108102
return result
109103

110104
F.linear = fp8_linear

0 commit comments

Comments
 (0)