Skip to content

Commit ce31a77

Browse files
authored
[PPDiffusers] Fix ppdiffusers bug and support ZH stablediffusion (#3663)
* fix win download bug, etc * add attention mask support zh model * pad to max length * update NEG_NF * update diffusers readme * fix windows download * fix windows download
1 parent 58a5a9d commit ce31a77

30 files changed

+314
-440
lines changed

paddlenlp/transformers/clip/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def quick_gelu(x):
533533

534534
F.quick_gelu = quick_gelu
535535

536-
NEG_INF = float("-inf") # -1e4 -1e9
536+
NEG_INF = -1e9 # float("-inf") -1e4 -1e9
537537

538538

539539
class VisionTransformer(nn.Layer):

ppdiffusers/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
## 1. News 📢
77

8+
* 🔥 **2022.11.04 支持 IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 和 IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1 中文权重**
89
* 🔥 **2022.10.27 发布 PPDiffusers仓库**
910

1011

@@ -39,7 +40,7 @@ python setup.py install
3940

4041
## 4. 使用PPDiffusers快速体验Stable Diffusion模型!
4142

42-
Stable Diffusion 是一个**文本到图像(text-to-image)****潜在扩散模型(latent diffusion model, ldm)**, 该模型是由来自[CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) 的工程师以及 [RunwayML](https://runwayml.com/)一起开发而完成的。该模型使用了大小为**512x512**[LAION-5B](https://laion.ai/blog/laion-5b/)数据集子集进行训练。该模型使用了Openai开源的**CLIP ViT-L/14** 文本编码器(text_encoder)来编码提示(prompt)文本,从而作为引导条件(注意该部分权重不进行训练)。该模型使用了Unet模型(860M参数)和text encoder(123M参数),并且可以在具有4GB显存(注:当前paddle版本需要进行优化,无法在4GB的显卡上运行)的GPU进行推理预测
43+
Stable Diffusion 是一个**文本到图像(text-to-image)****潜在扩散模型(latent diffusion model, ldm)**, 该模型是由来自[CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [LAION](https://laion.ai/) 的工程师以及 [RunwayML](https://runwayml.com/)一起开发而完成的。该模型使用了大小为**512x512**[LAION-5B](https://laion.ai/blog/laion-5b/)数据集子集进行训练。该模型使用了Openai开源的**CLIP ViT-L/14** 文本编码器(text_encoder)来编码提示(prompt)文本,从而作为引导条件(注意该部分权重不进行训练)。该模型使用了Unet模型(860M参数)和text encoder(123M参数),并且可以在具有4GB显存的GPU上进行推理预测
4344

4445
___注意___:
4546
___为了方便国内用户下载使用及快速体验Stable Diffusion模型,我们在百度云(BOS)上提供了paddle版本的镜像权重。注意:为了使用该模型与权重,你必须接受该模型所要求的**License**,请访问huggingface的[model card](https://huggingface.co/runwayml/stable-diffusion-v1-5), 仔细阅读里面的**License**,然后签署该协议。___
@@ -61,8 +62,7 @@ image = pipe(prompt).images[0]
6162

6263
image.save("astronaut_rides_horse.png")
6364
```
64-
<center><image src="https://user-images.githubusercontent.com/50394665/197779466-04543823-8b83-41d6-94e8-146a7dac00d7.png" width="600"></center>
65-
65+
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197779466-04543823-8b83-41d6-94e8-146a7dac00d7.png">
6666

6767
### 4.2 使用Stable Diffusion进行由文本引导的图片-图片的生成
6868

@@ -74,10 +74,10 @@ from io import BytesIO
7474

7575
from ppdiffusers import StableDiffusionImg2ImgPipeline
7676

77-
# load the pipeline
77+
# 加载pipeline
7878
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
7979

80-
# let's download an initial image
80+
# 下载初始图片
8181
url = "https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/sketch-mountains-input.png"
8282

8383
response = requests.get(url)
@@ -92,7 +92,7 @@ with paddle.amp.auto_cast(True):
9292
image.save("fantasy_landscape.png")
9393
```
9494

95-
<center><image src="https://user-images.githubusercontent.com/50394665/197780044-34e6f8ca-6864-4c3d-bb99-28e0aadf867b.png" width="600"></center>
95+
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197780044-34e6f8ca-6864-4c3d-bb99-28e0aadf867b.png">
9696

9797

9898
### 4.3 使用Stable Diffusion根据文本补全图片
@@ -125,7 +125,7 @@ with paddle.amp.auto_cast(True):
125125

126126
image.save("cat_on_bench.png")
127127
```
128-
<center><image src="https://user-images.githubusercontent.com/50394665/197783711-ab3caf2e-5a4d-4099-8d01-d6ca80ca8e78.png" width="600"></center>
128+
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/197783711-ab3caf2e-5a4d-4099-8d01-d6ca80ca8e78.png">
129129

130130
Tips: 下面的使用方法是新版本的代码,也是官方推荐的代码,注意必须配合**runwayml/stable-diffusion-inpainting**才可正常使用。
131131
```python
@@ -153,7 +153,7 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
153153

154154
image.save("cat_on_bench_new.png")
155155
```
156-
<center><image src="https://user-images.githubusercontent.com/50394665/198016801-87cec13b-0d89-41c3-aedb-c89a43d76153.png" width="600"></center>
156+
<img width="600" alt="image" src="https://user-images.githubusercontent.com/50394665/198016801-87cec13b-0d89-41c3-aedb-c89a43d76153.png">
157157

158158
## 5. Credits
159159

ppdiffusers/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.6.0.dev1
1+
0.6.1

ppdiffusers/examples/community/clip_guided_stable_diffusion.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ def __call__(
262262
"The following part of your input was truncated because CLIP can only handle sequences up to"
263263
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
264264
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
265-
text_embeddings = self.text_encoder(text_input_ids)[0]
265+
266+
attention_mask = paddle.ones_like(text_input_ids)
267+
text_embeddings = self.text_encoder(text_input_ids,
268+
attention_mask=attention_mask)[0]
266269

267270
# duplicate text embeddings for each generation per prompt
268271
bs_embed, seq_len, _ = text_embeddings.shape
@@ -323,7 +326,9 @@ def __call__(
323326
truncation=True,
324327
return_tensors="pd",
325328
)
326-
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
329+
attention_mask = paddle.ones_like(uncond_input.input_ids)
330+
uncond_embeddings = self.text_encoder(
331+
uncond_input.input_ids, attention_mask=attention_mask)[0]
327332

328333
# duplicate unconditional embeddings for each generation per prompt
329334
seq_len = uncond_embeddings.shape[1]

ppdiffusers/examples/community/composable_stable_diffusion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
logger.warn(
9090
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
9191
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
92-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
92+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
9393
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
9494
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
9595
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
@@ -253,7 +253,9 @@ def __call__(
253253
"The following part of your input was truncated because CLIP can only handle sequences up to"
254254
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
255255
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
256-
text_embeddings = self.text_encoder(text_input_ids)[0]
256+
attention_mask = paddle.ones_like(text_input_ids)
257+
text_embeddings = self.text_encoder(text_input_ids,
258+
attention_mask=attention_mask)[0]
257259

258260
# duplicate text embeddings for each generation per prompt, using mps friendly method
259261
# bs_embed, seq_len, _ = text_embeddings.shape
@@ -318,7 +320,9 @@ def __call__(
318320
truncation=True,
319321
return_tensors="pd",
320322
)
321-
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
323+
attention_mask = paddle.ones_like(uncond_input.input_ids)
324+
uncond_embeddings = self.text_encoder(
325+
uncond_input.input_ids, attention_mask=attention_mask)[0]
322326

323327
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
324328
# seq_len = uncond_embeddings.shape[1]

ppdiffusers/examples/community/interpolate_stable_diffusion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
logger.warn(
119119
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
120120
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
121-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
121+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
122122
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
123123
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
124124
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
@@ -277,7 +277,9 @@ def __call__(
277277
)
278278
text_input_ids = text_input_ids[:, :self.tokenizer.
279279
model_max_length]
280-
text_embeddings = self.text_encoder(text_input_ids)[0]
280+
attention_mask = paddle.ones_like(text_input_ids)
281+
text_embeddings = self.text_encoder(
282+
text_input_ids, attention_mask=attention_mask)[0]
281283
else:
282284
batch_size = text_embeddings.shape[0]
283285

@@ -318,7 +320,9 @@ def __call__(
318320
truncation=True,
319321
return_tensors="pd",
320322
)
321-
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
323+
attention_mask = paddle.ones_like(uncond_input.input_ids)
324+
uncond_embeddings = self.text_encoder(
325+
uncond_input.input_ids, attention_mask=attention_mask)[0]
322326

323327
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
324328
seq_len = uncond_embeddings.shape[1]

ppdiffusers/examples/community/lpw_stable_diffusion.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060

6161
def parse_prompt_attention(text):
62-
"""
62+
r"""
6363
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
6464
Accepted tokens are:
6565
(abc) - increases attention to abc by a multiplier of 1.1
@@ -186,6 +186,7 @@ def pad_tokens_and_weights(tokens,
186186
max_length,
187187
bos,
188188
eos,
189+
pad,
189190
no_boseos_middle=True,
190191
chunk_length=77):
191192
r"""
@@ -194,8 +195,9 @@ def pad_tokens_and_weights(tokens,
194195
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
195196
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
196197
for i in range(len(tokens)):
197-
tokens[i] = [bos
198-
] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
198+
tokens[i] = [bos] + tokens[i] + [
199+
eos
200+
] + [pad] * (max_length - 2 - len(tokens[i]))
199201
if no_boseos_middle:
200202
weights[i] = [
201203
1.0
@@ -238,7 +240,9 @@ def get_unweighted_text_embeddings(
238240
# cover the head and the tail by the starting and the ending tokens
239241
text_input_chunk[:, 0] = text_input[0, 0]
240242
text_input_chunk[:, -1] = text_input[0, -1]
241-
text_embedding = pipe.text_encoder(text_input_chunk)[0]
243+
attention_mask = paddle.ones_like(text_input_chunk)
244+
text_embedding = pipe.text_encoder(text_input_chunk,
245+
attention_mask=attention_mask)[0]
242246

243247
if no_boseos_middle:
244248
if i == 0:
@@ -254,7 +258,9 @@ def get_unweighted_text_embeddings(
254258
text_embeddings.append(text_embedding)
255259
text_embeddings = paddle.concat(text_embeddings, axis=1)
256260
else:
257-
text_embeddings = pipe.text_encoder(text_input)[0]
261+
attention_mask = paddle.ones_like(text_input)
262+
text_embeddings = pipe.text_encoder(text_input,
263+
attention_mask=attention_mask)[0]
258264
return text_embeddings
259265

260266

@@ -336,14 +342,17 @@ def get_weighted_text_embeddings(
336342
2) * max_embeddings_multiples + 2
337343

338344
# pad the length of tokens and weights
339-
bos = pipe.tokenizer.bos_token_id
340-
eos = pipe.tokenizer.eos_token_id
345+
# support bert tokenizer
346+
bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
347+
eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
348+
pad = pipe.tokenizer.pad_token_id
341349
prompt_tokens, prompt_weights = pad_tokens_and_weights(
342350
prompt_tokens,
343351
prompt_weights,
344352
max_length,
345353
bos,
346354
eos,
355+
pad,
347356
no_boseos_middle=no_boseos_middle,
348357
chunk_length=pipe.tokenizer.model_max_length,
349358
)
@@ -355,6 +364,7 @@ def get_weighted_text_embeddings(
355364
max_length,
356365
bos,
357366
eos,
367+
pad,
358368
no_boseos_middle=no_boseos_middle,
359369
chunk_length=pipe.tokenizer.model_max_length,
360370
)
@@ -481,7 +491,7 @@ def __init__(
481491
logger.warn(
482492
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
483493
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
484-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
494+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
485495
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
486496
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
487497
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
@@ -753,7 +763,8 @@ def __call__(
753763
timesteps = timesteps.tile([
754764
batch_size * num_images_per_prompt,
755765
])
756-
766+
if seed is not None:
767+
paddle.seed(seed)
757768
noise = paddle.randn(
758769
init_latents.shape,
759770
dtype=latents_dtype,
@@ -926,8 +937,8 @@ def text2img(
926937

927938
def img2img(
928939
self,
929-
init_image: Union[paddle.Tensor, PIL.Image.Image],
930940
prompt: Union[str, List[str]],
941+
init_image: Union[paddle.Tensor, PIL.Image.Image],
931942
negative_prompt: Optional[Union[str, List[str]]] = None,
932943
strength: float = 0.8,
933944
num_inference_steps: Optional[int] = 50,
@@ -1016,9 +1027,9 @@ def img2img(
10161027

10171028
def inpaint(
10181029
self,
1030+
prompt: Union[str, List[str]],
10191031
init_image: Union[paddle.Tensor, PIL.Image.Image],
10201032
mask_image: Union[paddle.Tensor, PIL.Image.Image],
1021-
prompt: Union[str, List[str]],
10221033
negative_prompt: Optional[Union[str, List[str]]] = None,
10231034
strength: float = 0.8,
10241035
num_inference_steps: Optional[int] = 50,

ppdiffusers/examples/community/wildcard_stable_diffusion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
logger.warn(
154154
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
155155
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
156-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
156+
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
157157
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
158158
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
159159
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
@@ -298,7 +298,9 @@ def __call__(
298298
"The following part of your input was truncated because CLIP can only handle sequences up to"
299299
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
300300
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
301-
text_embeddings = self.text_encoder(text_input_ids)[0]
301+
attention_mask = paddle.ones_like(text_input_ids)
302+
text_embeddings = self.text_encoder(text_input_ids,
303+
attention_mask=attention_mask)[0]
302304

303305
# duplicate text embeddings for each generation per prompt, using mps friendly method
304306
bs_embed, seq_len, _ = text_embeddings.shape
@@ -337,7 +339,9 @@ def __call__(
337339
truncation=True,
338340
return_tensors="pd",
339341
)
340-
uncond_embeddings = self.text_encoder(uncond_input.input_ids)[0]
342+
attention_mask = paddle.ones_like(uncond_input.input_ids)
343+
uncond_embeddings = self.text_encoder(
344+
uncond_input.input_ids, attention_mask=attention_mask)[0]
341345

342346
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
343347
seq_len = uncond_embeddings.shape[1]

0 commit comments

Comments
 (0)