From 9f3d07ffb9558e2c873f5299f4b3a15c10495138 Mon Sep 17 00:00:00 2001 From: firestonelib Date: Wed, 18 Jan 2023 17:05:58 +0800 Subject: [PATCH] add debertav2 --- .../multimodal/imagen/imagen_super_resolution_1024.yaml | 5 +++-- .../multimodal/imagen/imagen_super_resolution_256.yaml | 2 +- .../multimodal/imagen/imagen_text2im_64x64_DebertaV2.yaml | 1 + ppfleetx/models/multimodal_model/imagen/modeling.py | 1 + .../imagen/run_text2im_2B_64x64_T5-11B_sharding8_dp32.sh | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_1024.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_1024.yaml index dd6d85c32..4773a2a2f 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_1024.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_1024.yaml @@ -8,7 +8,7 @@ Global: Model: name: imagen_SR1024 - text_encoder_name: t5/t5-11b + text_encoder_name: projects/imagen/t5/t5-11b text_embed_dim: 1024 timesteps: 1000 in_chans: 3 @@ -25,12 +25,13 @@ Model: dynamic_thresholding_percentile: 0.95 only_train_unet_number: 1 use_recompute: False + recompute_granularity: Data: Train: dataset: name: ImagenDataset - input_path: ./projects/imagen/filelist/cc12m_base64.lst + input_path: ./projects/imagen/filelist/laion_400M/train shuffle: True input_resolution: 1024 max_seq_len: 128 diff --git a/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_256.yaml b/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_256.yaml index af4b77f65..dc78c8949 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_256.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_super_resolution_256.yaml @@ -53,7 +53,7 @@ Loss: p2_loss_weight_k: 1.0 Distributed: - dp_degree: 1 + dp_degree: 128 mp_degree: 1 pp_degree: 1 sharding: diff --git a/ppfleetx/configs/multimodal/imagen/imagen_text2im_64x64_DebertaV2.yaml b/ppfleetx/configs/multimodal/imagen/imagen_text2im_64x64_DebertaV2.yaml index 7b0d84ad0..fcbb36f23 100644 --- a/ppfleetx/configs/multimodal/imagen/imagen_text2im_64x64_DebertaV2.yaml +++ b/ppfleetx/configs/multimodal/imagen/imagen_text2im_64x64_DebertaV2.yaml @@ -25,6 +25,7 @@ Model: dynamic_thresholding_percentile: 0.95 only_train_unet_number: 1 use_recompute: False + recompute_granularity: Loss: name: mse_loss diff --git a/ppfleetx/models/multimodal_model/imagen/modeling.py b/ppfleetx/models/multimodal_model/imagen/modeling.py index 027f87498..5c51af33d 100644 --- a/ppfleetx/models/multimodal_model/imagen/modeling.py +++ b/ppfleetx/models/multimodal_model/imagen/modeling.py @@ -23,6 +23,7 @@ from .unet import Unet from ppfleetx.models.language_model.t5 import * +from ppfleetx.models.language_model.debertav2 import * from ppfleetx.data.tokenizers import get_t5_tokenizer from .utils import ( GaussianDiffusionContinuousTimes, default, exists, cast_tuple, first, diff --git a/projects/imagen/run_text2im_2B_64x64_T5-11B_sharding8_dp32.sh b/projects/imagen/run_text2im_2B_64x64_T5-11B_sharding8_dp32.sh index e059d677f..e38217b78 100644 --- a/projects/imagen/run_text2im_2B_64x64_T5-11B_sharding8_dp32.sh +++ b/projects/imagen/run_text2im_2B_64x64_T5-11B_sharding8_dp32.sh @@ -22,5 +22,5 @@ python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6, ./tools/train.py \ -c ./ppfleetx/configs/multimodal/imagen/imagen_text2im_64x64_T5-11B.yaml \ -o Distributed.sharding.sharding_stage=2 \ - -o Distributed.dp_degree=32 \ + -o Distributed.dp_degree=32 \ -o Distributed.sharding.sharding_degree=8