Skip to content

Commit 7853bfb

Browse files
authored
Remove Qwen Image Redundant RoPE Cache (huggingface#12452)
Refactor QwenEmbedRope to only use the LRU cache for RoPE caching
1 parent 23ebbb4 commit 7853bfb

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,6 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180180
],
181181
dim=1,
182182
)
183-
self.rope_cache = {}
184183

185184
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
186185
self.scale_rope = scale_rope
@@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000):
195194
freqs = torch.polar(torch.ones_like(freqs), freqs)
196195
return freqs
197196

198-
def forward(self, video_fhw, txt_seq_lens, device):
197+
def forward(
198+
self,
199+
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
200+
txt_seq_lens: List[int],
201+
device: torch.device,
202+
) -> Tuple[torch.Tensor, torch.Tensor]:
199203
"""
200-
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
201-
txt_length: [bs] a list of 1 integers representing the length of the text
204+
Args:
205+
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
206+
A list of 3 integers [frame, height, width] representing the shape of the video.
207+
txt_seq_lens (`List[int]`):
208+
A list of integers of length batch_size representing the length of each text prompt.
209+
device: (`torch.device`):
210+
The device on which to perform the RoPE computation.
202211
"""
203212
if self.pos_freqs.device != device:
204213
self.pos_freqs = self.pos_freqs.to(device)
@@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
213222
max_vid_index = 0
214223
for idx, fhw in enumerate(video_fhw):
215224
frame, height, width = fhw
216-
rope_key = f"{idx}_{height}_{width}"
217-
218-
if not torch.compiler.is_compiling():
219-
if rope_key not in self.rope_cache:
220-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
221-
video_freq = self.rope_cache[rope_key]
222-
else:
223-
video_freq = self._compute_video_freqs(frame, height, width, idx)
225+
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
226+
video_freq = self._compute_video_freqs(frame, height, width, idx)
224227
video_freq = video_freq.to(device)
225228
vid_freqs.append(video_freq)
226229

@@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
235238

236239
return vid_freqs, txt_freqs
237240

238-
@functools.lru_cache(maxsize=None)
239-
def _compute_video_freqs(self, frame, height, width, idx=0):
241+
@functools.lru_cache(maxsize=128)
242+
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
240243
seq_lens = frame * height * width
241244
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
242245
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

0 commit comments

Comments
 (0)