Skip to content

Commit 2b22e31

Browse files
committed
refactor model in intermediate api mode
1 parent bf8047e commit 2b22e31

File tree

8 files changed

+172
-2687
lines changed

8 files changed

+172
-2687
lines changed

examples/experiments/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,32 @@
2424
import paddle
2525
import paddle.distributed as dist
2626

27+
from paddleformers.data.causal_dataset import (
28+
build_train_valid_test_datasets,
29+
check_data_split,
30+
print_rank_0,
31+
)
2732
from paddleformers.ops import Topology
2833
from paddleformers.trainer import PdArgumentParser, get_last_checkpoint
29-
from paddleformers.trainer.auto_trainer import AutoTrainer
30-
from paddleformers.trainer.auto_training_args import AutoTrainingArguments
34+
from paddleformers.trainer.trainer import Trainer
3135
from paddleformers.trainer.trainer_utils import IntervalStrategy, _get_distributed_seeds
36+
from paddleformers.trainer.training_args import TrainingArguments
37+
from paddleformers.trainer.utils.doc import add_start_docstrings
3238
from paddleformers.transformers import (
3339
AutoTokenizer,
3440
CosineAnnealingWithWarmupDecay,
3541
LinearAnnealingWithWarmupDecay,
3642
LlamaConfig,
37-
LlamaForCausalLM3DAuto,
38-
LlamaForCausalLMNet,
39-
LlamaPretrainingCriterion3DAuto,
40-
LlamaPretrainingCriterionNet,
43+
LlamaForCausalLM,
44+
LlamaPretrainingCriterion,
4145
)
4246
from paddleformers.utils.log import logger
43-
44-
MODEL_CLASSES = {
45-
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
46-
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
47-
}
48-
49-
50-
from paddleformers.data.causal_dataset import (
51-
build_train_valid_test_datasets,
52-
check_data_split,
53-
print_rank_0,
54-
)
55-
from paddleformers.trainer.utils.doc import add_start_docstrings
5647
from paddleformers.utils.tools import get_env_device
5748

5849

5950
@dataclass
60-
@add_start_docstrings(AutoTrainingArguments.__doc__)
61-
class PreTrainingArguments(AutoTrainingArguments):
51+
@add_start_docstrings(TrainingArguments.__doc__)
52+
class PreTrainingArguments(TrainingArguments):
6253
min_learning_rate: float = field(
6354
default=1e-5,
6455
metadata={"help": "Minimum learning rate deacyed to."},
@@ -338,7 +329,7 @@ def get_train_data_file(args):
338329
return files
339330

340331

341-
class PretrainingTrainer(AutoTrainer):
332+
class PretrainingTrainer(Trainer):
342333
def __init__(self, *args, **kwargs):
343334
super().__init__(*args, **kwargs)
344335
self.is_pretraining = True
@@ -474,7 +465,9 @@ def main():
474465
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
475466
)
476467

477-
config_class, model_class, criterion_class = MODEL_CLASSES[model_args.model_type]
468+
config_class = LlamaConfig
469+
model_class = LlamaForCausalLM
470+
criterion_class = LlamaPretrainingCriterion
478471

479472
config = config_class.from_pretrained(model_args.model_name_or_path)
480473
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
@@ -542,8 +535,6 @@ def main():
542535
# It's OK, not use accumulate_steps optimization
543536
pass
544537

