Skip to content

Conversation

czl66
Copy link

@czl66 czl66 commented Dec 24, 2024

In my practice on aishell -conformer_ctc-asr-task, I found that the script only implemented single machine - multi gpus, which is inconvenient for our gpusevrers. So I modified train.py, hope can be helpful for your icefall community. :)
企业微信截图_dad4f422-924e-4b74-82e3-9fca53ce6ca0

@csukuangfj
Copy link
Collaborator

Could you describe how to run it for multi-node multi-GPU training?

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

Could you describe how to run it for multi-node multi-GPU training?

yes, here is the code for main bash file:

node_rank=$1
WORLD_SIZE=$2
export CUDA_VISIBLE_DEVICES=$3
echo "WORKER INFO:: node_rank=$node_rank, WORLD_SIZE=$WORLD_SIZE, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')

DISTRIBUTED_ARGS="
    --nnodes ${WORLD_SIZE:-1} \
    --nproc_per_node $gpu_num \
    --node_rank ${node_rank:-0} \
    --master_addr ${MASTER_ADDR:-127.0.0.1} \
    --master_port ${MASTER_PORT:-26669}
"
torchrun $DISTRIBUTED_ARGS ./conformer_ctc/train.py --world-size $gpu_num --max-duration 200 --num-epochs 100. 

and u should write another script to start the training, including assign the node, the WORLD_SIZE, the gpus.

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

e.g., u have 4 machines, and each machine has 8-gpus, if one node assigns one gpu, the total nodes is 32, and you should pass $1=0,1,2,3...31, $2=32, $3='0', '1', '2', ... '7' one by one. Besides, if one node assigns 2 gpus, the total nodes is 16, and you should pass $1=0,1,2,3...15, $2=16, $3='0,1', '2,3', '4,5', '6,7' respectively.

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

Could you describe how to run it for multi-node multi-GPU training?

yes, here is the code for main bash file:

node_rank=$1
WORLD_SIZE=$2
export CUDA_VISIBLE_DEVICES=$3
echo "WORKER INFO:: node_rank=$node_rank, WORLD_SIZE=$WORLD_SIZE, CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')

DISTRIBUTED_ARGS="
    --nnodes ${WORLD_SIZE:-1} \
    --nproc_per_node $gpu_num \
    --node_rank ${node_rank:-0} \
    --master_addr ${MASTER_ADDR:-127.0.0.1} \
    --master_port ${MASTER_PORT:-26669}
"
torchrun $DISTRIBUTED_ARGS ./conformer_ctc/train.py --world-size $gpu_num --max-duration 200 --num-epochs 100. 

and u should write another script to start the training, including assign the node, the WORLD_SIZE, the gpus.

and the single machine version is provided:

export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun --nproc_per_node $gpu_num ./conformer_ctc/train.py --world-size $gpu_num --max-duration 200 --num-epochs 100

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

Also, when I using decode.py for ctc_decoding, I found that the speed is really slow, even it has pasted several hours, the recognizing result is not generated. So I debug, finally found the num_workers caused this problem.
Moreever, I found the decoding process only supply one sample by one sample, it is still too slow, so I convert the Monocut to dict, which can make batch-way decoding work. I hope my code is helpful. :) @csukuangfj

@yfyeung
Copy link
Collaborator

yfyeung commented Dec 24, 2024

There is no need to modify egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py.

To enable multi-node multi-GPU support, simply modify the train.py file with the following changes:

Add

from icefall.dist import (
    cleanup_dist,
    get_local_rank,
    get_rank,
    get_world_size,
    setup_dist,
)
    parser.add_argument(
        "--use-multi-node",
        type=str2bool,
        default=False,
        help="""True if using multi-node multi-GPU.
        You are not supposed to set it directly.
        """,
    )
    if params.use_multi_node:
        local_rank = get_local_rank()
    else:
        local_rank = rank
    logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
    if world_size > 1:
        setup_dist(rank, world_size, params.master_port, params.use_multi_node)
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda", local_rank)
    logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
    if world_size > 1:
        logging.info("Using DDP")
        model = DDP(model, device_ids=[local_rank])
    if args.use_multi_node:
        rank = get_rank()
        world_size = get_world_size()
        args.world_size = world_size
        run(rank=rank, world_size=world_size, args=args)
    else:
        world_size = args.world_size
        assert world_size >= 1
        if world_size > 1:
            mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
        else:
            run(rank=0, world_size=1, args=args)

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

There is no need to modify egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py.

To enable multi-node multi-GPU support, simply modify the train.py file with the following changes:

Add

from icefall.dist import (
    cleanup_dist,
    get_local_rank,
    get_rank,
    get_world_size,
    setup_dist,
)
    parser.add_argument(
        "--use-multi-node",
        type=str2bool,
        default=False,
        help="""True if using multi-node multi-GPU.
        You are not supposed to set it directly.
        """,
    )
    if params.use_multi_node:
        local_rank = get_local_rank()
    else:
        local_rank = rank
    logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
    if world_size > 1:
        setup_dist(rank, world_size, params.master_port, params.use_multi_node)
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda", local_rank)
    logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
    if world_size > 1:
        logging.info("Using DDP")
        model = DDP(model, device_ids=[local_rank])
    if args.use_multi_node:
        rank = get_rank()
        world_size = get_world_size()
        args.world_size = world_size
        run(rank=rank, world_size=world_size, args=args)
    else:
        world_size = args.world_size
        assert world_size >= 1
        if world_size > 1:
            mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
        else:
            run(rank=0, world_size=1, args=args)

yeah, you are absolutely right. In addition, I think using barrier() is a must.

@czl66
Copy link
Author

czl66 commented Dec 24, 2024

There is no need to modify egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py.

To enable multi-node multi-GPU support, simply modify the train.py file with the following changes:

Add

from icefall.dist import (
    cleanup_dist,
    get_local_rank,
    get_rank,
    get_world_size,
    setup_dist,
)
    parser.add_argument(
        "--use-multi-node",
        type=str2bool,
        default=False,
        help="""True if using multi-node multi-GPU.
        You are not supposed to set it directly.
        """,
    )
    if params.use_multi_node:
        local_rank = get_local_rank()
    else:
        local_rank = rank
    logging.info(f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}")
    if world_size > 1:
        setup_dist(rank, world_size, params.master_port, params.use_multi_node)
    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda", local_rank)
    logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
    if world_size > 1:
        logging.info("Using DDP")
        model = DDP(model, device_ids=[local_rank])
    if args.use_multi_node:
        rank = get_rank()
        world_size = get_world_size()
        args.world_size = world_size
        run(rank=rank, world_size=world_size, args=args)
    else:
        world_size = args.world_size
        assert world_size >= 1
        if world_size > 1:
            mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
        else:
            run(rank=0, world_size=1, args=args)

By the way, if you set batch_size of test_dataloaders in egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py from None to some int value like 1 or 10, it will caused this error:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'lhotse.cut.mono.MonoCut'>.
That's why I modified. I hope this info is useful.

@yfyeung
Copy link
Collaborator

yfyeung commented Dec 24, 2024

I think there is no need for torch.distributed.barrier() because DistributedDataParallel already broadcasts parameters from rank 0 during initialization. This built-in synchronization ensures all ranks have consistent weights without requiring an explicit barrier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants