Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,12 @@ def parse_args() -> None:
program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
program.add_argument('--color-correction', help='apply color correction to the swapped face', dest='color_correction', action='store_true', default=False) # Added this line back
program.add_argument('--nsfw-filter', help='filter the NSFW image or video', dest='nsfw_filter', action='store_true', default=False)
program.add_argument('--map-faces', help='map source target faces', dest='map_faces', action='store_true', default=False)
program.add_argument('--mouth-mask', help='mask the mouth region', dest='mouth_mask', action='store_true', default=False)
program.add_argument('--poisson-blending', help='use Poisson blending for smoother face integration', dest='poisson_blending', action='store_true', default=False)
program.add_argument('--preserve-ears', help='attempt to preserve target ears by modifying the blend mask', dest='preserve_ears', action='store_true', default=False)
program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
program.add_argument('-l', '--lang', help='Ui language', default="en")
Expand Down Expand Up @@ -69,7 +72,10 @@ def parse_args() -> None:
modules.globals.keep_audio = args.keep_audio
modules.globals.keep_frames = args.keep_frames
modules.globals.many_faces = args.many_faces
modules.globals.color_correction = args.color_correction
modules.globals.mouth_mask = args.mouth_mask
modules.globals.use_poisson_blending = args.poisson_blending
modules.globals.preserve_target_ears = args.preserve_ears
modules.globals.nsfw_filter = args.nsfw_filter
modules.globals.map_faces = args.map_faces
modules.globals.video_encoder = args.video_encoder
Expand Down
7 changes: 7 additions & 0 deletions modules/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,10 @@
mask_feather_ratio = 8
mask_down_size = 0.50
mask_size = 1
use_poisson_blending = False # Added for Poisson blending
poisson_blending_feather_amount = 5 # Feathering for the mask before Poisson blending
preserve_target_ears = False # Flag to enable preserving target's ears
ear_width_ratio = 0.18 # Width of the ear exclusion box as a ratio of face bbox width
ear_height_ratio = 0.35 # Height of the ear exclusion box as a ratio of face bbox height
ear_vertical_offset_ratio = 0.20 # Vertical offset of the ear box from top of face bbox
ear_horizontal_overlap_ratio = 0.03 # How much the ear exclusion zone can overlap into the face bbox
197 changes: 189 additions & 8 deletions modules/processors/frame/face_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,43 @@ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
face_swapper = get_face_swapper()

# Apply the face swap
swapped_frame = face_swapper.get(
swapped_frame_result = face_swapper.get( # Renamed to avoid confusion
temp_frame, target_face, source_face, paste_back=True
)

# Ensure swapped_frame_result is not None and is a valid image
if swapped_frame_result is None or not isinstance(swapped_frame_result, np.ndarray):
logging.error("Face swap operation failed or returned invalid result.")
return temp_frame # Return original frame if swap failed

# Color Correction
if modules.globals.color_correction:
# Get the bounding box of the target face to apply color correction
# more accurately to the swapped region.
# The target_face object should have bbox attribute (x1, y1, x2, y2)
if hasattr(target_face, 'bbox'):
x1, y1, x2, y2 = target_face.bbox.astype(int)
# Ensure coordinates are within frame bounds
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(swapped_frame_result.shape[1], x2), min(swapped_frame_result.shape[0], y2)

if x1 < x2 and y1 < y2:
swapped_face_region = swapped_frame_result[y1:y2, x1:x2]
target_face_region_original = temp_frame[y1:y2, x1:x2]

if swapped_face_region.size > 0 and target_face_region_original.size > 0:
corrected_swapped_face_region = apply_histogram_matching_color_correction(swapped_face_region, target_face_region_original)
swapped_frame_result[y1:y2, x1:x2] = corrected_swapped_face_region
else:
# Fallback to full frame color correction if regions are invalid
swapped_frame_result = apply_histogram_matching_color_correction(swapped_frame_result, temp_frame)
else:
# Fallback to full frame color correction if bbox is invalid
swapped_frame_result = apply_histogram_matching_color_correction(swapped_frame_result, temp_frame)
else:
# Fallback to full frame color correction if no bbox
swapped_frame_result = apply_histogram_matching_color_correction(swapped_frame_result, temp_frame)

if modules.globals.mouth_mask:
# Create a mask for the target face
face_mask = create_face_mask(target_face, temp_frame)
Expand All @@ -85,22 +118,136 @@ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
)