545-
print("Final pre-training config:", config)
546-
547538
if (
548539
"replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config
549540
and config.tensor_parallel_degree > 1
@@ -553,6 +544,15 @@ def main():
553544

554545
replace_cross_entropy()
555546

547+
if training_args.use_intermediate_api:
548+
config.run_single_model = True
549+
config.tensor_parallel_degree = 1
550+
config.sharding_parallel_degree = 1
551+
config.sep_parallel_degree = 1
552+
config.context_parallel_degree = 1
553+
554+
print("Final pre-training config:", config)
555+
556556
# # Set the dtype for loading model
557557
# dtype = "float32"
558558
# if training_args.fp16_opt_level == "O2":

paddleformers/transformers/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -210,41 +210,6 @@
210210
"LlamaPretrainingCriterion",
211211
"LlamaNTKScalingRotaryEmbedding",
212212
],
213-
"llama.modeling_auto": [
214-
"enable_fuse_ffn_qkv_pass",
215-
"LlamaDecoderLayerAuto",
216-
"LlamaAttentionAuto",
217-
"LlamaPretrainedModelAuto",
218-
"LlamaLMHeadAuto",
219-
"LlamaModelAuto",
220-
"LlamaForCausalLM3DAuto",
221-
"LlamaMLPAuto",
222-
"get_mesh",
223-
"LlamaRMSNormAuto",
224-
"is_pp_enable",
225-
"LlamaPretrainingCriterion3DAuto",
226-
"global_mesh_starts_with_pp",
227-
"scaled_dot_product_attention",
228-
],
229-
"llama.modeling_network": [
230-
"LlamaPretrainedModelNet",
231-
"layer_input_parallel_row_and_col_hook",
232-
"LlamaModelNet",
233-
"LlamaPretrainingCriterionNet",
234-
"layer_input_replicate_hook",
235-
"LlamaLMHeadNet",
236-
"LlamaForCausalLMNetDPO",
237-
"GlobalOutputNet",
238-
"layer_input_parallel_row_hook",
239-
"LlamaRMSNormNet",
240-
"LlamaAttentionNet",
241-
"scaled_dot_product_attention",
242-
"ReshardLayer",
243-
"LlamaForCausalLMNet",
244-
"enable_fuse_ffn_qkv_pass",
245-
"LlamaMLPNet",
246-
"LlamaDecoderLayerNet",
247-
],
248213
"llama.modeling_pp": ["LlamaForCausalLMPipe"],
249214
"llama.tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
250215
"llama.tokenizer_fast": ["LlamaTokenizerFast"],

paddleformers/transformers/configuration_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ class PretrainedConfig:
537537
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
538538
model has a output word embedding layer.
539539
540+
run_single_model (`bool`, *optional*, defaults to `False`):
541+
Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled.
542+
540543
dtype (`str`, *optional*):
541544
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
542545
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
@@ -601,6 +604,13 @@ def __init__(self, **kwargs):
601604
self.use_cache = kwargs.pop("use_cache", False)
602605
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)
603606

607+
# for run model in single card mode
608+
self.run_single_model = kwargs.pop("run_single_model", False)
609+
if self.run_single_model:
610+
self.tensor_parallel_degree = 1
611+
self.sep_parallel_degree = 1
612+
self.context_parallel_degree = 1
613+
604614
# for transformers fuse
605615
self.fuse_linear = kwargs.pop("fuse_linear", False)
606616
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)

