Skip to content

Conversation

@waliwali777
Copy link
Contributor

@waliwali777 waliwali777 commented Nov 5, 2025

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

Function optimization

PR changes

Models

Description

在重构 trainer ( #2801 )的基础上对 llama 3.1 模型组网接口进行优化
组网统一使用 modeling.py,移除 modeling_network.py 和 modeling_auto.py
该 PR 逻辑如下:

  1. 在自动并行的 workflow.py 中会导入 AutoConfig、AutoModelForCausalLM、AutoModelForCausalLMPipe,这会根据 yaml 中指定的 model_name 自动获取对应的模型实现。这里可以跑通 单卡、手动并行、自动并行。
    用户在 yaml 配置中开启自动并行情况如下:

中层api动半:

stage: auto-parallel
model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
enable_auto_parallel: true
use_intermediate_api: true
... ...

单卡:

stage: auto-parallel
model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
enable_auto_parallel: false
use_intermediate_api: false
... ...

动手:

stage: auto-parallel
model_name_or_path: meta-llama/Llama-3.1-8B-Instruct
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
enable_auto_parallel: false
use_intermediate_api: false
... ...

启动 llama3 脚本:

单卡 :
export CUDA_VISIBLE_DEVICES=2
多卡 :
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
paddleformers-cli train ./config.yaml

  1. 在 PretrainedConfig 中引入开关 run_single_model,当跑中层api动半时,该开关会被设为开启,并会在 workerflow 中关闭其他并行配置(将 tensor_parallel_degree、sep_parallel_degree、context_parallel_degree 设为1),让 modeling.py 在运行时是在单卡代码模式下,避免跑到动手通信的地方
  2. 在 modeling.py 添加中层api配置。新增文件 auto_dist_config.py,这是中层 api 的配置文件,记录每层在不同并行下的切分状态。给 LlamaForCausalLM 添加函数 auto_dist_config ,这个函数会在 trainer 初始化时去读取这个配置,并初始化中层api的环境
  3. 对 parallel_matmul 的写法进行修改,增加参数 tensor_parallel_degree,去除 try catch 方式获取 hcg,只有动手的情况下会获取 hcg,避免动半时跑到手动通信的代码
  4. 移除 modeling_network.py 和 modeling_auto.py

@paddle-bot
Copy link

paddle-bot bot commented Nov 5, 2025

Thanks for your contribution!

@codecov-commenter
Copy link

codecov-commenter commented Nov 5, 2025

Codecov Report

❌ Patch coverage is 19.14894% with 38 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@26d3e2c). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddleformers/cli/train/auto_parallel/workflow.py 0.00% 27 Missing ⚠️
...ddleformers/transformers/llama/auto_dist_config.py 33.33% 4 Missing ⚠️
paddleformers/transformers/llama/modeling.py 55.55% 4 Missing ⚠️
paddleformers/transformers/configuration_utils.py 40.00% 3 Missing ⚠️

❌ Your patch status has failed because the patch coverage (19.14%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #2859   +/-   ##
==========================================
  Coverage           ?   31.91%           
==========================================
  Files              ?      419           
  Lines              ?    67382           
  Branches           ?        0           
==========================================
  Hits               ?    21505           
  Misses             ?    45877           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@waliwali777 waliwali777 changed the title [AutoParallel] Refactor model in intermediate api [AutoParallel] Refactor llama3.1 model in intermediate api Nov 6, 2025

register_pp_reshard_information(config.num_hidden_layers)
except:
print("Not register llama pp reshard information.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么情况下会执行失败?不注册这个reshard会造成什么影响

Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
run_single_model (`bool`, *optional*, defaults to `False`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这个是想表达非并行模式的话,名字并不表意,建议替换下,开发者可以更好理解,例如:run_without_parallelismrun_in_non_parallel_mode,如果这种模式下还允许dp和sharding的话,看看是否有更合适的名字

any(architecture in str(config.architectures) for architecture in architectures_to_check)
and training_args.data_parallel_degree > 1
):
training_args.use_expert_parallel = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单卡模式下允许EP吗?

training_args.use_expert_parallel = True

if model_args.continue_training:
# NOTE(gongenlei): new add
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note去掉

if training_args.autotuner_benchmark:
model = model_class.from_config(config, dtype=dtype)
else:
model = model_class.from_pretrained(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

热启时不需要参考下面 paddle.lazyGuard写法吗?

Copy link

@liym27 liym27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants