Skip to content

CUDA Memory Issue even with quantized model #949

@Abhivadan

Description

@Abhivadan

I have been trying to finetune llama 3.1 8b model on a custom datatset on 3 GPUs of 15 gb each. However I am facing Cuda out of memory issue even with quantized model.
torchrun --nnodes 3 --nproc_per_node 1 finetuning2.py \ --enable_fsdp \ --quantization 4bit \ --model_name /Volume1/ocr/.models/Meta-Llama-3.1-8B-Instruct \ --use_peft \ --batch_size 1\ --peft_method lora \ --max_seq_len 512\ --lora_r 8\ --lora_alpha 16\ --output_dir /Volume1/ocr/.models/Meta-Llama-3.1-8B-PEFT/model \ --dataset custom_dataset \ -.

Image

Even after loading quantized version it takes about 30 Gb of GPus and crashes during finetuning process with error:

W0520 17:34:48.279826 3415756 site-packages/torch/distributed/run.py:766] 
W0520 17:34:48.279826 3415756 site-packages/torch/distributed/run.py:766] *****************************************
W0520 17:34:48.279826 3415756 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0520 17:34:48.279826 3415756 site-packages/torch/distributed/run.py:766] *****************************************
/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Warning: unknown parameter batch_sizeWarning: unknown parameter batch_size

Warning: unknown parameter max_seq_lenWarning: unknown parameter max_seq_len

Warning: unknown parameter lora_rWarning: unknown parameter batch_sizeWarning: unknown parameter lora_r


Warning: unknown parameter lora_alphaWarning: unknown parameter max_seq_lenWarning: unknown parameter lora_alpha