paddleformers/transformers/llama/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,41 +50,6 @@
5050
"LlamaPretrainingCriterion",
5151
"LlamaNTKScalingRotaryEmbedding",
5252
],
53-
"modeling_auto": [
54-
"enable_fuse_ffn_qkv_pass",
55-
"LlamaDecoderLayerAuto",
56-
"LlamaAttentionAuto",
57-
"LlamaPretrainedModelAuto",
58-
"LlamaLMHeadAuto",
59-
"LlamaModelAuto",
60-
"LlamaForCausalLM3DAuto",
61-
"LlamaMLPAuto",
62-
"get_mesh",
63-
"LlamaRMSNormAuto",
64-
"is_pp_enable",
65-
"LlamaPretrainingCriterion3DAuto",
66-
"global_mesh_starts_with_pp",
67-
"scaled_dot_product_attention",
68-
],
69-
"modeling_network": [
70-
"LlamaPretrainedModelNet",
71-
"layer_input_parallel_row_and_col_hook",
72-
"LlamaModelNet",
73-
"LlamaPretrainingCriterionNet",
74-
"layer_input_replicate_hook",
75-
"LlamaLMHeadNet",
76-
"LlamaForCausalLMNetDPO",
77-
"GlobalOutputNet",
78-
"layer_input_parallel_row_hook",
79-
"LlamaRMSNormNet",
80-
"LlamaAttentionNet",
81-
"scaled_dot_product_attention",
82-
"ReshardLayer",
83-
"LlamaForCausalLMNet",
84-
"enable_fuse_ffn_qkv_pass",
85-
"LlamaMLPNet",
86-
"LlamaDecoderLayerNet",
87-
],
8853
"modeling_pp": ["LlamaForCausalLMPipe"],
8954
"tokenizer": ["LlamaTokenizer", "Llama3Tokenizer"],
9055
"tokenizer_fast": ["LlamaTokenizerFast"],
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import paddle.distributed as dist
17+
from paddle.distributed.auto_parallel.intermediate.tensor_parallel import (
18+
PrepareLayerInput,
19+
)
20+
21+
22+
def layer_input_parallel_row_hook(process_mesh):
23+
def hook(layer, inputs, output=None):
24+
res_inputs = []
25+
for input in inputs:
26+
if not input.is_dist():
27+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Replicate()])
28+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Replicate()]))
29+
else:
30+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Replicate()]))
31+
return tuple(res_inputs)
32+
33+
return hook
34+
35+
36+
def layer_input_parallel_row_and_col_hook(process_mesh):
37+
def hook(layer, inputs, output=None):
38+
res_inputs = []
39+
for input in inputs:
40+
if not input.is_dist():
41+
x = dist.shard_tensor(input, process_mesh, [dist.Shard(0), dist.Shard(1)])
42+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Shard(0), dist.Shard(1)]))
43+
else:
44+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Shard(0), dist.Shard(1)]))
45+
return tuple(res_inputs)
46+
47+
return hook
48+
49+
50+
def layer_input_replicate_hook(process_mesh):
51+
def hook(layer, inputs, output=None):
52+
res_inputs = []
53+
for input in inputs:
54+
if not input.is_dist():
55+
x = dist.shard_tensor(input, process_mesh, [dist.Replicate(), dist.Replicate()])
56+
res_inputs.append(dist.reshard(x, process_mesh, [dist.Replicate(), dist.Replicate()]))
57+
else:
58+
res_inputs.append(dist.reshard(input, process_mesh, [dist.Replicate(), dist.Replicate()]))
59+
return tuple(res_inputs)
60+
61+
return hook
62+
63+
64+
def auto_dist_config(self, prefix=""):
65+
if prefix != "":
66+
assert prefix.endswith(".")
67+
config = {
68+
"sp_config": {
69+
"parallelize_plan": {
70+
f"{prefix}llama.embed_tokens": [
71+
dist.ColWiseParallel(),
72+
dist.SequenceParallelBegin(),
73+
],
74+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
75+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
76+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
77+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
78+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
79+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
80+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
81+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
82+
f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(),
83+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
84+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
85+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
86+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
87+
f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable(need_transpose=False),
88+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
89+
f"{prefix}lm_head": dist.SequenceParallelEnd(),
90+
}
91+
},
92+
"mp_config": {
93+
"parallelize_plan": {
94+
f"{prefix}llama.embed_tokens": dist.ColWiseParallel(gather_output=True),
95+
f"{prefix}llama.reshard_row": PrepareLayerInput(layer_input_parallel_row_hook),
96+
f"{prefix}llama.reshard_row_and_col": PrepareLayerInput(layer_input_parallel_row_and_col_hook),
97+
f"{prefix}llama.global_layer.reshard_replicate": PrepareLayerInput(layer_input_replicate_hook),
98+
f"{prefix}llama.layers.*.self_attn.qkv_proj": dist.ColWiseParallel(),
99+
f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(),
100+
f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(),
101+
f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(),
102+
f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(),
103+
f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(),
104+
f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(),
105+
f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": dist.ColWiseParallel(),
106+
f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(),
107+
f"{prefix}lm_head.weight": dist.ColWiseParallel(),
108+
}
109+
},
110+
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
111+
}
112+
113+
return config

0 commit comments

Comments
 (0)