diff --git a/scripts/ipadapter/plugable_ipadapter.py b/scripts/ipadapter/plugable_ipadapter.py index 72c0e6652..012bf53fc 100644 --- a/scripts/ipadapter/plugable_ipadapter.py +++ b/scripts/ipadapter/plugable_ipadapter.py @@ -193,8 +193,14 @@ def apply_effective_region_mask(self, out: torch.Tensor) -> torch.Tensor: assert ( factor > 0 ), f"{factor}, {sequence_length}, {self.latent_width}, {self.latent_height}" - mask_h = int(self.latent_height * factor) - mask_w = int(self.latent_width * factor) + mask_h = round(self.latent_height * factor) + mask_w = round(self.latent_width * factor) + + # Ensure mask_h * mask_w equals sequence_length + if mask_h * mask_w < sequence_length: + mask_w += 1 + elif mask_h * mask_w > sequence_length: + mask_w -= 1 mask = torch.nn.functional.interpolate( self.effective_region_mask.to(out.device),