# Apply the mouth area
swapped_frame = apply_mouth_area(
swapped_frame, mouth_cutout, mouth_box, face_mask, lower_lip_polygon
swapped_frame_result = apply_mouth_area(
swapped_frame_result, mouth_cutout, mouth_box, face_mask, lower_lip_polygon
)

if modules.globals.show_mouth_mask_box:
mouth_mask_data = (mouth_mask, mouth_cutout, mouth_box, lower_lip_polygon)
swapped_frame = draw_mouth_mask_visualization(
swapped_frame, target_face, mouth_mask_data
swapped_frame_result = draw_mouth_mask_visualization(
swapped_frame_result, target_face, mouth_mask_data
)

return swapped_frame
# Poisson Blending
if modules.globals.use_poisson_blending and hasattr(target_face, 'bbox'):
# Create a mask for the swapped face region for Poisson blending
# This mask should cover the area of the swapped face.
# We can use the target_face.bbox and perhaps expand it slightly,
# or use a more precise mask from face parsing if available.
# For simplicity, using a slightly feathered convex hull of landmarks.

face_mask_for_blending = np.zeros(temp_frame.shape[:2], dtype=np.uint8)

# Prioritize using the bounding box for a tighter mask
if hasattr(target_face, 'bbox'):
x1, y1, x2, y2 = target_face.bbox.astype(int)
# Ensure coordinates are within frame bounds
x1_b, y1_b = max(0, x1), max(0, y1) # Use different var names to avoid conflict with center calculation
x2_b, y2_b = min(temp_frame.shape[1], x2), min(temp_frame.shape[0], y2)

# Create a rectangular mask based on the bounding box
if x1_b < x2_b and y1_b < y2_b:
face_mask_for_blending[y1_b:y2_b, x1_b:x2_b] = 255
else:
logging.warning("Invalid bounding box for Poisson mask. Attempting landmark-based mask.")
# Fallback to landmark-based convex hull if bbox is invalid
landmarks = target_face.landmark_2d_106 if hasattr(target_face, 'landmark_2d_106') else None
if landmarks is not None and len(landmarks) > 0:
try:
hull_points = cv2.convexHull(landmarks.astype(np.int32))
cv2.fillConvexPoly(face_mask_for_blending, hull_points, 255)
except Exception as e:
logging.error(f"Could not form convex hull for Poisson mask from landmarks: {e}. Blending will be skipped.")
else:
logging.error("No valid bbox or landmarks for Poisson mask. Blending will be skipped.")
else:
# Fallback to landmark-based convex hull if no bbox attribute
landmarks = target_face.landmark_2d_106 if hasattr(target_face, 'landmark_2d_106') else None
if landmarks is not None and len(landmarks) > 0:
try:
hull_points = cv2.convexHull(landmarks.astype(np.int32))
cv2.fillConvexPoly(face_mask_for_blending, hull_points, 255)
except Exception as e:
logging.error(f"Could not form convex hull for Poisson mask from landmarks (no bbox): {e}. Blending will be skipped.")
else:
logging.error("No bbox or landmarks available for Poisson mask. Blending will be skipped.")

# Subtract ear regions if preserve_target_ears is enabled
if modules.globals.preserve_target_ears and np.any(face_mask_for_blending > 0):
mfx1, mfy1, mfx2, mfy2 = target_face.bbox.astype(int)
mfw = mfx2 - mfx1
mfh = mfy2 - mfy1

ear_w = int(mfw * modules.globals.ear_width_ratio)
ear_h = int(mfh * modules.globals.ear_height_ratio)
ear_v_offset = int(mfh * modules.globals.ear_vertical_offset_ratio)
ear_overlap = int(mfw * modules.globals.ear_horizontal_overlap_ratio)

# Person's Right Ear (image left side of face bbox)
# This region in face_mask_for_blending will be set to 0
rex1 = max(0, mfx1 - ear_w + ear_overlap)
rey1 = max(0, mfy1 + ear_v_offset)
rex2 = min(temp_frame.shape[1], mfx1 + ear_overlap) # Extends slightly into face bbox for smoother transition
rey2 = min(temp_frame.shape[0], rey1 + ear_h)
if rex1 < rex2 and rey1 < rey2:
cv2.rectangle(face_mask_for_blending, (rex1, rey1), (rex2, rey2), 0, -1)

# Person's Left Ear (image right side of face bbox)
lex1 = max(0, mfx2 - ear_overlap)
ley1 = max(0, mfy1 + ear_v_offset)
lex2 = min(temp_frame.shape[1], mfx2 + ear_w - ear_overlap)
ley2 = min(temp_frame.shape[0], ley1 + ear_h)
if lex1 < lex2 and ley1 < ley2:
cv2.rectangle(face_mask_for_blending, (lex1, ley1), (lex2, ley2), 0, -1)

# Feather the mask to smooth edges for Poisson blending
if np.any(face_mask_for_blending > 0): # Only feather if there's a mask
feather_amount = modules.globals.poisson_blending_feather_amount
if feather_amount > 0:
# Ensure kernel size is odd
kernel_size = 2 * feather_amount + 1
face_mask_for_blending = cv2.GaussianBlur(face_mask_for_blending, (kernel_size, kernel_size), 0)

# Calculate the center of the target face bbox for seamlessClone
if hasattr(target_face, 'bbox'):
x1, y1, x2, y2 = target_face.bbox.astype(int)
center_x = (x1 + x2) // 2
center_y = (y1 + y2) // 2

# Ensure center is within frame dimensions
center_x = np.clip(center_x, 0, temp_frame.shape[1] -1)
center_y = np.clip(center_y, 0, temp_frame.shape[0] -1)
center = (center_x, center_y)

# Apply Poisson blending
# swapped_frame_result is the source, temp_frame is the destination
if np.any(face_mask_for_blending > 0): # Proceed only if mask is not empty
try:
# Ensure swapped_frame_result and temp_frame are 8-bit 3-channel images
if swapped_frame_result.dtype != np.uint8:
swapped_frame_result = np.clip(swapped_frame_result, 0, 255).astype(np.uint8)
if temp_frame.dtype != np.uint8:
temp_frame_uint8 = np.clip(temp_frame, 0, 255).astype(np.uint8)
else:
temp_frame_uint8 = temp_frame

swapped_frame_result = cv2.seamlessClone(swapped_frame_result, temp_frame_uint8, face_mask_for_blending, center, cv2.NORMAL_CLONE)
except cv2.error as e:
logging.error(f"Error during Poisson blending: {e}")
# Fallback to non-blended result if seamlessClone fails
pass # swapped_frame_result remains as is
else:
logging.warning("Poisson blending mask is empty. Skipping Poisson blending.")

return swapped_frame_result


def process_frame(source_face: Face, temp_frame: Frame) -> Frame:
if modules.globals.color_correction:
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
# The color_correction logic was moved into swap_face.
# The initial temp_frame modification `cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)`
# was incorrect as it changes the color space of the whole frame before processing,
# which is not what we want for color correction of the swapped part.
# Histogram matching is now done BGR to BGR.

if modules.globals.many_faces:
many_faces = get_many_faces(temp_frame)
Expand Down Expand Up @@ -620,3 +767,37 @@ def apply_color_transfer(source, target):
source = (source - source_mean) * (target_std / source_std) + target_mean

return cv2.cvtColor(np.clip(source, 0, 255).astype("uint8"), cv2.COLOR_LAB2BGR)


def apply_histogram_matching_color_correction(source_img: Frame, target_img: Frame) -> Frame:
"""
Applies color correction to the source image to match the target image's color distribution
using histogram matching on each color channel.
"""
corrected_img = np.zeros_like(source_img)
for i in range(source_img.shape[2]): # Iterate over color channels (B, G, R)
source_hist, _ = np.histogram(source_img[:, :, i].flatten(), 256, [0, 256])
target_hist, _ = np.histogram(target_img[:, :, i].flatten(), 256, [0, 256])

# Compute cumulative distribution functions (CDFs)
source_cdf = source_hist.cumsum()
source_cdf_normalized = source_cdf * source_hist.max() / source_cdf.max() # Normalize

target_cdf = target_hist.cumsum()
target_cdf_normalized = target_cdf * target_hist.max() / target_cdf.max() # Normalize

# Create lookup table
lookup_table = np.zeros(256, 'uint8')

gj = 0
for gi in range(256):
while gj < 256 and target_cdf_normalized[gj] < source_cdf_normalized[gi]:
gj += 1
if gj == 256: # If we reach end of target_cdf, map remaining to max value
lookup_table[gi] = 255
else:
lookup_table[gi] = gj

corrected_img[:, :, i] = cv2.LUT(source_img[:, :, i], lookup_table)

return corrected_img
Loading