Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 31 additions & 28 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ModelVersion:
STAGE_2 = "aes_stage2"

DEFAULT_VERSION = STAGE_2

ENABLE_ANTI_BLUR_DEFAULT = False
ENABLE_REALISM_DEFAULT = False

Expand Down Expand Up @@ -60,13 +60,13 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
global pipeline

if (
pipeline
and loaded_pipeline_config["enable_realism"] == enable_realism
pipeline
and loaded_pipeline_config["enable_realism"] == enable_realism
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur
and model_version == loaded_pipeline_config["model_version"]
):
return

loaded_pipeline_config["enable_realism"] = enable_realism
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
loaded_pipeline_config["model_version"] = model_version
Expand Down Expand Up @@ -96,15 +96,15 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):


def generate_image(
input_image,
control_image,
prompt,
seed,
input_image,
control_image,
prompt,
seed,
width,
height,
guidance_scale,
num_steps,
infusenet_conditioning_scale,
guidance_scale,
num_steps,
infusenet_conditioning_scale,
infusenet_guidance_start,
infusenet_guidance_end,
enable_realism,
Expand Down Expand Up @@ -175,15 +175,15 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
4. *[Optional] Adjust advanced hyperparameters or apply optional LoRAs to meet personal needs.* Please refer to **important usage tips** under the Generated Image field.
5. **Click the "Generate" button to generate an image.** Enjoy!
""")

with gr.Row():
with gr.Column(scale=3):
with gr.Row():
ui_id_image = gr.Image(label="Identity Image", type="pil", scale=3, height=370, min_width=100)

with gr.Column(scale=2, min_width=100):
ui_control_image = gr.Image(label="Control Image [Optional]", type="pil", height=370, min_width=100)

ui_prompt_text = gr.Textbox(label="Prompt", value="Portrait, 4K, high quality, cinematic")
ui_model_version = gr.Dropdown(
label="Model Version",
Expand Down Expand Up @@ -231,42 +231,42 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
)

ui_btn_generate.click(
generate_image,
generate_image,
inputs=[
ui_id_image,
ui_control_image,
ui_prompt_text,
ui_seed,
ui_id_image,
ui_control_image,
ui_prompt_text,
ui_seed,
ui_width,
ui_height,
ui_guidance_scale,
ui_num_steps,
ui_infusenet_conditioning_scale,
ui_infusenet_guidance_start,
ui_guidance_scale,
ui_num_steps,
ui_infusenet_conditioning_scale,
ui_infusenet_guidance_start,
ui_infusenet_guidance_end,
ui_enable_realism,
ui_enable_anti_blur,
ui_model_version
],
outputs=[image_output],
],
outputs=[image_output],
concurrency_id="gpu"
)

with gr.Accordion("Local Gradio Demo for Developers", open=False):
gr.Markdown(
'Please refer to our GitHub repository to [run the InfiniteYou-FLUX gradio demo locally](https://github.com/bytedance/InfiniteYou#local-gradio-demo).'
)

gr.Markdown(
"""
---
### 📜 Disclaimer and Licenses
### 📜 Disclaimer and Licenses
Some images in this demo are from public domains or generated by models. These pictures are intended solely to show the capabilities of our research. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.

The use of the released code, model, and demo must strictly adhere to the respective licenses. Our code is released under the Apache 2.0 License, and our model is released under the Creative Commons Attribution-NonCommercial 4.0 International Public License for academic research purposes only. Any manual or automatic downloading of the face models from [InsightFace](https://github.com/deepinsight/insightface), the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) base model, LoRAs, etc., must follow their original licenses and be used only for academic research purposes.

This research aims to positively impact the Generative AI field. Users are granted freedom to create images using this tool, but they must comply with local laws and use it responsibly. The developers do not assume any responsibility for potential misuse.

### 📖 Citation

If you find InfiniteYou useful for your research or applications, please cite our paper:
Expand All @@ -285,6 +285,9 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
"""
)

if torch.backends.mps.is_available():
torch.set_default_device("mps:0")

download_models()

prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
Expand Down
73 changes: 45 additions & 28 deletions pipelines/pipeline_infu_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np
import torch
from diffusers.models import FluxControlNetModel
from facexlib.recognition import init_recognition_model
from facexlib.recognition import Backbone
from facexlib.utils import load_file_from_url
from huggingface_hub import snapshot_download
from insightface.app import FaceAnalysis
from insightface.utils import face_align
Expand All @@ -30,6 +31,19 @@
from .pipeline_flux_infusenet import FluxInfuseNetPipeline
from .resampler import Resampler

def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None):
if model_name == 'arcface':
model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to(device).eval()
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')

model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
model.eval()
model = model.to(device)
return model

def seed_everything(seed, deterministic=False):
"""Set random seed.
Expand All @@ -44,8 +58,11 @@ def seed_everything(seed, deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
Expand Down Expand Up @@ -87,9 +104,9 @@ def extract_arcface_bgr_embedding(in_image, landmark, arcface_model=None, in_set
arc_face_image = face_align.norm_crop(in_image, landmark=np.array(kps), image_size=112)
arc_face_image = torch.from_numpy(arc_face_image).unsqueeze(0).permute(0,3,1,2) / 255.
arc_face_image = 2 * arc_face_image - 1
arc_face_image = arc_face_image.cuda().contiguous()
arc_face_image = arc_face_image.contiguous()
if arcface_model is None:
arcface_model = init_recognition_model('arcface', device='cuda')
arcface_model = init_recognition_model('arcface', device='cuda' if torch.cuda.is_available() else 'cpu')
face_emb = arcface_model(arc_face_image)[0] # [512], normalized
return face_emb

Expand All @@ -98,34 +115,34 @@ def resize_and_pad_image(source_img, target_img_size):
# Get original and target sizes
source_img_size = source_img.size
target_width, target_height = target_img_size

# Determine the new size based on the shorter side of target_img
if target_width <= target_height:
new_width = target_width
new_height = int(target_width * (source_img_size[1] / source_img_size[0]))
else:
new_height = target_height
new_width = int(target_height * (source_img_size[0] / source_img_size[1]))

# Resize the source image using LANCZOS interpolation for high quality
resized_source_img = source_img.resize((new_width, new_height), Image.LANCZOS)

# Compute padding to center resized image
pad_left = (target_width - new_width) // 2
pad_top = (target_height - new_height) // 2

# Create a new image with white background
padded_img = Image.new("RGB", target_img_size, (255, 255, 255))
padded_img.paste(resized_source_img, (pad_left, pad_top))

return padded_img


class InfUFluxPipeline:
def __init__(
self,
base_model_path,
infu_model_path,
self,
base_model_path,
infu_model_path,
insightface_root_path = './',
image_proj_num_tokens=8,
infu_flux_version='v1.0',
Expand All @@ -134,7 +151,7 @@ def __init__(

self.infu_flux_version = infu_flux_version
self.model_version = model_version

# Load pipeline
try:
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
Expand Down Expand Up @@ -167,7 +184,7 @@ def __init__(
'After that, run the code again. If you have downloaded it, please use `base_model_path` to specify the correct path.')
print('\nIf you are using other models, please download them to a local directory and use `base_model_path` to specify the correct path.')
exit()
pipe.to('cuda', torch.bfloat16)
pipe.to(torch.empty(1).device, torch.bfloat16)
self.pipe = pipe

# Load image proj model
Expand All @@ -184,28 +201,28 @@ def __init__(
ff_mult=4,
)
image_proj_model_path = os.path.join(infu_model_path, 'image_proj_model.bin')
ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
ipm_state_dict = torch.load(image_proj_model_path, map_location=torch.device(torch.empty(1).device))
image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
del ipm_state_dict
image_proj_model.to('cuda', torch.bfloat16)
image_proj_model.to(torch.empty(1).device, torch.bfloat16)
image_proj_model.eval()

self.image_proj_model = image_proj_model

# Load face encoder
self.app_640 = FaceAnalysis(name='antelopev2',
self.app_640 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_640.prepare(ctx_id=0, det_size=(640, 640))

self.app_320 = FaceAnalysis(name='antelopev2',
self.app_320 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_320.prepare(ctx_id=0, det_size=(320, 320))

self.app_160 = FaceAnalysis(name='antelopev2',
self.app_160 = FaceAnalysis(name='antelopev2',
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.app_160.prepare(ctx_id=0, det_size=(160, 160))

self.arcface_model = init_recognition_model('arcface', device='cuda')
self.arcface_model = init_recognition_model('arcface', device='cuda' if torch.cuda.is_available() else 'cpu')

def load_loras(self, loras):
names, scales = [],[]
Expand All @@ -223,7 +240,7 @@ def _detect_face(self, id_image_cv2):
face_info = self.app_640.get(id_image_cv2)
if len(face_info) > 0:
return face_info

face_info = self.app_320.get(id_image_cv2)
if len(face_info) > 0:
return face_info
Expand All @@ -244,27 +261,27 @@ def __call__(
infusenet_conditioning_scale = 1.0,
infusenet_guidance_start = 0.0,
infusenet_guidance_end = 1.0,
):
):
# Extract ID embeddings
print('Preparing ID embeddings')
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
face_info = self._detect_face(id_image_cv2)
if len(face_info) == 0:
raise ValueError('No face detected in the input ID image')

face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
landmark = face_info['kps']
id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, self.arcface_model)
id_embed = id_embed.clone().unsqueeze(0).float().cuda()
id_embed = id_embed.clone().unsqueeze(0).float()
id_embed = id_embed.reshape([1, -1, 512])
id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16)
id_embed = id_embed.to(device=torch.empty(1).device, dtype=torch.bfloat16)
with torch.no_grad():
id_embed = self.image_proj_model(id_embed)
bs_embed, seq_len, _ = id_embed.shape
id_embed = id_embed.repeat(1, 1, 1)
id_embed = id_embed.view(bs_embed * 1, seq_len, -1)
id_embed = id_embed.to(device='cuda', dtype=torch.bfloat16)
id_embed = id_embed.to(device=torch.empty(1).device, dtype=torch.bfloat16)

# Load control image
print('Preparing the control image')
if control_image is not None:
Expand Down
10 changes: 7 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def main():
assert args.model_version in ['aes_stage2', 'sim_stage1'], 'Currently only supports model versions: aes_stage2 | sim_stage1'

# Set cuda device
torch.cuda.set_device(args.cuda_device)
if torch.cuda.is_available():
torch.cuda.set_device(args.cuda_device)
elif torch.backends.mps.is_available():
torch.set_default_device("mps:0")
print(f'Using cuda device: {torch.empty(1).device}')

# Load pipeline
infu_model_path = os.path.join(args.model_dir, f'infu_flux_{args.infu_flux_version}', args.model_version)
Expand All @@ -69,7 +73,7 @@ def main():
if args.enable_anti_blur_lora:
loras.append([os.path.join(lora_dir, 'flux_anti_blur_lora.safetensors'), 'anti_blur', 1.0])
pipe.load_loras(loras)

# Perform inference
if args.seed == 0:
args.seed = torch.seed() & 0xFFFFFFFF
Expand All @@ -84,7 +88,7 @@ def main():
infusenet_guidance_start=args.infusenet_guidance_start,
infusenet_guidance_end=args.infusenet_guidance_end,
)

# Save results
os.makedirs(args.out_results_dir, exist_ok=True)
index = len(os.listdir(args.out_results_dir))
Expand Down