Skip to content

Commit 07151d5

Browse files
authored
Support text truncation (max_length) for vision_language_sft collator (#1559)
1 parent a5c541b commit 07151d5

File tree

8 files changed

+646
-8
lines changed

8 files changed

+646
-8
lines changed

src/oumi/core/collators/vision_language_sft_collator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ def __init__(
3030
tokenizer: BaseTokenizer,
3131
processor_name: str,
3232
*,
33-
max_length: Optional[int],
33+
max_length: Optional[int] = None,
3434
truncation: bool = False,
35+
truncation_side: str = "right",
3536
label_ignore_index: Optional[int] = None,
3637
allow_multi_image_inputs: bool = True,
3738
trust_remote_code: bool = False,
@@ -45,6 +46,7 @@ def __init__(
4546
truncation: Whether to truncate long inputs to `max_length`.
4647
If False, the long inputs are preserved as is even if they exceed
4748
`max_length`. Only has effect if `max_length` is specified.
49+
truncation_side: The side to truncate the tokens ("right" or "left").
4850
label_ignore_index: If set, then label values of tokens that shouldn't
4951
contribute to the loss computation will be replaced by
5052
this special value.
@@ -53,10 +55,6 @@ def __init__(
5355
"""
5456
self._allow_multi_image_inputs = allow_multi_image_inputs
5557

56-
# TODO Consider supporting truncation using these params
57-
self._max_length = max_length
58-
self._truncation = truncation
59-
6058
if not processor_name:
6159
raise ValueError("processor_name is required for VisionLanguageSftCollator")
6260

@@ -66,6 +64,9 @@ def __init__(
6664
processor_name=processor_name,
6765
trust_remote_code=trust_remote_code,
6866
return_tensors="pt",
67+
truncation=truncation,
68+
truncation_side=truncation_side,
69+
max_length=max_length,
6970
label_ignore_index=label_ignore_index,
7071
)
7172
)

src/oumi/core/feature_generators/vision_language_conversation_feature_generator.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,14 @@
3737
from oumi.core.types.conversation import (
3838
ContentItem,
3939
Conversation,
40+
Message,
41+
)
42+
from oumi.utils.conversation_utils import (
43+
load_pil_image_from_content_item,
44+
truncate_text_in_content_items,
4045
)
41-
from oumi.utils.conversation_utils import load_pil_image_from_content_item
4246
from oumi.utils.logging import logger
47+
from oumi.utils.str_utils import truncate_text_pieces_to_max_tokens_limit
4348
from oumi.utils.torch_utils import get_first_dim_len
4449

4550

@@ -65,12 +70,24 @@ def __init__(
6570
processor_name: Optional[str] = None,
6671
trust_remote_code: bool = False,
6772
return_tensors: Optional[str] = None,
73+
max_length: Optional[int] = None,
74+
truncation: bool = False,
75+
truncation_side: str = "right",
6876
label_ignore_index: Optional[int] = None,
6977
) -> None:
7078
"""Initializes a new instance of VisionLanguageFeatureProcessor."""
7179
# Importing these here to avoid circular dependencies
7280
from oumi.builders.processors import build_processor
7381

82+
if truncation_side not in ("left", "right"):
83+
raise ValueError(
84+
f"Invalid truncation_side: '{truncation_side}'. "
85+
"Expected 'left' or 'right'."
86+
)
87+
88+
self._max_length: Optional[int] = max_length
89+
self._truncation: bool = truncation
90+
self._truncation_side = truncation_side
7491
self._return_tensors = return_tensors
7592

7693
if tokenizer is None:
@@ -145,6 +162,9 @@ def _prepare_simple_model(
145162
last_text_item: ContentItem = text_turns[-1].text_content_items[-1]
146163

147164
prompt = last_text_item.content or ""
165+
truncated_texts = self._truncate_text_pieces([prompt])
166+
assert len(truncated_texts) == 1
167+
prompt = truncated_texts[0]
148168
image = self._load_image(last_image_item)
149169

150170
return image, prompt
@@ -171,6 +191,8 @@ def _prepare_instruct_model(
171191
f"Unsupported message: {turn.id}. Contains no text and no images."
172192
)
173193

194+
messages = self._truncate_text_in_content_items(messages)
195+
174196
text_prompt = self._processor.apply_chat_template(
175197
messages, add_generation_prompt=False
176198
)
@@ -361,3 +383,37 @@ def transform_conversations(
361383
inputs["labels"] = labels.tolist()
362384

363385
return inputs.data
386+
387+
def _truncate_text_in_content_items(self, messages: list[Message]) -> list[Message]:
388+
"""Truncates text contents in Messages to `max_length` total tokens.
389+
390+
Note that we have to truncate plain texts *before* we apply chat template
391+
as the final processed prompt is generally unsafe to truncate at arbitrary
392+
offset: it may break invariants (e.g., prompt contains `N` images tokens)
393+
leading to runtime errors in processor.
394+
"""
395+
if not (
396+
self._truncation and self._max_length is not None and self._max_length > 0
397+
):
398+
return messages
399+
400+
return truncate_text_in_content_items(
401+
messages,
402+
tokenizer=self._processor.tokenizer,
403+
max_tokens=self._max_length,
404+
truncation_side=self._truncation_side,
405+
)
406+
407+
def _truncate_text_pieces(self, text_pieces: list[str]) -> list[str]:
408+
"""Truncates text pieces to total length not exceeding `max_length`."""
409+
if not (
410+
self._truncation and self._max_length is not None and self._max_length > 0
411+
):
412+
return copy.deepcopy(text_pieces)
413+
414+
return truncate_text_pieces_to_max_tokens_limit(
415+
text_pieces,
416+
tokenizer=self._processor.tokenizer,
417+
max_tokens=self._max_length,
418+
truncation_side=self._truncation_side,
419+
)

src/oumi/core/processors/base_processor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,23 @@ def apply_chat_template(
128128
def save_config(self, output_dir: Union[Path, str]) -> None:
129129
"""Saves processor config to the directory."""
130130
raise NotImplementedError
131+
132+
@abc.abstractmethod
133+
def truncate_text(
134+
self,
135+
text: str,
136+
*,
137+
max_tokens: int,
138+
truncation_side: str = "right",
139+
) -> tuple[str, int]:
140+
"""Truncates text to `max_length` in tokens.
141+
142+
Args:
143+
text: A text prompt.
144+
max_tokens: Maximum number of tokens to keep.
145+
truncation_side: The side to truncate the tokens ("right" or "left").
146+
147+
Returns:
148+
A tuple containing truncated text prompt and the number of tokens.
149+
"""
150+
raise NotImplementedError

src/oumi/core/processors/default_processor.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
2626
from oumi.core.types.conversation import Message
2727
from oumi.utils.logging import logger
28+
from oumi.utils.str_utils import truncate_to_max_tokens_limit
2829

2930

3031
class DefaultProcessor(BaseProcessor):
@@ -54,7 +55,7 @@ def __init__(
5455
and callable(worker_processor.apply_chat_template)
5556
):
5657
raise ValueError(
57-
"Worker processor doesn't have " "the `apply_chat_template` method"
58+
"Worker processor doesn't have the `apply_chat_template` method"
5859
)
5960

6061
self._processor_name = processor_name
@@ -250,3 +251,28 @@ def save_config(self, output_dir: Union[Path, str]) -> None:
250251
return
251252

252253
self._worker_processor.save_pretrained(str(output_dir))
254+
255+
@override
256+
def truncate_text(
257+
self,
258+
text: str,
259+
*,
260+
max_tokens: int,
261+
truncation_side: str = "right",
262+
) -> tuple[str, int]:
263+
"""Truncates text to `max_length` in tokens.
264+
265+
Args:
266+
text: A text prompt.
267+
max_tokens: Maximum number of tokens to keep.
268+
truncation_side: The side to truncate the tokens ("right" or "left").
269+
270+
Returns:
271+
A tuple containing truncated text prompt and the number of tokens.
272+
"""
273+
return truncate_to_max_tokens_limit(
274+
text,
275+
self._tokenizer,
276+
max_tokens=max_tokens,
277+
truncation_side=truncation_side,
278+
)

src/oumi/utils/conversation_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import PIL.Image
1919

20+
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
2021
from oumi.core.types.conversation import ContentItem, Conversation, Message, Type
2122
from oumi.utils.image_utils import (
2223
DEFAULT_IMAGE_MODE,
@@ -26,6 +27,7 @@
2627
load_pil_image_from_path,
2728
load_pil_image_from_url,
2829
)
30+
from oumi.utils.str_utils import truncate_text_pieces_to_max_tokens_limit
2931

3032

3133
def load_image_bytes_to_content_item(
@@ -343,3 +345,91 @@ def remove_excessive_images_from_conversation(
343345
messages=filtered_messages,
344346
metadata=conversation.metadata,
345347
)
348+
349+
350+
def truncate_text_in_content_items(
351+
messages: list[Message],
352+
tokenizer: BaseTokenizer,
353+
*,
354+
max_tokens: int,
355+
truncation_side: str = "right",
356+
) -> list[Message]:
357+
"""Truncates text contents in Messages to `max_length` total tokens.
358+
359+
Note that we have to truncate plain texts *before* we apply chat template
360+
as the final processed prompt is generally unsafe to truncate at arbitrary
361+
offset: it may break invariants (e.g., prompt contains `N` images tokens)
362+
leading to runtime errors in processor.
363+
364+
Args:
365+
messages: A list of messages.
366+
tokenizer: The tokenizer used for encoding the data.
367+
max_tokens: Maximum number of tokens to keep in all text pieces combined.
368+
truncation_side: The side to truncate the tokens ("right" or "left").
369+
370+
Returns:
371+
A list of messages with potentially truncated text prompts.
372+
The returned list contains the same messages as the input list,
373+
except that the text content items may be truncated.
374+
"""
375+
if max_tokens <= 0:
376+
raise ValueError("`max_tokens` must be a positive integer")
377+
elif truncation_side not in ("left", "right"):
378+
raise ValueError(
379+
f"Invalid truncation_side: '{truncation_side}'. Expected 'left' or 'right'."
380+
)
381+
382+
result = [m for m in messages] # shallow copy
383+
384+
text_pieces: list[str] = []
385+
for msg_idx, message in enumerate(result):
386+
for item_idx, item in enumerate(message.content_items):
387+
if item.is_text():
388+
text_pieces.append(item.content or "")
389+
390+
if len(text_pieces) == 0:
391+
return result
392+
393+
truncated_texts = truncate_text_pieces_to_max_tokens_limit(
394+
text_pieces,
395+
tokenizer=tokenizer,
396+
max_tokens=max_tokens,
397+
truncation_side=truncation_side,
398+
)
399+
assert len(text_pieces) == len(truncated_texts)
400+
401+
idx = 0
402+
for msg_idx, message in enumerate(result):
403+
message_truncated = False
404+
items: list[ContentItem] = []
405+
for item_idx, item in enumerate(message.content_items):
406+
if item.is_text():
407+
items.append(
408+
ContentItem(
409+
content=truncated_texts[idx],
410+
type=item.type,
411+
)
412+
)
413+
original_text = item.content or ""
414+
if truncated_texts[idx] != original_text:
415+
message_truncated = True
416+
idx += 1
417+
else:
418+
items.append(item)
419+
420+
if message_truncated:
421+
if (
422+
len(items) == 1
423+
and items[0].is_text()
424+
and isinstance(messages[msg_idx].content, str)
425+
):
426+
assert isinstance(items[0].content, str)
427+
result[msg_idx] = Message(
428+
id=message.id, content=items[0].content, role=message.role
429+
)
430+
else:
431+
result[msg_idx] = Message(
432+
id=message.id, content=items, role=message.role
433+
)
434+
435+
return result

0 commit comments

Comments
 (0)