Skip to content

Commit 570afc2

Browse files
authored
phi 4 multimodal training version 1 ( with limitations ) (#1555)
1 parent 07151d5 commit 570afc2

File tree

11 files changed

+269
-11
lines changed

11 files changed

+269
-11
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Phi-4-multimodal-instruct
2+
3+
Configs for Phi-4-multimodal-instruct 5.6Β model. See https://huggingface.co/microsoft/Phi-4-multimodal-instruct
4+
5+
This is a multimodal model that combines text, visual, and audio inputs.
6+
It uses a "Mixture of LoRAs" approach, allowing you to plug in adapters for each
7+
modality without needing to retrain the base model. For more information consider
8+
reading the following:
9+
10+
- [Mixture-of-LoRAs](https://arxiv.org/abs/2403.03432)
11+
- [Phi-4 Multimodal Technical Report](https://arxiv.org/abs/2503.01743)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Phi4 multimodal 5.6B full fine-tune training job config.
2+
#
3+
# Requirements:
4+
# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup
5+
# - Log into WandB (`wandb login`) or disable `enable_wandb`
6+
#
7+
# Usage:
8+
# oumi launch up --config configs/recipes/vision/phi4/sft/gcp_job.yaml --cluster phi4-multimodal
9+
#
10+
# See Also:
11+
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html
12+
# - Config class: oumi.core.configs.JobConfig
13+
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py
14+
# - Other job configs: configs/**/*job.yaml
15+
16+
name: phi4-mm-sft-oumi-train
17+
18+
resources:
19+
cloud: gcp
20+
accelerators: "A100:1" # Feel free to bump up the number of GPUs!
21+
use_spot: false
22+
disk_size: 1000 # Disk size in GBs
23+
24+
num_nodes: 1 # Set it to N for multi-node training.
25+
26+
working_dir: .
27+
28+
file_mounts:
29+
~/.netrc: ~/.netrc # WandB credentials
30+
~/.cache/huggingface/token: ~/.cache/huggingface/token # HF credentials
31+
32+
envs:
33+
WANDB_PROJECT: oumi-train
34+
OUMI_RUN_NAME: phi4-vl.fft.oumi
35+
36+
setup: |
37+
set -e
38+
39+
pip install uv && uv pip install oumi[gpu] hf_transfer
40+
# Install model from HF Hub. This tool increases download speed compared to
41+
# downloading the model during training.
42+
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download microsoft/Phi-4-multimodal-instruct
43+
44+
# The model requires flash_attention_2! Install it here.
45+
pip install -U flash-attn --no-build-isolation
46+
47+
48+
run: |
49+
set -e # Exit if any command failed.
50+
source ./configs/examples/misc/sky_init.sh
51+
set -x
52+
oumi distributed torchrun \
53+
-m oumi train \
54+
-c configs/recipes/vision/phi4/sft/train.yaml \
55+
--training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \
56+
--training.max_steps 25 \
57+
--training.save_steps 0 \
58+
--training.save_final_model false
59+
60+
echo "Node ${SKYPILOT_NODE_RANK} is all done!"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Phi-4-multimodal-instruct training config for SFT finetuning.
2+
#
3+
# Phi-4-multimodal-instruct is a multimodal model that combines text, visual, and audio # inputs. It uses a "Mixture of LoRAs" approach, allowing you to plug in adapters for
4+
# each modality without needing to retrain the base model.
5+
#
6+
# Important Note: Oumi has currently integrated and fully tested Phi-4 for vision and
7+
# text modalities only (!).
8+
#
9+
# Requirements:
10+
# - Log into WandB (`wandb login`) or disable `enable_wandb`
11+
# - Run `pip install -U flash-attn --no-build-isolation`
12+
#
13+
# Usage:
14+
# oumi train -c configs/recipes/vision/phi4/sft/train.yaml
15+
#
16+
# See Also:
17+
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/train/train.html
18+
# - Config class: oumi.core.configs.TrainingConfig
19+
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/training_config.py
20+
# - Other training configs: configs/**/pretraining/, configs/**/sft/, configs/**/dpo/
21+
22+
model:
23+
model_name: "microsoft/Phi-4-multimodal-instruct"
24+
torch_dtype_str: "bfloat16"
25+
model_max_length: 4096
26+
trust_remote_code: True
27+
attn_implementation: "flash_attention_2" # The model requires Flash Attention.
28+
29+
# The model by default freezes the following audio/image-related modules:
30+
# model.embed_tokens_extend.audio_embed
31+
# model.embed_tokens_extend.image_embed
32+
33+
data:
34+
train:
35+
collator_name: "vision_language_with_padding"
36+
use_torchdata: true
37+
datasets:
38+
- dataset_name: "merve/vqav2-small"
39+
split: "validation"
40+
shuffle: True
41+
seed: 42
42+
trust_remote_code: True
43+
transform_num_workers: "auto"
44+
dataset_kwargs:
45+
processor_name: "microsoft/Phi-4-multimodal-instruct"
46+
return_tensors: True
47+
48+
# Below are examples of other vision SFT datasets
49+
# - dataset_name: "HuggingFaceH4/llava-instruct-mix-vsft"
50+
# split: "train"
51+
# shuffle: True
52+
# seed: 42
53+
# trust_remote_code: True
54+
# transform_num_workers: "auto"
55+
# dataset_kwargs:
56+
# processor_name: "microsoft/Phi-4-multimodal-instruct"
57+
# return_tensors: True
58+
# - dataset_name: "coco_captions"
59+
# split: "train"
60+
# trust_remote_code: True
61+
# dataset_kwargs:
62+
# processor_name: "microsoft/Phi-4-multimodal-instruct"
63+
# return_tensors: True
64+
# - dataset_name: vision_language_jsonl
65+
# dataset_path: "training.jsonl" # See notebook for example how to generate this file
66+
# dataset_kwargs:
67+
# data_column: "messages"
68+
# processor_name: "microsoft/Phi-4-multimodal-instruct"
69+
70+
training:
71+
output_dir: "output/vlm_finetuned"
72+
trainer_type: "TRL_SFT"
73+
enable_gradient_checkpointing: True
74+
per_device_train_batch_size: 1 # Due to processor's handling of variable sized img-features.
75+
gradient_accumulation_steps: 8
76+
max_steps: 20
77+
78+
gradient_checkpointing_kwargs:
79+
# Reentrant docs: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
80+
use_reentrant: False
81+
ddp_find_unused_parameters: True
82+
83+
optimizer: "adamw_torch_fused"
84+
learning_rate: 2e-5
85+
warmup_ratio: 0.03
86+
weight_decay: 0.0
87+
lr_scheduler_type: "cosine"
88+
89+
logging_steps: 10
90+
dataloader_main_process_only: True
91+
dataloader_num_workers: 4
92+
dataloader_prefetch_factor: 8
93+
include_performance_metrics: True
94+
enable_wandb: True

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dependencies = [
4242
"aiohttp>=3.10,<3.12", # Used by inference engine
4343
"aiofiles>=22.1.0,<25", # Allows to use async file operations
4444
"aioresponses>=0.7.6,<0.8", # User by inference engine tests
45+
"backoff>=2.2.1,<2.3",
4546
"datasets>=3.2.0,<3.3",
4647
"jsonlines",
4748
"lm_eval[wandb]>=0.4.5,<0.5.0",

src/oumi/builders/collators.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def build_collator_from_config(
161161
collator_kwargs["allow_multi_image_inputs"] = (
162162
model_config.visual_config.supports_multiple_images
163163
)
164+
collator_kwargs["main_image_feature"] = (
165+
model_config.visual_config.main_image_feature
166+
)
164167

165168
if collator_name == "vision_language_sft":
166169
processor_name = collator_kwargs.get(

src/oumi/builders/processors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ def build_processor(
5050

5151
# Initialize model-specific params.
5252
label_ignore_index: Optional[int] = constants.LABEL_IGNORE_INDEX
53+
ignore_features: Optional[list[str]] = None
5354
processor_kwargs = {}
5455
if model_config is not None:
5556
label_ignore_index = model_config.label_ignore_index
57+
ignore_features = model_config.ignore_features
5658
processor_kwargs.update(model_config.processor_kwargs)
5759

5860
create_processor_fn = functools.partial(
@@ -70,4 +72,5 @@ def build_processor(
7072
worker_processor,
7173
tokenizer,
7274
label_ignore_index=label_ignore_index,
75+
ignore_features=ignore_features,
7376
)

src/oumi/core/collators/vision_language_collator_with_padding.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
2222
from oumi.utils.torch_utils import pad_to_max_dim_and_stack
2323

24-
_PIXEL_VALUES_KEY = "pixel_values"
25-
2624

2725
class VisionLanguageCollatorWithPadding:
2826
def __init__(
@@ -33,6 +31,7 @@ def __init__(
3331
truncation: bool = False,
3432
label_ignore_index: Optional[int] = None,
3533
allow_multi_image_inputs: bool = True,
34+
main_image_feature: str = "pixel_values",
3635
):
3736
"""Custom collator for multi-modal vision-language training.
3837
@@ -45,8 +44,11 @@ def __init__(
4544
label_ignore_index: If set, then label values of tokens that shouldn't
4645
contribute to the loss computation will be replaced by this special value.
4746
allow_multi_image_inputs: Whether to allow multi-image inputs.
47+
main_image_feature: The key to use for fetching the main image data
48+
(e.g., raw pixels, patches, etc.) from the input.
4849
"""
4950
self._allow_multi_image_inputs = allow_multi_image_inputs
51+
self._main_image_feature = main_image_feature
5052
self._text_collator: TextCollatorWithPadding = TextCollatorWithPadding(
5153
tokenizer=tokenizer,
5254
max_length=max_length,
@@ -60,7 +62,7 @@ def __init__(
6062
)
6163

6264
def __call__(self, batch) -> dict[str, Any]:
63-
"""Custom collator for multi-modal vision-language training.
65+
"""Custom collator for multi-modal vision-language training.
6466
6567
Args:
6668
batch: List of batch items.
@@ -71,7 +73,7 @@ def __call__(self, batch) -> dict[str, Any]:
7173
# Collate batch prompts
7274
collated_batch = self._text_collator(batch) # type: ignore
7375
known_input_names: set[str] = set(collated_batch.keys()).union(
74-
{_PIXEL_VALUES_KEY}
76+
{self._main_image_feature}
7577
)
7678
other_input_names: set[str] = set()
7779

@@ -80,12 +82,12 @@ def __call__(self, batch) -> dict[str, Any]:
8082
# TODO Consider relaxing this constraint: a vision/language model
8183
# can handle text-only inputs e.g., a follow-up to an answer,
8284
# or image-only inputs e.g., captioning.
83-
if _PIXEL_VALUES_KEY not in item:
85+
if self._main_image_feature not in item:
8486
raise ValueError(
85-
f"Item doesn't contain '{_PIXEL_VALUES_KEY}' key. "
87+
f"Item doesn't contain '{self._main_image_feature}' key. "
8688
f"Available keys: {item.keys()}"
8789
)
88-
images.append(item[_PIXEL_VALUES_KEY])
90+
images.append(item[self._main_image_feature])
8991

9092
for key in item:
9193
if (
@@ -96,10 +98,10 @@ def __call__(self, batch) -> dict[str, Any]:
9698
other_input_names.add(key)
9799

98100
# Collate images.
99-
pixel_values = self.collate_images(images)
101+
image_input_features = self.collate_images(images)
100102

101103
# Add images to other inputs.
102-
collated_batch[_PIXEL_VALUES_KEY] = pixel_values
104+
collated_batch[self._main_image_feature] = image_input_features
103105

104106
# For other inputs, let's verify they present in all examples and stack them.
105107
if len(other_input_names) > 0:

src/oumi/core/configs/internal/internal_model_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ class InternalFeatureSpec(NamedTuple):
7474

7575
@dataclass
7676
class InternalVisualModelConfig(BaseConfig):
77+
main_image_feature: str = "pixel_values"
78+
"""The key corresponding to the main image feature consumed by the model.
79+
80+
E.g., raw pixels, transformed image patches, etc. resulting from data
81+
preprocessing and consumed by the underlying model."""
82+
7783
variable_shape_image_features: bool = False
7884
"""Whether image features can be of variable shape.
7985
@@ -133,5 +139,8 @@ class InternalModelConfig(BaseConfig):
133139
processor_kwargs: dict[str, Any] = field(default_factory=dict)
134140
"""Extra params to pass to processor constructor."""
135141

142+
ignore_features: list[str] = field(default_factory=list)
143+
"""Features from processing the input to ignore in the model's forward method."""
144+
136145
visual_config: Optional[InternalVisualModelConfig] = None
137146
"""Configuration specific to visual models."""

src/oumi/core/configs/internal/supported_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,51 @@ def _create_phi3_vlm_config() -> InternalModelConfig:
200200
return config
201201

202202

203+
def _create_phi4_vlm_config() -> InternalModelConfig:
204+
config = InternalModelConfig()
205+
config.chat_template = "phi3-instruct"
206+
config.ignore_features = [
207+
"audio_attention_mask", # We won't use audio features.
208+
"audio_embed_sizes",
209+
"input_audio_embeds",
210+
]
211+
212+
config.model_input_features.update(
213+
{
214+
feature_name: InternalFeatureSpec(
215+
name=feature_name,
216+
required=True,
217+
variable_shape=True,
218+
image_dependent=True,
219+
first_dim_action=InternalFeatureFirstDimAction.DROP_IF_DUMMY,
220+
)
221+
for feature_name in (
222+
"input_image_embeds",
223+
"image_attention_mask",
224+
)
225+
}
226+
)
227+
config.model_input_features.update(
228+
{
229+
feature_name: InternalFeatureSpec(
230+
name=feature_name,
231+
required=True,
232+
variable_shape=False,
233+
image_dependent=True,
234+
)
235+
for feature_name in ("image_sizes",)
236+
}
237+
)
238+
visual_config = InternalVisualModelConfig()
239+
# FIXME OPE-355 Set to True once multi-image issues are resolved for the model.
240+
visual_config.supports_multiple_images = False
241+
visual_config.variable_shape_image_features = True
242+
visual_config.main_image_feature = "input_image_embeds"
243+
244+
config.visual_config = visual_config
245+
return config
246+
247+
203248
def _create_idefics3_vlm_config() -> InternalModelConfig:
204249
config = _create_default_vlm_config(
205250
supports_multiple_images=True, pixel_values_variable_shape=True
@@ -321,6 +366,11 @@ def get_all_models_map() -> (
321366
tested=True,
322367
config=_create_phi3_vlm_config(),
323368
),
369+
_ModelTypeInfo(
370+
model_type="phi4mm",
371+
model_class=transformers.AutoModelForCausalLM,
372+
config=_create_phi4_vlm_config(),
373+
),
324374
]
325375

326376
# Make it immutable.

src/oumi/core/processors/base_processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def label_ignore_index(self) -> Optional[int]:
8787
"""Returns a label ignore index."""
8888
raise NotImplementedError
8989

90+
@property
91+
@abc.abstractmethod
92+
def ignore_features(self) -> list[str]:
93+
"""Returns a list of keys of features to ignore from feeding the model."""
94+
raise NotImplementedError
95+
9096
@abc.abstractmethod
9197
def __call__(
9298
self,

0 commit comments

Comments
 (0)