We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9a38fab commit 11d22e0Copy full SHA for 11d22e0
src/diffusers/models/transformers/transformer_wan.py
@@ -180,6 +180,7 @@ def __init__(
180
added_kv_proj_dim: Optional[int] = None,
181
cross_attention_dim_head: Optional[int] = None,
182
processor=None,
183
+ is_cross_attention=None,
184
):
185
super().__init__()
186
@@ -207,6 +208,8 @@ def __init__(
207
208
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
209
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
210
211
+ self.is_cross_attention = cross_attention_dim_head is not None
212
+
213
self.set_processor(processor)
214
215
def fuse_projections(self):
0 commit comments