1010from torch import nn
1111from transformers import BatchFeature , PretrainedConfig
1212from transformers .models .cohere2_vision import Cohere2VisionConfig
13+ from transformers .models .cohere2_vision .image_processing_cohere2_vision_fast import ( # noqa: E501
14+ get_optimal_tiled_canvas )
1315from transformers .models .cohere2_vision .processing_cohere2_vision import (
1416 Cohere2VisionProcessor )
1517
@@ -150,14 +152,46 @@ def get_image_size_with_most_features(self) -> ImageSize:
150152 max_patches = image_processor .max_patches
151153 return ImageSize (height = height * max_patches , width = width )
152154
153- def get_num_patches (self , image_width : int , image_height : int ) -> int :
155+ def get_num_patches (
156+ self ,
157+ * ,
158+ image_width : int ,
159+ image_height : int ,
160+ processor : Optional [Cohere2VisionProcessor ],
161+ ) -> int :
154162 """
155163 Calculate the number of image patches for a given image.
156164 Uses the HF processor to determine the actual number of patches.
157165 """
158- return self .get_hf_processor (
159- ).image_processor .get_number_of_image_patches (image_height ,
160- image_width , {})
166+ if processor is None :
167+ processor = self .get_hf_processor ()
168+
169+ image_processor = processor .image_processor
170+
171+ # The current implementation of get_number_of_image_patches
172+ # is incorrect, so we patch it here.
173+ # return image_processor.get_number_of_image_patches(image_height,
174+ # image_width, {})
175+
176+ min_patches = image_processor .min_patches
177+ max_patches = image_processor .max_patches
178+ patch_size = image_processor .size
179+ crop_to_patches = image_processor .crop_to_patches
180+
181+ if not crop_to_patches :
182+ return 1
183+
184+ num_columns , num_rows = get_optimal_tiled_canvas (
185+ (image_height , image_width ),
186+ (patch_size ["height" ], patch_size ["width" ]),
187+ min_patches ,
188+ max_patches ,
189+ )
190+ num_patches = num_columns * num_rows
191+ if num_patches > 1 :
192+ num_patches += 1 # Thumbnail image
193+
194+ return num_patches
161195
162196
163197class Cohere2VisionDummyInputsBuilder (
@@ -208,6 +242,8 @@ def _call_hf_processor(
208242 # Ensure num_patches is available for proper tensor splitting
209243 if "num_patches" not in processed_outputs and (
210244 images := mm_data .get ("images" )) is not None :
245+ hf_processor = self .info .get_hf_processor (** mm_kwargs )
246+
211247 # Fallback calculation if HF processor didn't provide num_patches
212248 parsed_images = self ._get_data_parser ().parse_mm_data ({
213249 "image" :
@@ -217,8 +253,9 @@ def _call_hf_processor(
217253 num_patches = [
218254 self .info .get_num_patches (
219255 image_width = parsed_images .get_image_size (i ).width ,
220- image_height = parsed_images .get_image_size (i ).height )
221- for i in range (len (parsed_images ))
256+ image_height = parsed_images .get_image_size (i ).height ,
257+ processor = hf_processor ,
258+ ) for i in range (len (parsed_images ))
222259 ]
223260 processed_outputs ["num_patches" ] = torch .tensor (num_patches )
224261
@@ -245,25 +282,25 @@ def _get_prompt_updates(
245282 ) -> Sequence [PromptUpdate ]:
246283 hf_processor = self .info .get_hf_processor (** hf_processor_mm_kwargs )
247284 image_token = hf_processor .image_token
285+ img_tokens_per_tile = int (hf_processor .patch_size ** 2 )
248286 img_line_break_token = hf_processor .img_line_break_token
249287 boi_token = hf_processor .boi_token
250288 eoi_token = hf_processor .eoi_token
251289
252290 def get_replacement (item_idx : int ):
253- images : ImageProcessorItems = mm_items .get ("image" ,
254- ImageProcessorItems )
291+ images = mm_items .get_items ("image" , ImageProcessorItems )
255292 image_size : ImageSize = images .get_image_size (item_idx )
256293
257- num_patches = self .info .get_num_patches (image_size . height ,
258- image_size .width )
259- img_tokens_per_tile = int ( hf_processor . patch_size ** 2 )
260- single_tile_tokens = image_token * img_tokens_per_tile + \
261- img_line_break_token
262- img_string = f" { boi_token } \
263- { single_tile_tokens * num_patches } \
264- { eoi_token } "
294+ num_patches = self .info .get_num_patches (
295+ image_width = image_size .width ,
296+ image_height = image_size . height ,
297+ processor = hf_processor ,
298+ )
299+ patch_tokens = ( image_token * img_tokens_per_tile +
300+ img_line_break_token )
301+ repl = f" { boi_token } { patch_tokens * num_patches } { eoi_token } "
265302
266- return PromptUpdateDetails .select_text (img_string , image_token )
303+ return PromptUpdateDetails .select_text (repl , image_token )
267304
268305 return [
269306 PromptReplacement (
0 commit comments