Skip to content

[GGUF] Reduce peak RAM usage by casting dequantized tensors early during load#45386

Open
UsamaKenway wants to merge 3 commits intohuggingface:mainfrom
UsamaKenway:gguf-early-dtype-casting
Open

[GGUF] Reduce peak RAM usage by casting dequantized tensors early during load#45386
UsamaKenway wants to merge 3 commits intohuggingface:mainfrom
UsamaKenway:gguf-early-dtype-casting

Conversation

@UsamaKenway
Copy link
Copy Markdown

@UsamaKenway UsamaKenway commented Apr 12, 2026

Optimizes memory usage when loading GGUF models by performing dtype casting immediately after dequantization.

While I was adding the support for Gemma4 in this PR #45296, i noticed this issue that the GGUF tensors are dequantized to float32 by default during the loading process, even if the user intends to load the model in float16 or bfloat16. For large models, this creates a significant RAM spike that can lead to Out Of Memory.

By passing the target torch_dtype directly into the loading utility, we can cast the tensors immediately after dequantization, effectively halving the peak RAM required for the state dict.

Benchmark Results (Gemma 4 26B IT q4_k_m)

I tested the peak RAM (Global Peak RSS) with and without this change using a separate branch for tracking:

- Without this PR (Float32 spike): ~118.7 GB Peak RSS
- With this PR (Early casting):    ~59.4  GB Peak RSS
------------------------------------------------------
Saving:                            ~59.3  GB (50% reduction)
Tests

With the changes

(py312venv) usamakenway@Megatron: RUN_SLOW=1 pytest tests/quantization/ggml/test_ggml.py::GgufModelTests::test_gemma4_26b_it_q4_k_m -s                                                                                                                                                                                                                                                                                                                                                   

[RAM DEBUG] Global Peak RSS (High Water Mark): 2185.59 MB

[RAM DEBUG] Global Peak RSS (High Water Mark): 2197.88 MB

[RAM DEBUG] Global Peak RSS (High Water Mark): 2391.64 MB
Converting and de-quantizing GGUF tensors...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 658/658 [03:03<00:00,  3.59it/s]

[RAM DEBUG] Global Peak RSS (High Water Mark): 59428.81 MB
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 657/657 [00:00<00:00, 26841.26it/s]
PASSEDtests/quantization/ggml/test_ggml.py::GgufModelTests::test_gemma4_26b_it_q4_k_m [PASSED] 287.34s

Without the changes

(py312venv) usamakenway@Megatron:  RUN_SLOW=1 pytest tests/quantization/ggml/test_ggml.py::GgufModelTests::test_gemma4_26b_it_q4_k_m -s
                                                                                                                                       
tests/quantization/ggml/test_ggml.py::GgufModelTests::test_gemma4_26b_it_q4_k_m 
[RAM DEBUG] Global Peak RSS (High Water Mark): 2259.14 MB

[RAM DEBUG] Global Peak RSS (High Water Mark): 2270.75 MB

[RAM DEBUG] Global Peak RSS (High Water Mark): 2464.60 MB
Converting and de-quantizing GGUF tensors...: 100%|███████████████████████████████████████████████████████████████████████████████| 658/658 [05:46<00:00,  1.90it/s]

[RAM DEBUG] Global Peak RSS (High Water Mark): 118747.33 MB
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 657/657 [00:03<00:00, 195.35it/s]
PASSEDtests/quantization/ggml/test_ggml.py::GgufModelTests::test_gemma4_26b_it_q4_k_m [PASSED] 499.38s

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Copy Markdown
Member

cc @SunMarc

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, a couple of comments

Comment on lines 4092 to 4096
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move that above so that we don't have to replicate the dtype logic ?


parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
tensor = torch.from_numpy(np.copy(weights))
if torch_dtype is not None and torch_dtype != torch.float32:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need the fp32 check ?

Signed-off-by: Usama Kenway <usamakenway@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants