Skip to content

Commit a647e5a

Browse files
authored
🗜 Hotfix: avoid passing quantization_config=None (#4019)
1 parent 816ac61 commit a647e5a

19 files changed

+110
-64
lines changed

examples/scripts/dpo_vlm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,18 @@
8888
# Model & Tokenizer
8989
################
9090
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
91-
quantization_config = get_quantization_config(model_args)
9291

9392
model_kwargs = dict(
9493
revision=model_args.model_revision,
9594
attn_implementation=model_args.attn_implementation,
9695
dtype=dtype,
97-
device_map=get_kbit_device_map() if quantization_config is not None else None,
98-
quantization_config=quantization_config,
9996
)
97+
quantization_config = get_quantization_config(model_args)
98+
if quantization_config is not None:
99+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
100+
model_kwargs["device_map"] = get_kbit_device_map()
101+
model_kwargs["quantization_config"] = quantization_config
102+
100103
model = AutoModelForImageTextToText.from_pretrained(
101104
model_args.model_name_or_path,
102105
trust_remote_code=model_args.trust_remote_code,

examples/scripts/gkd.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,19 @@
8181
################
8282
# Model & Tokenizer
8383
################
84-
quantization_config = get_quantization_config(model_args)
8584
model_kwargs = dict(
8685
revision=model_args.model_revision,
8786
trust_remote_code=model_args.trust_remote_code,
8887
attn_implementation=model_args.attn_implementation,
8988
dtype=model_args.dtype,
9089
use_cache=False if training_args.gradient_checkpointing else True,
91-
device_map=get_kbit_device_map() if quantization_config is not None else None,
92-
quantization_config=quantization_config,
9390
)
91+
quantization_config = get_quantization_config(model_args)
92+
if quantization_config is not None:
93+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
94+
model_kwargs["device_map"] = get_kbit_device_map()
95+
model_kwargs["quantization_config"] = quantization_config
96+
9497
training_args.model_init_kwargs = model_kwargs
9598

9699
teacher_model_kwargs = dict(
@@ -99,9 +102,12 @@
99102
attn_implementation=model_args.attn_implementation,
100103
dtype=model_args.dtype,
101104
use_cache=True,
102-
device_map=get_kbit_device_map() if quantization_config is not None else None,
103-
quantization_config=quantization_config,
104105
)
106+
if quantization_config is not None:
107+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
108+
model_kwargs["device_map"] = get_kbit_device_map()
109+
model_kwargs["quantization_config"] = quantization_config
110+
105111
training_args.teacher_model_init_kwargs = teacher_model_kwargs
106112

107113
tokenizer = AutoTokenizer.from_pretrained(

examples/scripts/grpo_vlm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,16 @@
9797
# Model & Processor
9898
################
9999
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
100-
quantization_config = get_quantization_config(model_args)
101100
training_args.model_init_kwargs = dict(
102101
revision=model_args.model_revision,
103102
attn_implementation=model_args.attn_implementation,
104103
dtype=dtype,
105-
device_map=get_kbit_device_map() if quantization_config is not None else None,
106-
quantization_config=quantization_config,
107104
)
105+
quantization_config = get_quantization_config(model_args)
106+
if quantization_config is not None:
107+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
108+
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
109+
training_args.model_init_kwargs["quantization_config"] = quantization_config
108110

109111
################
110112
# Dataset

examples/scripts/gspo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@
8383
# Model & Processor
8484
################
8585
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
86-
quantization_config = get_quantization_config(model_args)
8786
training_args.model_init_kwargs = dict(
8887
revision=model_args.model_revision,
8988
attn_implementation=model_args.attn_implementation,
9089
dtype=dtype,
91-
device_map=get_kbit_device_map() if quantization_config is not None else None,
92-
quantization_config=quantization_config,
9390
)
91+
quantization_config = get_quantization_config(model_args)
92+
if quantization_config is not None:
93+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
94+
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
95+
training_args.model_init_kwargs["quantization_config"] = quantization_config
9496

9597
################
9698
# Dataset

examples/scripts/gspo_vlm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@
8484
# Model & Processor
8585
################
8686
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
87-
quantization_config = get_quantization_config(model_args)
8887
training_args.model_init_kwargs = dict(
8988
revision=model_args.model_revision,
9089
attn_implementation=model_args.attn_implementation,
9190
dtype=dtype,
92-
device_map=get_kbit_device_map() if quantization_config is not None else None,
93-
quantization_config=quantization_config,
9491
)
92+
quantization_config = get_quantization_config(model_args)
93+
if quantization_config is not None:
94+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
95+
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
96+
training_args.model_init_kwargs["quantization_config"] = quantization_config
9597

9698
################
9799
# Dataset

examples/scripts/mpo_vlm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,19 @@
7272
# Model & Processor
7373
################
7474
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
75-
quantization_config = get_quantization_config(model_args)
7675

7776
model_kwargs = dict(
7877
trust_remote_code=model_args.trust_remote_code,
7978
revision=model_args.model_revision,
8079
attn_implementation=model_args.attn_implementation,
8180
dtype=dtype,
82-
device_map=get_kbit_device_map() if quantization_config is not None else None,
83-
quantization_config=quantization_config,
8481
)
82+
quantization_config = get_quantization_config(model_args)
83+
if quantization_config is not None:
84+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
85+
model_kwargs["device_map"] = get_kbit_device_map()
86+
model_kwargs["quantization_config"] = quantization_config
87+
8588
model = AutoModelForImageTextToText.from_pretrained(
8689
model_args.model_name_or_path,
8790
**model_kwargs,

examples/scripts/nash_md.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,17 @@
8888
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
8989

9090
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
91-
quantization_config = get_quantization_config(model_args)
9291
model_kwargs = dict(
9392
revision=model_args.model_revision,
9493
attn_implementation=model_args.attn_implementation,
9594
dtype=dtype,
9695
use_cache=False if training_args.gradient_checkpointing else True,
97-
device_map=get_kbit_device_map() if quantization_config is not None else None,
98-
quantization_config=quantization_config,
9996
)
97+
quantization_config = get_quantization_config(model_args)
98+
if quantization_config is not None:
99+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
100+
model_kwargs["device_map"] = get_kbit_device_map()
101+
model_kwargs["quantization_config"] = quantization_config
100102

101103
model = AutoModelForCausalLM.from_pretrained(
102104
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs

examples/scripts/online_dpo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,17 @@
8484
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
8585

8686
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
87-
quantization_config = get_quantization_config(model_args)
8887
model_kwargs = dict(
8988
revision=model_args.model_revision,
9089
attn_implementation=model_args.attn_implementation,
9190
dtype=dtype,
9291
use_cache=False if training_args.gradient_checkpointing else True,
93-
device_map=get_kbit_device_map() if quantization_config is not None else None,
94-
quantization_config=quantization_config,
9592
)
93+
quantization_config = get_quantization_config(model_args)
94+
if quantization_config is not None:
95+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
96+
model_kwargs["device_map"] = get_kbit_device_map()
97+
model_kwargs["quantization_config"] = quantization_config
9698

9799
model = AutoModelForCausalLM.from_pretrained(
98100
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs

examples/scripts/online_dpo_vlm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,17 @@
115115
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
116116

117117
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
118-
quantization_config = get_quantization_config(model_args)
119118
model_kwargs = dict(
120119
revision=model_args.model_revision,
121120
attn_implementation=model_args.attn_implementation,
122121
dtype=dtype,
123122
use_cache=False if training_args.gradient_checkpointing else True,
124-
device_map=get_kbit_device_map() if quantization_config is not None else None,
125-
quantization_config=quantization_config,
126123
)
124+
quantization_config = get_quantization_config(model_args)
125+
if quantization_config is not None:
126+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
127+
model_kwargs["device_map"] = get_kbit_device_map()
128+
model_kwargs["quantization_config"] = quantization_config
127129

128130
# Load the VLM model using correct architecture (from GRPO pattern)
129131
config = AutoConfig.from_pretrained(model_args.model_name_or_path)

examples/scripts/ppo/ppo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,16 @@
9191
# Model & Tokenizer
9292
################
9393
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
94-
quantization_config = get_quantization_config(model_args)
9594
model_kwargs = dict(
9695
revision=model_args.model_revision,
9796
attn_implementation=model_args.attn_implementation,
9897
dtype=dtype,
99-
device_map=get_kbit_device_map() if quantization_config is not None else None,
100-
quantization_config=quantization_config,
10198
)
99+
quantization_config = get_quantization_config(model_args)
100+
if quantization_config is not None:
101+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
102+
model_kwargs["device_map"] = get_kbit_device_map()
103+
model_kwargs["quantization_config"] = quantization_config
102104

103105
tokenizer = AutoTokenizer.from_pretrained(
104106
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code

0 commit comments

Comments
 (0)