Skip to content

Commit aaa6ea5

Browse files
committed
update app.py
1 parent 5fff05b commit aaa6ea5

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

app.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
7979
gc.collect()
8080
torch.cuda.empty_cache()
8181

82-
model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
82+
if model_version == 'aes_stage2':
83+
model_path = f'./models/InfiniteYou/infu_flux_v1.0/aes_stage2'
84+
elif model_version == 'sim_stage1':
85+
model_path = f'./models/InfiniteYou/infu_flux_v1.0/sim_stage1'
86+
else:
87+
raise ValueError(f'Model version {model_version} not supported.')
8388
print(f'Loading model from {model_path}')
8489

8590
pipeline = InfUFluxPipeline(
@@ -307,5 +312,6 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
307312
prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
308313

309314
demo.queue()
310-
demo.launch(server_name='0.0.0.0') # IPv4
315+
demo.launch(server_name='localhost') # localhost
316+
# demo.launch(server_name='0.0.0.0') # IPv4
311317
# demo.launch(server_name='[::]') # IPv6

pipelines/pipeline_infu_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def __init__(
194194
ff_mult=4,
195195
)
196196
image_proj_model_path = os.path.join(infu_model_path, 'image_proj_model.bin')
197-
ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
197+
ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu", weights_only=True)
198198
image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
199199
del ipm_state_dict
200200
image_proj_model.to('cuda', torch.bfloat16)

0 commit comments

Comments
 (0)