Warning: unknown parameter lora_r
Warning: unknown parameter lora_alpha
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████████| 4/4 [00:55<00:00, 13.76s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████| 4/4 [00:55<00:00, 13.76s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████| 4/4 [00:54<00:00, 13.75s/it]
--> Model /Volume1/ocr/.models/Meta-Llama-3.1-8B-Instruct

--> /Volume1/ocr/.models/Meta-Llama-3.1-8B-Instruct has 1050.939392 Million params

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
bFloat16 enabled for mixed precision - using bfSixteen policy
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Training Set Length = 1000
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1453.93it/s]
length of dataset_train 116
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1475.40it/s]
length of dataset_train 116
--> Validation Set Length = 1000
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1475.43it/s]
length of dataset_train 116
Can not find the custom data_collator in the dataset.py file (/Volume1/ocr/codes/finetuning_llama/guanco_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 9
Preprocessing dataset:  45%|██████████████████▉                       | 452/1000 [00:00<00:00, 1450.62it/s]Can not find the custom data_collator in the dataset.py file (/Volume1/ocr/codes/finetuning_llama/guanco_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 9
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1506.42it/s]
--> Num of Validation Set Batches loaded = 38
--> Num of Validation Set Batches loaded = 38
Starting epoch 0/3
train_config.max_train_step: 0
Preprocessing dataset:  76%|████████████████████████████████          | 763/1000 [00:00<00:00, 1515.31it/s]/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/cuda/memory.py:489: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1501.01it/s]
--> Num of Validation Set Batches loaded = 38
--> Num of Validation Set Batches loaded = 38
Starting epoch 0/3
train_config.max_train_step: 0
Can not find the custom data_collator in the dataset.py file (/Volume1/ocr/codes/finetuning_llama/guanco_dataset.py).
Using the default data_collator instead.
--> Num of Training Set Batches loaded = 9
Preprocessing dataset:   0%|                                                      | 0/1000 [00:00<?, ?it/s]/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/cuda/memory.py:489: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|█████████████████████████████████████████| 1000/1000 [00:00<00:00, 1241.26it/s]
--> Num of Validation Set Batches loaded = 38
--> Num of Validation Set Batches loaded = 38
Starting epoch 0/3
train_config.max_train_step: 0
/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/cuda/memory.py:489: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                             | 0/9 [00:00<?, ?it/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/Volume1/ocr/codes/finetuning_llama/finetuning2.py", line 8, in <module>
[rank0]:     fire.Fire(main)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 135, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 468, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/finetuning.py", line 406, in main
[rank0]:     results = train(
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/utils/train_utils.py", line 153, in train
[rank0]:     loss = model(**batch).loss
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/peft_model.py", line 1757, in forward
[rank0]:     return self.base_model(
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 837, in forward
[rank0]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 125, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.91 GiB. GPU 0 has a total capacity of 14.61 GiB of which 1.99 GiB is free. Including non-PyTorch memory, this process has 12.62 GiB memory in use. Of the allocated memory 8.50 GiB is allocated by PyTorch, and 3.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank2]: Traceback (most recent call last):
[rank2]:   File "/Volume1/ocr/codes/finetuning_llama/finetuning2.py", line 8, in <module>
[rank2]:     fire.Fire(main)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 135, in Fire
[rank2]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 468, in _Fire
[rank2]:     component, remaining_args = _CallAndUpdateTrace(
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank2]:     component = fn(*varargs, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/finetuning.py", line 406, in main
[rank2]:     results = train(
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/utils/train_utils.py", line 153, in train
[rank2]:     loss = model(**batch).loss
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
[rank2]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/peft_model.py", line 1757, in forward
[rank2]:     return self.base_model(
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank2]:     return self.model.forward(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank2]:     output = func(self, *args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 837, in forward
[rank2]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]:     return self._call_impl(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]:     return forward_call(*args, **kwargs)
[rank2]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 125, in forward
[rank2]:     return F.linear(input, self.weight, self.bias)
[rank2]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.91 GiB. GPU 2 has a total capacity of 14.61 GiB of which 1.99 GiB is free. Including non-PyTorch memory, this process has 12.62 GiB memory in use. Of the allocated memory 8.50 GiB is allocated by PyTorch, and 3.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank1]: Traceback (most recent call last):
[rank1]:   File "/Volume1/ocr/codes/finetuning_llama/finetuning2.py", line 8, in <module>
[rank1]:     fire.Fire(main)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 135, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 468, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/finetuning.py", line 406, in main
[rank1]:     results = train(
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/llama_cookbook/utils/train_utils.py", line 153, in train
[rank1]:     loss = model(**batch).loss
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 856, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/peft_model.py", line 1757, in forward
[rank1]:     return self.base_model(
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank1]:     return self.model.forward(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/generic.py", line 965, in wrapper
[rank1]:     output = func(self, *args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 837, in forward
[rank1]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 125, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.91 GiB. GPU 1 has a total capacity of 14.61 GiB of which 1.99 GiB is free. Including non-PyTorch memory, this process has 12.62 GiB memory in use. Of the allocated memory 8.50 GiB is allocated by PyTorch, and 3.96 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Training Epoch: 1:   0%|                                                             | 0/9 [00:24<?, ?it/s]
Training Epoch: 1:   0%|                                                             | 0/9 [00:26<?, ?it/s]
Training Epoch: 1:   0%|                                                             | 0/9 [00:25<?, ?it/s]
W0520 17:36:37.212563 3415756 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3415828 closing signal SIGTERM
W0520 17:36:37.213461 3415756 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 3415830 closing signal SIGTERM
E0520 17:36:37.779025 3415756 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: 1) local_rank: 1 (pid: 3415829) of binary: /Volume1/conda_env/llama_env/bin/python3.9
Traceback (most recent call last):
  File "/Volume1/conda_env/llama_env/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/run.py", line 892, in main
    run(args)
  File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/run.py", line 883, in run
    elastic_launch(
  File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 139, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/Volume1/conda_env/llama_env/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 270, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
finetuning2.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):. 

Also when I try to use only 1 gpu instead of all 3 it takes about 11gb of space which appears somewhat strange since earlier it used 30GB of space.

Image.

I want to know is there something that i can do to avoid this issue.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions