Skip to content

Commit 2b56ba2

Browse files
committed
Use text_encoder from cache
1 parent e79d0b7 commit 2b56ba2

File tree

7 files changed

+60
-7
lines changed

7 files changed

+60
-7
lines changed

data/repo/diffusion_scripts/sd_cuda_safe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
PipeDevice = Device(opt.device, fptype)
2222

23-
pipe = PipeDevice.GetPipe(opt.mdlpath, opt.mode, opt.nsfw)
23+
pipe = PipeDevice.GetPipe(opt.mdlpath, opt.mode, opt.nsfw, "")
2424
pipe.to(PipeDevice.device, fptype)
2525

2626
if opt.dlora:

data/repo/diffusion_scripts/sd_onnx_safe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
pipe = None
1414
PipeDevice = Device("onnx", torch.float32)
1515

16-
pipe = PipeDevice.GetPipe(opt.mdlpath, opt.mode, opt.nsfw)
16+
pipe = PipeDevice.GetPipe(opt.mdlpath, opt.mode, opt.nsfw, opt.textencoder)
1717
safe_unet = pipe.unet
1818

1919
print("SD: Model preload: done")
@@ -43,7 +43,7 @@
4343
tokenizer_extract = False
4444

4545
if (data['LoRA'] != old_lora_json) or (data['TI'] != old_te_json) or (data['TINeg'] != old_ten_json):
46-
onnx_te_model = onnx.load(opt.mdlpath + "/text_encoder/" + ONNX_MODEL)
46+
onnx_te_model = onnx.load(opt.textencoder + ONNX_MODEL)
4747

4848
# Hard reload
4949
old_lora_json = None

data/repo/diffusion_scripts/sd_xbackend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def LPW_Path(self):
6969

7070
return dir_path
7171

72-
def GetPipe(self, Model: str, Mode: str, NSFW: bool):
72+
def GetPipe(self, Model: str, Mode: str, NSFW: bool, TextEnc: str):
7373
pipe = None
7474
nsfw_pipe = None
7575

@@ -87,11 +87,13 @@ def GetPipe(self, Model: str, Mode: str, NSFW: bool):
8787
nsfw_pipe = OnnxRuntimeModel.from_pretrained(safety_model, provider=self.prov)
8888
print (Mode)
8989
if Mode == "txt2img":
90-
pipe = OnnxStableDiffusionPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe)
90+
pipe = OnnxStableDiffusionPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe, text_encoder=None)
9191
if Mode == "img2img":
92-
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe)
92+
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe, text_encoder=None)
9393
if Mode == "inpaint":
94-
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe)
94+
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(Model, custom_pipeline=self.LPW_Path(), provider=self.prov, safety_checker=nsfw_pipe, text_encoder=None)
95+
96+
pipe.text_encoder = OnnxRuntimeModel.from_pretrained(TextEnc, provider=self.prov)
9597
else:
9698
if Mode == "pix2pix":
9799
if NSFW:
@@ -399,6 +401,9 @@ def ApplyArg(parser):
399401
parser.add_argument(
400402
"--model", type=str, help="Path to model checkpoint (.ckpt or .safetensors)", dest='mdlpath',
401403
)
404+
parser.add_argument(
405+
"--textencoder", type=str, help="Path to model checkpoint (.onnx)", dest='textencoder',
406+
)
402407
parser.add_argument(
403408
"--workdir", default=None, type=str, help="Path to working directory", dest='workdir',
404409
)

ui-src/MainWindow.xaml.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,13 @@ private void Button_ClickBreak(object sender, RoutedEventArgs e)
223223
if (!Directory.Exists(WorkingPath))
224224
{
225225
if (GlobalVariables.Mode == Helper.ImplementMode.DiffCUDA)
226+
{
226227
Install.CheckAndInstallCUDA();
228+
}
227229
else
230+
{
228231
Install.CheckAndInstallONNX();
232+
}
229233

230234
return;
231235
}
@@ -264,6 +268,7 @@ private void btnONNX_Click(object sender, RoutedEventArgs e)
264268

265269
GlobalVariables.Mode = Helper.ImplementMode.ONNX;
266270
Install.CheckAndInstallONNX();
271+
Task.Run(() => TextEncoderONNX.CheckEncoder());
267272

268273
Brush Safe = new SolidColorBrush(Colors.Black);
269274

ui-src/Utils/ModelCMD.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
using Newtonsoft.Json;
2+
using SD_FXUI.Utils.Models;
23
using System;
34
using System.Collections.Generic;
45
using System.Linq;
56
using System.Text;
7+
using System.Text.Encodings.Web;
68
using System.Threading.Tasks;
79
using System.Windows.Xps.Serialization;
810

@@ -73,6 +75,7 @@ public void PreStart(string StartModel, string StartMode, bool StartNSFW, bool I
7375
{
7476
CmdLine += " --nsfw=True ";
7577
}
78+
CmdLine += $" --textencoder={TextEncoderONNX.GetModelDir()} ";
7679

7780
Process = new Host(FS.GetWorkingDir(), "repo/" + PythonEnv.GetPy(Helper.VENV.DiffONNX));
7881
Process.Start("./repo/diffusion_scripts/sd_onnx_safe.py " + CmdLine);
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System.IO;
2+
3+
namespace SD_FXUI.Utils.Models
4+
{
5+
internal class TextEncoderONNX
6+
{
7+
static public void CheckEncoder()
8+
{
9+
if (GlobalVariables.Mode == Helper.ImplementMode.ONNX)
10+
{
11+
if (!File.Exists(GetModel()))
12+
{
13+
WGetDownloadModels.DownloadTextEncoder();
14+
}
15+
}
16+
}
17+
18+
static public string GetModel()
19+
{
20+
return FS.GetModelDir() + "text_encoder/model.onnx";
21+
}
22+
static public string GetModelDir()
23+
{
24+
return FS.GetModelDir() + "text_encoder/";
25+
}
26+
27+
}
28+
}

ui-src/Utils/Models/WGetDownloadModels.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ public static void DownloadSDPoser()
5050
FileDownloader.DownloadFileAsync("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/raw/main/config.json", WorkingDir + @"config.json");
5151
FileDownloader.DownloadFileAsync("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/diffusion_pytorch_model.fp16.bin", WorkingDir + @"diffusion_pytorch_model.bin");
5252
}
53+
public static void DownloadTextEncoder()
54+
{
55+
string WorkingDir = FS.GetModelDir() + "text_encoder/";
56+
57+
if (Directory.Exists(WorkingDir))
58+
{
59+
return;
60+
}
61+
Directory.CreateDirectory(WorkingDir);
62+
63+
FileDownloader.DownloadFileAsync("https://huggingface.co/ForserX/TextEncoderBackupONNX/resolve/main/model.onnx", WorkingDir + @"model.onnx");
64+
}
5365

5466
public static void DownloadSDFacegen()
5567
{

0 commit comments

Comments
 (0)