diff --git a/modules/models/sd35/mmditx.py b/modules/models/sd35/mmditx.py index d558bc333..b1b043a2f 100644 --- a/modules/models/sd35/mmditx.py +++ b/modules/models/sd35/mmditx.py @@ -904,7 +904,10 @@ def forward( hw = x.shape[-2:] # The line below should be unnecessary when full integrated. x = x[:1,:16,:,:] - x = self.x_embedder(x) + self.cropped_pos_embed(hw).to("cuda") + # Workaround for unable to promote FP8 error with FP8 models + x_embed = self.x_embedder(x).to(torch.float32) + pos_embed = self.cropped_pos_embed(hw).to(torch.float32).to("cuda") + x = x_embed + pos_embed c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None: y = self.y_embedder(y) # (N, D)