Skip to content

Commit 310fdaf

Browse files
authored
Introduce cache-dit to community optimization (#12366)
* docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * misc: update examples link * misc: update examples link * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * docs: introduce cache-dit to diffusers * Refine documentation for CacheDiT features Updated the wording for clarity and consistency in the documentation. Adjusted sections on cache acceleration, automatic block adapter, patch functor, and hybrid cache configuration.
1 parent dcb6dd9 commit 310fdaf

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@
8282
title: Token merging
8383
- local: optimization/deepcache
8484
title: DeepCache
85+
- local: optimization/cache_dit
86+
title: CacheDiT
8587
- local: optimization/tgate
8688
title: TGATE
8789
- local: optimization/xdit
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
## CacheDiT
2+
3+
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
4+
5+
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
6+
7+
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
8+
9+
<hfoptions id="install">
10+
<hfoption id="PyPI">
11+
12+
```bash
13+
pip3 install -U cache-dit
14+
```
15+
16+
</hfoption>
17+
<hfoption id="source">
18+
19+
```bash
20+
pip3 install git+https://github.com/vipshop/cache-dit.git
21+
```
22+
23+
</hfoption>
24+
</hfoptions>
25+
26+
Run the command below to view supported DiT pipelines.
27+
28+
```python
29+
>>> import cache_dit
30+
>>> cache_dit.supported_pipelines()
31+
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
32+
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
33+
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
34+
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
35+
```
36+
37+
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
38+
39+
40+
## Unified Cache API
41+
42+
CacheDiT works by matching specific input/output patterns as shown below.
43+
44+
![](https://github.com/vipshop/cache-dit/raw/main/assets/patterns-v1.png)
45+
46+
Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
47+
48+
```python
49+
import cache_dit
50+
from diffusers import DiffusionPipeline
51+
52+
# Can be any diffusion pipeline
53+
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
54+
55+
# One-line code with default cache options.
56+
cache_dit.enable_cache(pipe)
57+
58+
# Just call the pipe as normal.
59+
output = pipe(...)
60+
61+
# Disable cache and run original pipe.
62+
cache_dit.disable_cache(pipe)
63+
```
64+
65+
## Automatic Block Adapter
66+
67+
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
68+
69+
70+
```python
71+
from cache_dit import ForwardPattern, BlockAdapter
72+
73+
# Use 🔥BlockAdapter with `auto` mode.
74+
cache_dit.enable_cache(
75+
BlockAdapter(
76+
# Any DiffusionPipeline, Qwen-Image, etc.
77+
pipe=pipe, auto=True,
78+
# Check `📚Forward Pattern Matching` documentation and hack the code of
79+
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
80+
forward_pattern=ForwardPattern.Pattern_1,
81+
),
82+
)
83+
84+
# Or, manually setup transformer configurations.
85+
cache_dit.enable_cache(
86+
BlockAdapter(
87+
pipe=pipe, # Qwen-Image, etc.
88+
transformer=pipe.transformer,
89+
blocks=pipe.transformer.transformer_blocks,
90+
forward_pattern=ForwardPattern.Pattern_1,
91+
),
92+
)
93+
```
94+
95+
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
96+
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
97+
98+
```python
99+
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
100+
# single_transformer_blocks have different forward patterns.
101+
cache_dit.enable_cache(
102+
BlockAdapter(
103+
pipe=pipe, # FLUX.1, etc.
104+
transformer=pipe.transformer,
105+
blocks=[
106+
pipe.transformer.transformer_blocks,
107+
pipe.transformer.single_transformer_blocks,
108+
],
109+
forward_pattern=[
110+
ForwardPattern.Pattern_1,
111+
ForwardPattern.Pattern_3,
112+
],
113+
),
114+
)
115+
```
116+
117+
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
118+
119+
## Patch Functor
120+
121+
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
122+
123+
![](https://github.com/vipshop/cache-dit/raw/main/assets/patch-functor.png)
124+
125+
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
126+
127+
```python
128+
@BlockAdapterRegistry.register("HiDream")
129+
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
130+
from diffusers import HiDreamImageTransformer2DModel
131+
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
132+
133+
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
134+
return BlockAdapter(
135+
pipe=pipe,
136+
transformer=pipe.transformer,
137+
blocks=[
138+
pipe.transformer.double_stream_blocks,
139+
pipe.transformer.single_stream_blocks,
140+
],
141+
forward_pattern=[
142+
ForwardPattern.Pattern_0,
143+
ForwardPattern.Pattern_3,
144+
],
145+
# NOTE: Setup your custom patch functor here.
146+
patch_functor=HiDreamPatchFunctor(),
147+
**kwargs,
148+
)
149+
```
150+
151+
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
152+
153+
```python
154+
stats = cache_dit.summary(pipe)
155+
```
156+
157+
```python
158+
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
159+
160+
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
161+
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
162+
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
163+
```
164+
165+
## DBCache: Dual Block Cache
166+
167+
![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-v1.png)
168+
169+
DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
170+
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
171+
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
172+
173+
174+
```python
175+
import cache_dit
176+
from diffusers import FluxPipeline
177+
178+
pipe_or_adapter = FluxPipeline.from_pretrained(
179+
"black-forest-labs/FLUX.1-dev",
180+
torch_dtype=torch.bfloat16,
181+
).to("cuda")
182+
183+
# Default options, F8B0, 8 warmup steps, and unlimited cached
184+
# steps for good balance between performance and precision
185+
cache_dit.enable_cache(pipe_or_adapter)
186+
187+
# Custom options, F8B8, higher precision
188+
from cache_dit import BasicCacheConfig
189+
190+
cache_dit.enable_cache(
191+
pipe_or_adapter,
192+
cache_config=BasicCacheConfig(
193+
max_warmup_steps=8, # steps do not cache
194+
max_cached_steps=-1, # -1 means no limit
195+
Fn_compute_blocks=8, # Fn, F8, etc.
196+
Bn_compute_blocks=8, # Bn, B8, etc.
197+
residual_diff_threshold=0.12,
198+
),
199+
)
200+
```
201+
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
202+
203+
## TaylorSeer Calibrator
204+
205+
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
206+
207+
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
208+
209+
```python
210+
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
211+
212+
cache_dit.enable_cache(
213+
pipe_or_adapter,
214+
# Basic DBCache w/ FnBn configurations
215+
cache_config=BasicCacheConfig(
216+
max_warmup_steps=8, # steps do not cache
217+
max_cached_steps=-1, # -1 means no limit
218+
Fn_compute_blocks=8, # Fn, F8, etc.
219+
Bn_compute_blocks=8, # Bn, B8, etc.
220+
residual_diff_threshold=0.12,
221+
),
222+
# Then, you can use the TaylorSeer Calibrator to approximate
223+
# the values in cached steps, taylorseer_order default is 1.
224+
calibrator_config=TaylorSeerCalibratorConfig(
225+
taylorseer_order=1,
226+
),
227+
)
228+
```
229+
230+
> [!TIP]
231+
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
232+
233+
## Hybrid Cache CFG
234+
235+
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
236+
237+
```python
238+
from cache_dit import BasicCacheConfig
239+
240+
cache_dit.enable_cache(
241+
pipe_or_adapter,
242+
cache_config=BasicCacheConfig(
243+
...,
244+
# For example, set it as True for Wan 2.1, Qwen-Image
245+
# and set it as False for FLUX.1, HunyuanVideo, etc.
246+
enable_separate_cfg=True,
247+
),
248+
)
249+
```
250+
251+
## torch.compile
252+
253+
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
254+
255+
256+
```python
257+
cache_dit.enable_cache(pipe)
258+
259+
# Compile the Transformer module
260+
pipe.transformer = torch.compile(pipe.transformer)
261+
```
262+
263+
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
264+
265+
```python
266+
torch._dynamo.config.recompile_limit = 96 # default is 8
267+
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
268+
```
269+
270+
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.

0 commit comments

Comments
 (0)