@@ -180,7 +180,6 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180180 ],
181181 dim = 1 ,
182182 )
183- self .rope_cache = {}
184183
185184 # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
186185 self .scale_rope = scale_rope
@@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000):
195194 freqs = torch .polar (torch .ones_like (freqs ), freqs )
196195 return freqs
197196
198- def forward (self , video_fhw , txt_seq_lens , device ):
197+ def forward (
198+ self ,
199+ video_fhw : Union [Tuple [int , int , int ], List [Tuple [int , int , int ]]],
200+ txt_seq_lens : List [int ],
201+ device : torch .device ,
202+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
199203 """
200- Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
201- txt_length: [bs] a list of 1 integers representing the length of the text
204+ Args:
205+ video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
206+ A list of 3 integers [frame, height, width] representing the shape of the video.
207+ txt_seq_lens (`List[int]`):
208+ A list of integers of length batch_size representing the length of each text prompt.
209+ device: (`torch.device`):
210+ The device on which to perform the RoPE computation.
202211 """
203212 if self .pos_freqs .device != device :
204213 self .pos_freqs = self .pos_freqs .to (device )
@@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
213222 max_vid_index = 0
214223 for idx , fhw in enumerate (video_fhw ):
215224 frame , height , width = fhw
216- rope_key = f"{ idx } _{ height } _{ width } "
217-
218- if not torch .compiler .is_compiling ():
219- if rope_key not in self .rope_cache :
220- self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width , idx )
221- video_freq = self .rope_cache [rope_key ]
222- else :
223- video_freq = self ._compute_video_freqs (frame , height , width , idx )
225+ # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
226+ video_freq = self ._compute_video_freqs (frame , height , width , idx )
224227 video_freq = video_freq .to (device )
225228 vid_freqs .append (video_freq )
226229
@@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
235238
236239 return vid_freqs , txt_freqs
237240
238- @functools .lru_cache (maxsize = None )
239- def _compute_video_freqs (self , frame , height , width , idx = 0 ) :
241+ @functools .lru_cache (maxsize = 128 )
242+ def _compute_video_freqs (self , frame : int , height : int , width : int , idx : int = 0 ) -> torch . Tensor :
240243 seq_lens = frame * height * width
241244 freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
242245 freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
0 commit comments