Skip to content

Reduce Memory Usage During Inference #298

@bigximik

Description

@bigximik

🎯 Goal (What & Why)

The current memory footprint during the inference forward pass is approximately 2.5× higher for the same model and batch size when using Fast-LLM compared to Hugging Face Transformers.

🧠 Qwen2 1.5B — Batch Size: 16×4096, Flash Attention, bfloat16, H100 GPU

Test Peak GPU Memory Usage (MB)
HF (no loss calculation) 22,162.28
HF (with loss calculation) 40,962.78
Fast-LLM (no loss calculation) 59,013.70
Fast-LLM (with loss calculation) OOM

What is a reasonable target for reducing Fast-LLM's memory usage?

🚀 Execution Plan

(This section may start as an incomplete draft but must be defined before implementation begins.)

Step 1: What is the smallest working version?

(Describe the simplest way to implement this feature with minimal effort.)

Step 2: What additional optimizations are possible (but optional)?

(List potential refinements that can be added in later PRs if needed.)

📌 Acceptance Criteria (Must-Haves for Completion)

  • The feature must be functional and tested.
  • The implementation must be documented in practical terms.
  • The PR must include a performance/impact summary.
  • No refactors unless directly necessary for feature completion.

🛠️ Project Management

  • Assign the project to the Fast-LLM project.
  • Set the Estimate field (in days) in the GitHub project.
  • Use the Size field to categorize the PR size (Small/Medium/Large).
  • Assign an owner when opening the issue.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions