Skip to content

Commit 9b834f8

Browse files
Add Pruna optimization framework documentation (#11688)
* Add Pruna optimization framework documentation - Introduced a new section for Pruna in the table of contents. - Added comprehensive documentation for Pruna, detailing its optimization techniques, installation instructions, and examples for optimizing and evaluating models * Enhance Pruna documentation with image alt text and code block formatting - Added alt text to images for better accessibility and context. - Changed code block syntax from diff to python for improved clarity. * Add installation section to Pruna documentation - Introduced a new installation section in the Pruna documentation to guide users on how to install the framework. - Enhanced the overall clarity and usability of the documentation for new users. * Update pruna.md * Update pruna.md * Update Pruna documentation for model optimization and evaluation - Changed section titles for consistency and clarity, from "Optimizing models" to "Optimize models" and "Evaluating and benchmarking optimized models" to "Evaluate and benchmark models". - Enhanced descriptions to clarify the use of `diffusers` models and the evaluation process. - Added a new example for evaluating standalone `diffusers` models. - Updated references and links for better navigation within the documentation. * Refactor Pruna documentation for clarity and consistency - Removed outdated references to FLUX-juiced and streamlined the explanation of benchmarking. - Enhanced the description of evaluating standalone `diffusers` models. - Cleaned up code examples by removing unnecessary imports and comments for better readability. * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * Enhance Pruna documentation with new examples and clarifications - Added an image to illustrate the optimization process. - Updated the explanation for sharing and loading optimized models on the Hugging Face Hub. - Clarified the evaluation process for optimized models using the EvaluationAgent. - Improved descriptions for defining metrics and evaluating standalone diffusers models. --------- Co-authored-by: Steven Liu <[email protected]>
1 parent 81426b0 commit 9b834f8

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
title: Caching
181181
- local: optimization/memory
182182
title: Reduce memory usage
183+
- local: optimization/pruna
184+
title: Pruna
183185
- local: optimization/xformers
184186
title: xFormers
185187
- local: optimization/tome
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Pruna
2+
3+
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
4+
5+
6+
| Technique | Description | Speed | Memory | Quality |
7+
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
8+
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. ||||
9+
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. ||||
10+
| `compiler` | Optimises the model with instructions for specific hardware. ||||
11+
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. ||||
12+
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. ||||
13+
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. ||||
14+
| `recoverer` | Restores the performance of a model after compression. ||||
15+
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. ||||
16+
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. || - ||
17+
18+
✅ (improves), ➖ (approx. the same), ❌ (worsens)
19+
20+
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
21+
22+
## Installation
23+
24+
Install Pruna with the following command.
25+
26+
```bash
27+
pip install pruna
28+
```
29+
30+
31+
## Optimize Diffusers models
32+
33+
A broad range of optimization algorithms are supported for Diffusers models as shown below.
34+
35+
<div class="flex justify-center">
36+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
37+
</div>
38+
39+
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
40+
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
41+
42+
> [!TIP]
43+
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
44+
45+
<div class="flex justify-center">
46+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
47+
</div>
48+
49+
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
50+
51+
```python
52+
import torch
53+
from diffusers import FluxPipeline
54+
55+
from pruna import PrunaModel, SmashConfig, smash
56+
57+
# load the model
58+
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
59+
pipe = FluxPipeline.from_pretrained(
60+
"black-forest-labs/FLUX.1-dev",
61+
torch_dtype=torch.bfloat16
62+
).to("cuda")
63+
64+
# define the configuration
65+
smash_config = SmashConfig()
66+
smash_config["factorizer"] = "qkv_diffusers"
67+
smash_config["compiler"] = "torch_compile"
68+
smash_config["torch_compile_target"] = "module_list"
69+
smash_config["cacher"] = "fora"
70+
smash_config["fora_interval"] = 2
71+
72+
# for the best results in terms of speed you can add these configs
73+
# however they will increase your warmup time from 1.5 min to 10 min
74+
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
75+
# smash_config["quantizer"] = "torchao"
76+
# smash_config["torchao_quant_type"] = "fp8dq"
77+
# smash_config["torchao_excluded_modules"] = "norm+embedding"
78+
79+
# optimize the model
80+
smashed_pipe = smash(pipe, smash_config)
81+
82+
# run the model
83+
smashed_pipe("a knitted purple prune").images[0]
84+
```
85+
86+
<div class="flex justify-center">
87+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
88+
</div>
89+
90+
After optimization, we can share and load the optimized model using the Hugging Face Hub.
91+
92+
```python
93+
# save the model
94+
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
95+
96+
# load the model
97+
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
98+
```
99+
100+
## Evaluate and benchmark Diffusers models
101+
102+
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
103+
104+
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
105+
106+
<hfoptions id="eval">
107+
<hfoption id="optimized model">
108+
109+
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
110+
111+
```python
112+
import torch
113+
from diffusers import FluxPipeline
114+
115+
from pruna import PrunaModel
116+
from pruna.data.pruna_datamodule import PrunaDataModule
117+
from pruna.evaluation.evaluation_agent import EvaluationAgent
118+
from pruna.evaluation.metrics import (
119+
ThroughputMetric,
120+
TorchMetricWrapper,
121+
TotalTimeMetric,
122+
)
123+
from pruna.evaluation.task import Task
124+
125+
# define the device
126+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
127+
128+
# load the model
129+
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
130+
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
131+
132+
# Define the metrics
133+
metrics = [
134+
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
135+
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
136+
TorchMetricWrapper("clip"),
137+
]
138+
139+
# Define the datamodule
140+
datamodule = PrunaDataModule.from_string("LAION256")
141+
datamodule.limit_datasets(10)
142+
143+
# Define the task and evaluation agent
144+
task = Task(metrics, datamodule=datamodule, device=device)
145+
eval_agent = EvaluationAgent(task)
146+
147+
# Evaluate smashed model and offload it to CPU
148+
smashed_pipe.move_to_device(device)
149+
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
150+
smashed_pipe.move_to_device("cpu")
151+
```
152+
153+
</hfoption>
154+
<hfoption id="standalone model">
155+
156+
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
157+
158+
```python
159+
import torch
160+
from diffusers import FluxPipeline
161+
162+
from pruna import PrunaModel
163+
164+
# load the model
165+
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
166+
pipe = FluxPipeline.from_pretrained(
167+
"black-forest-labs/FLUX.1-dev",
168+
torch_dtype=torch.bfloat16
169+
).to("cpu")
170+
wrapped_pipe = PrunaModel(model=pipe)
171+
```
172+
173+
</hfoption>
174+
</hfoptions>
175+
176+
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
177+
178+
> [!TIP]
179+
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
180+
181+
## Reference
182+
183+
- [Pruna](https://github.com/pruna-ai/pruna)
184+
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
185+
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
186+
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
187+

0 commit comments

Comments
 (0)