@@ -999,6 +999,7 @@ def __init__(
999999 # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
10001000 # number of temporal frames.
10011001 self .num_latent_frames_batch_size = 2
1002+ self .num_sample_frames_batch_size = 8
10021003
10031004 # We make the minimum height and width of sample for tiling half that of the generally supported
10041005 self .tile_sample_min_height = sample_height // 2
@@ -1081,6 +1082,29 @@ def disable_slicing(self) -> None:
10811082 """
10821083 self .use_slicing = False
10831084
1085+ def _encode (self , x : torch .Tensor ) -> torch .Tensor :
1086+ batch_size , num_channels , num_frames , height , width = x .shape
1087+
1088+ if self .use_tiling and (width > self .tile_sample_min_width or height > self .tile_sample_min_height ):
1089+ return self .tiled_encode (x )
1090+
1091+ frame_batch_size = self .num_sample_frames_batch_size
1092+ enc = []
1093+ for i in range (num_frames // frame_batch_size ):
1094+ remaining_frames = num_frames % frame_batch_size
1095+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames )
1096+ end_frame = frame_batch_size * (i + 1 ) + remaining_frames
1097+ x_intermediate = x [:, :, start_frame :end_frame ]
1098+ x_intermediate = self .encoder (x_intermediate )
1099+ if self .quant_conv is not None :
1100+ x_intermediate = self .quant_conv (x_intermediate )
1101+ enc .append (x_intermediate )
1102+
1103+ self ._clear_fake_context_parallel_cache ()
1104+ enc = torch .cat (enc , dim = 2 )
1105+
1106+ return enc
1107+
10841108 @apply_forward_hook
10851109 def encode (
10861110 self , x : torch .Tensor , return_dict : bool = True
@@ -1094,13 +1118,17 @@ def encode(
10941118 Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
10951119
10961120 Returns:
1097- The latent representations of the encoded images . If `return_dict` is True, a
1121+ The latent representations of the encoded videos . If `return_dict` is True, a
10981122 [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
10991123 """
1100- h = self .encoder (x )
1101- if self .quant_conv is not None :
1102- h = self .quant_conv (h )
1124+ if self .use_slicing and x .shape [0 ] > 1 :
1125+ encoded_slices = [self ._encode (x_slice ) for x_slice in x .split (1 )]
1126+ h = torch .cat (encoded_slices )
1127+ else :
1128+ h = self ._encode (x )
1129+
11031130 posterior = DiagonalGaussianDistribution (h )
1131+
11041132 if not return_dict :
11051133 return (posterior ,)
11061134 return AutoencoderKLOutput (latent_dist = posterior )
@@ -1172,6 +1200,75 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
11721200 )
11731201 return b
11741202
1203+ def tiled_encode (self , x : torch .Tensor ) -> torch .Tensor :
1204+ r"""Encode a batch of images using a tiled encoder.
1205+
1206+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1207+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1208+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1209+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1210+ output, but they should be much less noticeable.
1211+
1212+ Args:
1213+ x (`torch.Tensor`): Input batch of videos.
1214+
1215+ Returns:
1216+ `torch.Tensor`:
1217+ The latent representation of the encoded videos.
1218+ """
1219+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1220+ batch_size , num_channels , num_frames , height , width = x .shape
1221+
1222+ overlap_height = int (self .tile_sample_min_height * (1 - self .tile_overlap_factor_height ))
1223+ overlap_width = int (self .tile_sample_min_width * (1 - self .tile_overlap_factor_width ))
1224+ blend_extent_height = int (self .tile_latent_min_height * self .tile_overlap_factor_height )
1225+ blend_extent_width = int (self .tile_latent_min_width * self .tile_overlap_factor_width )
1226+ row_limit_height = self .tile_latent_min_height - blend_extent_height
1227+ row_limit_width = self .tile_latent_min_width - blend_extent_width
1228+ frame_batch_size = self .num_sample_frames_batch_size
1229+
1230+ # Split x into overlapping tiles and encode them separately.
1231+ # The tiles have an overlap to avoid seams between tiles.
1232+ rows = []
1233+ for i in range (0 , height , overlap_height ):
1234+ row = []
1235+ for j in range (0 , width , overlap_width ):
1236+ time = []
1237+ for k in range (num_frames // frame_batch_size ):
1238+ remaining_frames = num_frames % frame_batch_size
1239+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames )
1240+ end_frame = frame_batch_size * (k + 1 ) + remaining_frames
1241+ tile = x [
1242+ :,
1243+ :,
1244+ start_frame :end_frame ,
1245+ i : i + self .tile_sample_min_height ,
1246+ j : j + self .tile_sample_min_width ,
1247+ ]
1248+ tile = self .encoder (tile )
1249+ if self .quant_conv is not None :
1250+ tile = self .quant_conv (tile )
1251+ time .append (tile )
1252+ self ._clear_fake_context_parallel_cache ()
1253+ row .append (torch .cat (time , dim = 2 ))
1254+ rows .append (row )
1255+
1256+ result_rows = []
1257+ for i , row in enumerate (rows ):
1258+ result_row = []
1259+ for j , tile in enumerate (row ):
1260+ # blend the above tile and the left tile
1261+ # to the current tile and add the current tile to the result row
1262+ if i > 0 :
1263+ tile = self .blend_v (rows [i - 1 ][j ], tile , blend_extent_height )
1264+ if j > 0 :
1265+ tile = self .blend_h (row [j - 1 ], tile , blend_extent_width )
1266+ result_row .append (tile [:, :, :, :row_limit_height , :row_limit_width ])
1267+ result_rows .append (torch .cat (result_row , dim = 4 ))
1268+
1269+ enc = torch .cat (result_rows , dim = 3 )
1270+ return enc
1271+
11751272 def tiled_decode (self , z : torch .Tensor , return_dict : bool = True ) -> Union [DecoderOutput , torch .Tensor ]:
11761273 r"""
11771274 Decode a batch of images using a tiled decoder.
0 commit comments