diff --git a/README.md b/README.md
index 1ab3042..ffe436d 100644
--- a/README.md
+++ b/README.md
@@ -58,19 +58,19 @@
MarkDiffusion is an open-source Python toolkit for generative watermarking of latent diffusion models. As the use of diffusion-based generative models expands, ensuring the authenticity and origin of generated media becomes critical. MarkDiffusion simplifies the access, understanding, and assessment of watermarking technologies, making it accessible to both researchers and the broader community. *Note: if you are interested in LLM watermarking (text watermark), please refer to the [MarkLLM](https://github.com/THU-BPM/MarkLLM) toolkit from our group.*
-The toolkit comprises three key components: a unified implementation framework for streamlined watermarking algorithm integrations and user-friendly interfaces; a mechanism visualization suite that intuitively showcases added and extracted watermark patterns to aid public understanding; and a comprehensive evaluation module offering standard implementations of 24 tools across three essential aspects—detectability, robustness, and output quality, plus 8 automated evaluation pipelines.
+The toolkit comprises three key components: a unified implementation framework for streamlined watermarking algorithm integrations and user-friendly interfaces; a mechanism visualization suite that intuitively showcases added and extracted watermark patterns to aid public understanding; and a comprehensive evaluation module offering standard implementations of 31 tools across three essential aspects—detectability, robustness, and output quality, plus 6 automated evaluation pipelines.
### 💍 Key Features
-- **Unified Implementation Framework:** MarkDiffusion provides a modular architecture supporting eight state-of-the-art generative image/video watermarking algorithms of LDMs.
+- **Unified Implementation Framework:** MarkDiffusion provides a modular architecture supporting eleven state-of-the-art generative image/video watermarking algorithms of LDMs.
-- **Comprehensive Algorithm Support:** Currently implements 8 watermarking algorithms from two major categories: Pattern-based methods (Tree-Ring, Ring-ID, ROBIN, WIND) and Key-based methods (Gaussian-Shading, PRC, SEAL, VideoShield).
+- **Comprehensive Algorithm Support:** Currently implements 11 watermarking algorithms from two major categories: Pattern-based methods (Tree-Ring, Ring-ID, ROBIN, WIND, SFW) and Key-based methods (Gaussian-Shading, PRC, SEAL, VideoShield, GaussMarker, VideoMark).
- **Visualization Solutions:** The toolkit includes custom visualization tools that enable clear and insightful views into how different watermarking algorithms operate under various scenarios. These visualizations help demystify the algorithms' mechanisms, making them more understandable for users.
-- **Evaluation Module:** With 20 evaluation tools covering detectability, robustness, and impact on output quality, MarkDiffusion provides comprehensive assessment capabilities. It features 5 automated evaluation pipelines: Watermark Detection Pipeline, Image Quality Analysis Pipeline, Video Quality Analysis Pipeline, and specialized robustness assessment tools.
+- **Evaluation Module:** With 31 evaluation tools covering detectability, robustness, and impact on output quality, MarkDiffusion provides comprehensive assessment capabilities. It features 6 automated evaluation pipelines: Watermark Detection Pipeline, Image Quality Analysis Pipeline, Video Quality Analysis Pipeline, and specialized robustness assessment tools.
### ✨ Implemented Algorithms
diff --git a/README_es.md b/README_es.md
index a3c3f54..89488fe 100644
--- a/README_es.md
+++ b/README_es.md
@@ -4,39 +4,44 @@
# Un Kit de Herramientas de Código Abierto para Marcas de Agua Generativas de Modelos de Difusión Latente
-[](https://generative-watermark.github.io/)
+[](https://generative-watermark.github.io/)
[](https://arxiv.org/abs/2509.10569)
-[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing)
+[](https://markdiffusion.readthedocs.io)
+[](https://pypi.org/project/markdiffusion)
+[](https://github.com/conda-forge/markdiffusion-feedstock)
+
-**Versiones de idioma:** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
+**Versiones de idioma:** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
> 🔥 **¡Como un proyecto recién lanzado, damos la bienvenida a PRs!** Si has implementado un algoritmo de marcas de agua LDM o estás interesado en contribuir con uno, nos encantaría incluirlo en MarkDiffusion. ¡Únete a nuestra comunidad y ayuda a hacer las marcas de agua generativas más accesibles para todos!
## Contenidos
-- [Notas](#-notas)
- [Actualizaciones](#-actualizaciones)
-- [Introducción a MarkDiffusion](#introducción-a-markdiffusion)
- - [Descripción general](#descripción-general)
- - [Características clave](#características-clave)
- - [Algoritmos implementados](#algoritmos-implementados)
- - [Módulo de evaluación](#módulo-de-evaluación)
-- [Instalación](#instalación)
-- [Inicio rápido](#inicio-rápido)
-- [Cómo usar el kit de herramientas](#cómo-usar-el-kit-de-herramientas)
- - [Generación y detección de medios con marcas de agua](#generación-y-detección-de-medios-con-marcas-de-agua)
- - [Visualización de mecanismos de marcas de agua](#visualización-de-mecanismos-de-marcas-de-agua)
- - [Pipelines de evaluación](#pipelines-de-evaluación)
+- [Introducción a MarkDiffusion](#-introducción-a-markdiffusion)
+ - [Descripción general](#-descripción-general)
+ - [Características clave](#-características-clave)
+ - [Algoritmos implementados](#-algoritmos-implementados)
+ - [Módulo de evaluación](#-módulo-de-evaluación)
+- [Inicio rápido](#-inicio-rápido)
+ - [Demo de Google Colab](#demo-de-google-colab)
+ - [Instalación](#instalación)
+ - [Cómo usar el kit de herramientas](#cómo-usar-el-kit-de-herramientas)
+- [Módulos de prueba](#-módulos-de-prueba)
- [Citación](#citación)
-## ❗❗❗ Notas
-A medida que el contenido del repositorio MarkDiffusion se vuelve cada vez más rico y su tamaño crece, hemos creado un repositorio de almacenamiento de modelos en Hugging Face llamado [Generative-Watermark-Toolkits](https://huggingface.co/Generative-Watermark-Toolkits) para facilitar su uso. Este repositorio contiene varios modelos predeterminados para algoritmos de marcas de agua que involucran modelos auto-entrenados. Hemos eliminado los pesos de los modelos de las carpetas `ckpts/` correspondientes de estos algoritmos de marcas de agua en el repositorio principal. **Al usar el código, primero descarga los modelos correspondientes del repositorio de Hugging Face según las rutas de configuración y guárdalos en el directorio `ckpts/` antes de ejecutar el código.**
## 🔥 Actualizaciones
+🛠 **(2025.12.19)** Agregada una suite de pruebas completa para todas las funcionalidades con 454 casos de prueba.
+
+🛠 **(2025.12.10)** Agregado un sistema de pruebas de integración continua usando GitHub Actions.
+
🎯 **(2025.10.10)** Agregadas herramientas de ataque de imagen *Mask, Overlay, AdaptiveNoiseInjection*, ¡gracias a Zheyu Fu por su PR!
-🎯 **(2025.10.09)** Agregadas herramientas de ataque de video *VideoCodecAttack, FrameRateAdapter, FrameInterpolationAttack*, ¡gracias a Luyang Si por su PR!
+🎯 **(2025.10.09)** Agregadas herramientas de ataque de video *FrameRateAdapter, FrameInterpolationAttack*, ¡gracias a Luyang Si por su PR!
🎯 **(2025.10.08)** Agregados analizadores de calidad de imagen *SSIM, BRISQUE, VIF, FSIM*, ¡gracias a Huan Wang por su PR!
@@ -46,27 +51,27 @@ A medida que el contenido del repositorio MarkDiffusion se vuelve cada vez más
✨ **(2025.9.29)** Agregado el método de marca de agua [GaussMarker](https://arxiv.org/abs/2506.11444), ¡gracias a Luyang Si por su PR!
-## Introducción a MarkDiffusion
+## 🔓 Introducción a MarkDiffusion
-### Descripción general
+### 👀 Descripción general
MarkDiffusion es un kit de herramientas de Python de código abierto para marcas de agua generativas de modelos de difusión latente. A medida que se expande el uso de modelos generativos basados en difusión, garantizar la autenticidad y el origen de los medios generados se vuelve crítico. MarkDiffusion simplifica el acceso, la comprensión y la evaluación de tecnologías de marcas de agua, haciéndolo accesible tanto para investigadores como para la comunidad en general. *Nota: si estás interesado en marcas de agua LLM (marca de agua de texto), consulta el kit de herramientas [MarkLLM](https://github.com/THU-BPM/MarkLLM) de nuestro grupo.*
-El kit de herramientas comprende tres componentes clave: un marco de implementación unificado para integraciones simplificadas de algoritmos de marcas de agua e interfaces fáciles de usar; un conjunto de visualización de mecanismos que muestra intuitivamente los patrones de marcas de agua agregados y extraídos para ayudar a la comprensión pública; y un módulo de evaluación integral que ofrece implementaciones estándar de 24 herramientas en tres aspectos esenciales: detectabilidad, robustez y calidad de salida, además de 8 pipelines de evaluación automatizados.
+El kit de herramientas comprende tres componentes clave: un marco de implementación unificado para integraciones simplificadas de algoritmos de marcas de agua e interfaces fáciles de usar; un conjunto de visualización de mecanismos que muestra intuitivamente los patrones de marcas de agua agregados y extraídos para ayudar a la comprensión pública; y un módulo de evaluación integral que ofrece implementaciones estándar de 31 herramientas en tres aspectos esenciales: detectabilidad, robustez y calidad de salida, además de 6 pipelines de evaluación automatizados.
-### Características clave
+### 💍 Características clave
-- **Marco de implementación unificado:** MarkDiffusion proporciona una arquitectura modular que admite ocho algoritmos de marcas de agua generativas de imagen/video de última generación para LDMs.
+- **Marco de implementación unificado:** MarkDiffusion proporciona una arquitectura modular que admite once algoritmos de marcas de agua generativas de imagen/video de última generación para LDMs.
-- **Soporte integral de algoritmos:** Actualmente implementa 8 algoritmos de marcas de agua de dos categorías principales: métodos basados en patrones (Tree-Ring, Ring-ID, ROBIN, WIND) y métodos basados en claves (Gaussian-Shading, PRC, SEAL, VideoShield).
+- **Soporte integral de algoritmos:** Actualmente implementa 11 algoritmos de marcas de agua de dos categorías principales: métodos basados en patrones (Tree-Ring, Ring-ID, ROBIN, WIND, SFW) y métodos basados en claves (Gaussian-Shading, PRC, SEAL, VideoShield, GaussMarker, VideoMark).
- **Soluciones de visualización:** El kit de herramientas incluye herramientas de visualización personalizadas que permiten vistas claras y perspicaces sobre cómo operan los diferentes algoritmos de marcas de agua en varios escenarios. Estas visualizaciones ayudan a desmitificar los mecanismos de los algoritmos, haciéndolos más comprensibles para los usuarios.
-- **Módulo de evaluación:** Con 20 herramientas de evaluación que cubren detectabilidad, robustez e impacto en la calidad de salida, MarkDiffusion proporciona capacidades de evaluación integral. Cuenta con 5 pipelines de evaluación automatizados: Pipeline de detección de marcas de agua, Pipeline de análisis de calidad de imagen, Pipeline de análisis de calidad de video y herramientas especializadas de evaluación de robustez.
+- **Módulo de evaluación:** Con 31 herramientas de evaluación que cubren detectabilidad, robustez e impacto en la calidad de salida, MarkDiffusion proporciona capacidades de evaluación integral. Cuenta con 6 pipelines de evaluación automatizados: Pipeline de detección de marcas de agua, Pipeline de análisis de calidad de imagen, Pipeline de análisis de calidad de video y herramientas especializadas de evaluación de robustez.
-### Algoritmos implementados
+### ✨ Algoritmos implementados
| **Algoritmo** | **Categoría** | **Objetivo** | **Referencia** |
|---------------|-------------|------------|---------------|
@@ -82,7 +87,7 @@ El kit de herramientas comprende tres componentes clave: un marco de implementac
| VideoShield | Clave | Video | [VideoShield: Regulating Diffusion-based Video Generation Models via Watermarking](https://arxiv.org/abs/2501.14195) |
| VideoMark | Clave | Video | [VideoMark: A Distortion-Free Robust Watermarking Framework for Video Diffusion Models](https://arxiv.org/abs/2504.16359) |
-### Módulo de evaluación
+### 🎯 Módulo de evaluación
#### Pipelines de evaluación
MarkDiffusion admite ocho pipelines, dos para detección (WatermarkedMediaDetectionPipeline y UnWatermarkedMediaDetectionPipeline), y seis para análisis de calidad. La tabla a continuación detalla los pipelines de análisis de calidad.
@@ -116,7 +121,6 @@ MarkDiffusion admite ocho pipelines, dos para detección (WatermarkedMediaDetect
| MPEG4Compression | Robustez (Video) | Ataque de compresión de video MPEG-4, probando la robustez de compresión de marca de agua de video | Fotogramas de video comprimidos |
| FrameAverage | Robustez (Video) | Ataque de promedio de fotogramas, destruyendo marcas de agua a través del promedio entre fotogramas | Fotogramas de video promediados |
| FrameSwap | Robustez (Video) | Ataque de intercambio de fotogramas, probando la robustez cambiando secuencias de fotogramas | Fotogramas de video intercambiados |
-| VideoCodecAttack | Robustez (Video) | Ataque de recodificación de códec simulando transcodificación de plataforma (H.264/H.265/VP9/AV1) | Fotogramas de video recodificados |
| FrameRateAdapter | Robustez (Video) | Ataque de conversión de velocidad de fotogramas que remuestrea fotogramas preservando la duración | Secuencia de fotogramas remuestreada |
| FrameInterpolationAttack | Robustez (Video) | Ataque de interpolación de fotogramas insertando fotogramas mezclados para alterar la densidad temporal | Fotogramas de video interpolados |
| **Analizadores de calidad de imagen** | | | |
@@ -137,326 +141,130 @@ MarkDiffusion admite ocho pipelines, dos para detección (WatermarkedMediaDetect
| DynamicDegreeAnalyzer | Calidad (Video) | Medir nivel dinámico y magnitud de cambio en video | Valor de grado dinámico |
| ImagingQualityAnalyzer | Calidad (Video) | Evaluación integral de calidad de imagen de video | Puntuación de calidad de imagen |
-## Instalación
-
-### Configuración del entorno
-
-- Python 3.10+
-- PyTorch
-- Instalar dependencias:
+## 🧩 Inicio rápido
+### Demo de Google Colab
+Si deseas probar MarkDiffusion sin instalar nada, puedes usar [Google Colab](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing#scrollTo=-kWt7m9Y3o-G) para ver cómo funciona.
+### Instalación
+**(Recomendado)** Hemos publicado un paquete pypi para MarkDiffusion. Puedes instalarlo directamente con pip:
```bash
-pip install -r requirements.txt
-```
-
-*Nota:* Algunos algoritmos pueden requerir pasos de configuración adicionales. Consulta la documentación de algoritmos individuales para requisitos específicos.
-
-## Inicio rápido
-
-Aquí hay un ejemplo simple para comenzar con MarkDiffusion:
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-
-# Configuración del dispositivo
-device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
-# Configurar pipeline de difusión
-scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
-pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
-diffusion_config = DiffusionConfig(
- scheduler=scheduler,
- pipe=pipe,
- device=device,
- image_size=(512, 512),
- num_inference_steps=50,
- guidance_scale=7.5,
- gen_seed=42,
- inversion_type="ddim"
-)
-
-# Cargar algoritmo de marca de agua
-watermark = AutoWatermark.load('TR',
- algorithm_config='config/TR.json',
- diffusion_config=diffusion_config)
-
-# Generar medios con marca de agua
-prompt = "A beautiful sunset over the ocean"
-watermarked_image = watermark.generate_watermarked_media(prompt)
-
-# Detectar marca de agua
-detection_result = watermark.detect_watermark_in_media(watermarked_image)
-print(f"Watermark detected: {detection_result}")
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+pip install markdiffusion[optional]
```
-## Cómo usar el kit de herramientas
-
-Proporcionamos ejemplos extensos en `MarkDiffusion_demo.ipynb`.
-
-### Generación y detección de medios con marcas de agua
-
-#### Casos para generar y detectar medios con marcas de agua
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-
-# Cargar algoritmo de marca de agua
-mywatermark = AutoWatermark.load(
- 'GS',
- algorithm_config=f'config/GS.json',
- diffusion_config=diffusion_config
-)
-
-# Generar imagen con marca de agua
-watermarked_image = mywatermark.generate_watermarked_media(
- input_data="A beautiful landscape with a river and mountains"
-)
-
-# Visualizar la imagen con marca de agua
-watermarked_image.show()
-
-# Detectar marca de agua
-detection_result = mywatermark.detect_watermark_in_media(watermarked_image)
-print(detection_result)
+(Alternativa) Para usuarios que están *restringidos solo al uso del entorno conda*, también proporcionamos un paquete conda-forge, que se puede instalar con los siguientes comandos:
+```bash
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+conda config --add channels conda-forge
+conda config --set channel_priority strict
+conda install markdiffusion
```
+Sin embargo, ten en cuenta que algunas características avanzadas requieren paquetes adicionales que no están disponibles en conda y no se pueden incluir en la versión. Necesitarás instalarlos por separado si es necesario.
-### Visualización de mecanismos de marcas de agua
-
-El kit de herramientas incluye herramientas de visualización personalizadas que permiten vistas claras y perspicaces sobre cómo operan los diferentes algoritmos de marcas de agua en varios escenarios. Estas visualizaciones ayudan a desmitificar los mecanismos de los algoritmos, haciéndolos más comprensibles para los usuarios.
-
-
+### Cómo usar el kit de herramientas
-#### Casos para visualizar mecanismos de marcas de agua
+Después de la instalación, hay dos formas de usar MarkDiffusion:
-```python
-from visualize.auto_visualization import AutoVisualizer
-
-# Obtener datos para visualización
-data_for_visualization = mywatermark.get_data_for_visualize(watermarked_image)
-
-# Cargar visualizador
-visualizer = AutoVisualizer.load('GS',
- data_for_visualization=data_for_visualization)
-
-# Dibujar diagramas en el lienzo de Matplotlib
-fig = visualizer.visualize(rows=2, cols=2,
- methods=['draw_watermark_bits',
- 'draw_reconstructed_watermark_bits',
- 'draw_inverted_latents',
- 'draw_inverted_latents_fft'])
-```
+1. **Clonar el repositorio para probar las demos o usarlo para desarrollo personalizado.** El notebook `MarkDiffusion_demo.ipynb` ofrece demostraciones detalladas para varios casos de uso — por favor revísalo para obtener orientación. Aquí hay un ejemplo rápido de generación y detección de imagen con marca de agua usando el algoritmo TR:
-### Pipelines de evaluación
-
-#### Casos para evaluación
-
-1. **Pipeline de detección de marcas de agua**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset
-from evaluation.pipelines.detection import (
- WatermarkedMediaDetectionPipeline,
- UnWatermarkedMediaDetectionPipeline,
- DetectionPipelineReturnType
-)
-from evaluation.tools.image_editor import JPEGCompression
-from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator
-
-# Conjunto de datos
-my_dataset = StableDiffusionPromptsDataset(max_samples=200)
-
-# Configurar pipelines de detección
-pipeline1 = WatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[JPEGCompression(quality=60)],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-pipeline2 = UnWatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-# Configurar parámetros de detección
-detection_kwargs = {
- "num_inference_steps": 50,
- "guidance_scale": 1.0,
-}
-# Calcular tasas de éxito
-calculator = DynamicThresholdSuccessRateCalculator(
- labels=labels,
- rule=rules,
- target_fpr=target_fpr
-)
-
-results = calculator.calculate(
- pipeline1.evaluate(my_watermark, detection_kwargs=detection_kwargs),
- pipeline2.evaluate(my_watermark, detection_kwargs=detection_kwargs)
-)
-print(results)
-```
+ ```python
+ import torch
+ from watermark.auto_watermark import AutoWatermark
+ from utils.diffusion_config import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-2. **Pipeline de análisis de calidad de imagen**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset, MSCOCODataset
-from evaluation.pipelines.image_quality_analysis import (
- DirectImageQualityAnalysisPipeline,
- ReferencedImageQualityAnalysisPipeline,
- GroupImageQualityAnalysisPipeline,
- RepeatImageQualityAnalysisPipeline,
- ComparedImageQualityAnalysisPipeline,
- QualityPipelineReturnType
-)
-from evaluation.tools.image_quality_analyzer import (
- NIQECalculator, CLIPScoreCalculator, FIDCalculator,
- InceptionScoreCalculator, LPIPSAnalyzer, PSNRAnalyzer
-)
-
-# Ejemplos de diferentes métricas de calidad:
-
-# NIQE (Evaluador de calidad de imagen natural)
-if metric == 'NIQE':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = DirectImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[NIQECalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Configuración del dispositivo
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
-# Puntuación CLIP
-elif metric == 'CLIP':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = ReferencedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[CLIPScoreCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # Configurar pipeline de difusión
+ scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
+ diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# FID (Distancia de Inception de Fréchet)
-elif metric == 'FID':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[FIDCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # Cargar algoritmo de marca de agua
+ watermark = AutoWatermark.load('TR',
+ algorithm_config='config/TR.json',
+ diffusion_config=diffusion_config)
+
+ # Generar medios con marca de agua
+ prompt = "A beautiful sunset over the ocean"
+ watermarked_image = watermark.generate_watermarked_media(prompt)
+ watermarked_image.save("watermarked_image.png")
+
+ # Detectar marca de agua
+ detection_result = watermark.detect_watermark_in_media(watermarked_image)
+ print(f"Watermark detected: {detection_result}")
+ ```
+
+2. **Importar la biblioteca markdiffusion directamente en tu código sin clonar el repositorio.** El notebook `MarkDiffusion_pypi_demo.ipynb` proporciona ejemplos completos para usar MarkDiffusion a través de la biblioteca markdiffusion — por favor revísalo para obtener orientación. Aquí hay un ejemplo rápido:
+
+ ```python
+ import torch
+ from markdiffusion.watermark import AutoWatermark
+ from markdiffusion.utils import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
+
+ # Dispositivo
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Using device: {device}")
+
+ # Ruta del modelo
+ MODEL_PATH = "huanzi05/stable-diffusion-2-1-base"
+
+ # Inicializar planificador y pipeline
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(MODEL_PATH, subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained(
+ MODEL_PATH,
+ scheduler=scheduler,
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
+ safety_checker=None,
+ ).to(device)
+
+ # Crear DiffusionConfig para generación de imágenes
+ image_diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# IS (Puntuación Inception)
-elif metric == 'IS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[InceptionScoreCalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Cargar algoritmo de marca de agua Tree-Ring
+ tr_watermark = AutoWatermark.load('TR', diffusion_config=image_diffusion_config)
+ print("TR watermark algorithm loaded successfully!")
-# LPIPS (Similitud de parche de imagen perceptual aprendida)
-elif metric == 'LPIPS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=10)
- pipeline = RepeatImageQualityAnalysisPipeline(
- dataset=my_dataset,
- prompt_per_image=20,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[LPIPSAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Generar imagen con marca de agua
+ prompt = "A beautiful landscape with mountains and a river at sunset"
-# PSNR (Relación señal-ruido de pico)
-elif metric == 'PSNR':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = ComparedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[PSNRAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ watermarked_image = tr_watermark.generate_watermarked_media(input_data=prompt)
-# Cargar marca de agua y evaluar
-my_watermark = AutoWatermark.load(
- f'{algorithm_name}',
- algorithm_config=f'config/{algorithm_name}.json',
- diffusion_config=diffusion_config
-)
+ # Mostrar la imagen con marca de agua
+ watermarked_image.save("watermarked_image.png")
+ print("Watermarked image generated!")
-print(pipeline.evaluate(my_watermark))
-```
+ # Detectar marca de agua en la imagen con marca de agua
+ detection_result = tr_watermark.detect_watermark_in_media(watermarked_image)
+ print("Watermarked image detection result:")
+ print(detection_result)
+ ```
-3. **Pipeline de análisis de calidad de video**
-
-```python
-from evaluation.dataset import VBenchDataset
-from evaluation.pipelines.video_quality_analysis import DirectVideoQualityAnalysisPipeline
-from evaluation.tools.video_quality_analyzer import (
- SubjectConsistencyAnalyzer,
- MotionSmoothnessAnalyzer,
- DynamicDegreeAnalyzer,
- BackgroundConsistencyAnalyzer,
- ImagingQualityAnalyzer
-)
-
-# Cargar conjunto de datos VBench
-my_dataset = VBenchDataset(max_samples=200, dimension=dimension)
-
-# Inicializar analizador según métrica
-if metric == 'subject_consistency':
- analyzer = SubjectConsistencyAnalyzer(device=device)
-elif metric == 'motion_smoothness':
- analyzer = MotionSmoothnessAnalyzer(device=device)
-elif metric == 'dynamic_degree':
- analyzer = DynamicDegreeAnalyzer(device=device)
-elif metric == 'background_consistency':
- analyzer = BackgroundConsistencyAnalyzer(device=device)
-elif metric == 'imaging_quality':
- analyzer = ImagingQualityAnalyzer(device=device)
-else:
- raise ValueError(f'Invalid metric: {metric}. Supported metrics:
- subject_consistency, motion_smoothness, dynamic_degree,
- background_consistency, imaging_quality')
-
-# Crear pipeline de análisis de calidad de video
-pipeline = DirectVideoQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_video_editor_list=[],
- unwatermarked_video_editor_list=[],
- watermarked_frame_editor_list=[],
- unwatermarked_frame_editor_list=[],
- analyzers=[analyzer],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
-)
-
-print(pipeline.evaluate(my_watermark))
-```
+## 🛠 Módulos de prueba
+Proporcionamos un conjunto completo de módulos de prueba para garantizar la calidad del código. El módulo incluye 454 pruebas unitarias con aproximadamente un 90% de cobertura de código. Consulta el directorio `test/` para más detalles.
## Citación
```
diff --git a/README_fr.md b/README_fr.md
index ee43302..33d529e 100644
--- a/README_fr.md
+++ b/README_fr.md
@@ -4,39 +4,44 @@
# Une Boîte à Outils Open-Source pour le Tatouage Numérique Génératif des Modèles de Diffusion Latente
-[](https://generative-watermark.github.io/)
+[](https://generative-watermark.github.io/)
[](https://arxiv.org/abs/2509.10569)
-[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing)
+[](https://markdiffusion.readthedocs.io)
+[](https://pypi.org/project/markdiffusion)
+[](https://github.com/conda-forge/markdiffusion-feedstock)
+
-**Versions linguistiques :** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
+**Versions linguistiques :** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
> 🔥 **En tant que projet récemment publié, nous accueillons les PR !** Si vous avez implémenté un algorithme de tatouage numérique LDM ou si vous êtes intéressé à en contribuer un, nous serions ravis de l'inclure dans MarkDiffusion. Rejoignez notre communauté et aidez à rendre le tatouage numérique génératif plus accessible à tous !
## Sommaire
-- [Remarques](#-remarques)
- [Mises à jour](#-mises-à-jour)
-- [Introduction à MarkDiffusion](#introduction-à-markdiffusion)
- - [Vue d'ensemble](#vue-densemble)
- - [Caractéristiques clés](#caractéristiques-clés)
- - [Algorithmes implémentés](#algorithmes-implémentés)
- - [Module d'évaluation](#module-dévaluation)
-- [Installation](#installation)
-- [Démarrage rapide](#démarrage-rapide)
-- [Comment utiliser la boîte à outils](#comment-utiliser-la-boîte-à-outils)
- - [Génération et détection de médias tatoués](#génération-et-détection-de-médias-tatoués)
- - [Visualisation des mécanismes de tatouage](#visualisation-des-mécanismes-de-tatouage)
- - [Pipelines d'évaluation](#pipelines-dévaluation)
+- [Introduction à MarkDiffusion](#-introduction-à-markdiffusion)
+ - [Vue d'ensemble](#-vue-densemble)
+ - [Caractéristiques clés](#-caractéristiques-clés)
+ - [Algorithmes implémentés](#-algorithmes-implémentés)
+ - [Module d'évaluation](#-module-dévaluation)
+- [Démarrage rapide](#-démarrage-rapide)
+ - [Démo Google Colab](#démo-google-colab)
+ - [Installation](#installation)
+ - [Comment utiliser la boîte à outils](#comment-utiliser-la-boîte-à-outils)
+- [Modules de test](#-modules-de-test)
- [Citation](#citation)
-## ❗❗❗ Remarques
-Au fur et à mesure que le contenu du dépôt MarkDiffusion s'enrichit et que sa taille augmente, nous avons créé un dépôt de stockage de modèles sur Hugging Face appelé [Generative-Watermark-Toolkits](https://huggingface.co/Generative-Watermark-Toolkits) pour faciliter l'utilisation. Ce dépôt contient divers modèles par défaut pour les algorithmes de tatouage numérique qui impliquent des modèles auto-entraînés. Nous avons supprimé les poids des modèles des dossiers `ckpts/` correspondants de ces algorithmes de tatouage dans le dépôt principal. **Lors de l'utilisation du code, veuillez d'abord télécharger les modèles correspondants depuis le dépôt Hugging Face selon les chemins de configuration et les enregistrer dans le répertoire `ckpts/` avant d'exécuter le code.**
## 🔥 Mises à jour
+🛠 **(2025.12.19)** Ajout d'une suite de tests complète pour toutes les fonctionnalités avec 454 cas de test.
+
+🛠 **(2025.12.10)** Ajout d'un système de tests d'intégration continue utilisant GitHub Actions.
+
🎯 **(2025.10.10)** Ajout des outils d'attaque d'image *Mask, Overlay, AdaptiveNoiseInjection*, merci à Zheyu Fu pour sa PR !
-🎯 **(2025.10.09)** Ajout des outils d'attaque vidéo *VideoCodecAttack, FrameRateAdapter, FrameInterpolationAttack*, merci à Luyang Si pour sa PR !
+🎯 **(2025.10.09)** Ajout des outils d'attaque vidéo *FrameRateAdapter, FrameInterpolationAttack*, merci à Luyang Si pour sa PR !
🎯 **(2025.10.08)** Ajout des analyseurs de qualité d'image *SSIM, BRISQUE, VIF, FSIM*, merci à Huan Wang pour sa PR !
@@ -46,27 +51,27 @@ Au fur et à mesure que le contenu du dépôt MarkDiffusion s'enrichit et que sa
✨ **(2025.9.29)** Ajout de la méthode de tatouage [GaussMarker](https://arxiv.org/abs/2506.11444), merci à Luyang Si pour sa PR !
-## Introduction à MarkDiffusion
+## 🔓 Introduction à MarkDiffusion
-### Vue d'ensemble
+### 👀 Vue d'ensemble
MarkDiffusion est une boîte à outils Python open-source pour le tatouage numérique génératif des modèles de diffusion latente. Alors que l'utilisation des modèles génératifs basés sur la diffusion s'étend, garantir l'authenticité et l'origine des médias générés devient crucial. MarkDiffusion simplifie l'accès, la compréhension et l'évaluation des technologies de tatouage numérique, les rendant accessibles tant aux chercheurs qu'à la communauté au sens large. *Remarque : si vous êtes intéressé par le tatouage LLM (tatouage de texte), veuillez vous référer à la boîte à outils [MarkLLM](https://github.com/THU-BPM/MarkLLM) de notre groupe.*
-La boîte à outils comprend trois composants clés : un cadre d'implémentation unifié pour des intégrations rationalisées d'algorithmes de tatouage et des interfaces conviviales ; une suite de visualisation de mécanismes qui présente intuitivement les motifs de tatouage ajoutés et extraits pour aider à la compréhension du public ; et un module d'évaluation complet offrant des implémentations standard de 24 outils couvrant trois aspects essentiels — détectabilité, robustesse et qualité de sortie, plus 8 pipelines d'évaluation automatisés.
+La boîte à outils comprend trois composants clés : un cadre d'implémentation unifié pour des intégrations rationalisées d'algorithmes de tatouage et des interfaces conviviales ; une suite de visualisation de mécanismes qui présente intuitivement les motifs de tatouage ajoutés et extraits pour aider à la compréhension du public ; et un module d'évaluation complet offrant des implémentations standard de 31 outils couvrant trois aspects essentiels — détectabilité, robustesse et qualité de sortie, plus 6 pipelines d'évaluation automatisés.
-### Caractéristiques clés
+### 💍 Caractéristiques clés
-- **Cadre d'implémentation unifié :** MarkDiffusion fournit une architecture modulaire prenant en charge huit algorithmes de tatouage d'image/vidéo génératifs de pointe pour les LDM.
+- **Cadre d'implémentation unifié :** MarkDiffusion fournit une architecture modulaire prenant en charge onze algorithmes de tatouage d'image/vidéo génératifs de pointe pour les LDM.
-- **Support algorithmique complet :** Implémente actuellement 8 algorithmes de tatouage de deux catégories principales : méthodes basées sur les motifs (Tree-Ring, Ring-ID, ROBIN, WIND) et méthodes basées sur les clés (Gaussian-Shading, PRC, SEAL, VideoShield).
+- **Support algorithmique complet :** Implémente actuellement 11 algorithmes de tatouage de deux catégories principales : méthodes basées sur les motifs (Tree-Ring, Ring-ID, ROBIN, WIND, SFW) et méthodes basées sur les clés (Gaussian-Shading, PRC, SEAL, VideoShield, GaussMarker, VideoMark).
- **Solutions de visualisation :** La boîte à outils comprend des outils de visualisation personnalisés qui permettent des vues claires et perspicaces sur le fonctionnement des différents algorithmes de tatouage dans divers scénarios. Ces visualisations aident à démystifier les mécanismes des algorithmes, les rendant plus compréhensibles pour les utilisateurs.
-- **Module d'évaluation :** Avec 20 outils d'évaluation couvrant la détectabilité, la robustesse et l'impact sur la qualité de sortie, MarkDiffusion fournit des capacités d'évaluation complètes. Il comprend 5 pipelines d'évaluation automatisés : Pipeline de détection de tatouage, Pipeline d'analyse de qualité d'image, Pipeline d'analyse de qualité vidéo et outils d'évaluation de robustesse spécialisés.
+- **Module d'évaluation :** Avec 31 outils d'évaluation couvrant la détectabilité, la robustesse et l'impact sur la qualité de sortie, MarkDiffusion fournit des capacités d'évaluation complètes. Il comprend 6 pipelines d'évaluation automatisés : Pipeline de détection de tatouage, Pipeline d'analyse de qualité d'image, Pipeline d'analyse de qualité vidéo et outils d'évaluation de robustesse spécialisés.
-### Algorithmes implémentés
+### ✨ Algorithmes implémentés
| **Algorithme** | **Catégorie** | **Cible** | **Référence** |
|---------------|-------------|------------|---------------|
@@ -82,7 +87,7 @@ La boîte à outils comprend trois composants clés : un cadre d'implémentation
| VideoShield | Clé | Vidéo | [VideoShield: Regulating Diffusion-based Video Generation Models via Watermarking](https://arxiv.org/abs/2501.14195) |
| VideoMark | Clé | Vidéo | [VideoMark: A Distortion-Free Robust Watermarking Framework for Video Diffusion Models](https://arxiv.org/abs/2504.16359) |
-### Module d'évaluation
+### 🎯 Module d'évaluation
#### Pipelines d'évaluation
MarkDiffusion prend en charge huit pipelines, deux pour la détection (WatermarkedMediaDetectionPipeline et UnWatermarkedMediaDetectionPipeline), et six pour l'analyse de qualité. Le tableau ci-dessous détaille les pipelines d'analyse de qualité.
@@ -116,7 +121,6 @@ MarkDiffusion prend en charge huit pipelines, deux pour la détection (Watermark
| MPEG4Compression | Robustesse (Vidéo) | Attaque par compression vidéo MPEG-4, testant la robustesse du tatouage vidéo à la compression | Cadres vidéo compressés |
| FrameAverage | Robustesse (Vidéo) | Attaque par moyennage de cadres, détruisant les tatouages par moyennage inter-cadres | Cadres vidéo moyennés |
| FrameSwap | Robustesse (Vidéo) | Attaque par échange de cadres, testant la robustesse en changeant les séquences de cadres | Cadres vidéo échangés |
-| VideoCodecAttack | Robustesse (Vidéo) | Attaque par ré-encodage de codec simulant le transcodage de plateforme (H.264/H.265/VP9/AV1) | Cadres vidéo ré-encodés |
| FrameRateAdapter | Robustesse (Vidéo) | Attaque par conversion de fréquence d'images qui rééchantillonne les cadres tout en préservant la durée | Séquence de cadres rééchantillonnée |
| FrameInterpolationAttack | Robustesse (Vidéo) | Attaque par interpolation de cadres insérant des cadres mélangés pour modifier la densité temporelle | Cadres vidéo interpolés |
| **Analyseurs de qualité d'image** | | | |
@@ -137,326 +141,130 @@ MarkDiffusion prend en charge huit pipelines, deux pour la détection (Watermark
| DynamicDegreeAnalyzer | Qualité (Vidéo) | Mesurer le niveau dynamique et l'amplitude de changement dans la vidéo | Valeur de degré dynamique |
| ImagingQualityAnalyzer | Qualité (Vidéo) | Évaluation complète de la qualité d'imagerie vidéo | Score de qualité d'imagerie |
-## Installation
-
-### Configuration de l'environnement
-
-- Python 3.10+
-- PyTorch
-- Installer les dépendances :
+## 🧩 Démarrage rapide
+### Démo Google Colab
+Si vous souhaitez essayer MarkDiffusion sans rien installer, vous pouvez utiliser [Google Colab](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing#scrollTo=-kWt7m9Y3o-G) pour voir comment cela fonctionne.
+### Installation
+**(Recommandé)** Nous avons publié un package pypi pour MarkDiffusion. Vous pouvez l'installer directement avec pip :
```bash
-pip install -r requirements.txt
-```
-
-*Remarque :* Certains algorithmes peuvent nécessiter des étapes de configuration supplémentaires. Veuillez vous référer à la documentation des algorithmes individuels pour les exigences spécifiques.
-
-## Démarrage rapide
-
-Voici un exemple simple pour vous aider à démarrer avec MarkDiffusion :
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-
-# Configuration du périphérique
-device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
-# Configuration du pipeline de diffusion
-scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
-pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
-diffusion_config = DiffusionConfig(
- scheduler=scheduler,
- pipe=pipe,
- device=device,
- image_size=(512, 512),
- num_inference_steps=50,
- guidance_scale=7.5,
- gen_seed=42,
- inversion_type="ddim"
-)
-
-# Charger l'algorithme de tatouage
-watermark = AutoWatermark.load('TR',
- algorithm_config='config/TR.json',
- diffusion_config=diffusion_config)
-
-# Générer un média tatoué
-prompt = "A beautiful sunset over the ocean"
-watermarked_image = watermark.generate_watermarked_media(prompt)
-
-# Détecter le tatouage
-detection_result = watermark.detect_watermark_in_media(watermarked_image)
-print(f"Watermark detected: {detection_result}")
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+pip install markdiffusion[optional]
```
-## Comment utiliser la boîte à outils
-
-Nous fournissons de nombreux exemples dans `MarkDiffusion_demo.ipynb`.
-
-### Génération et détection de médias tatoués
-
-#### Cas de génération et de détection de médias tatoués
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-
-# Charger l'algorithme de tatouage
-mywatermark = AutoWatermark.load(
- 'GS',
- algorithm_config=f'config/GS.json',
- diffusion_config=diffusion_config
-)
-
-# Générer une image tatouée
-watermarked_image = mywatermark.generate_watermarked_media(
- input_data="A beautiful landscape with a river and mountains"
-)
-
-# Visualiser l'image tatouée
-watermarked_image.show()
-
-# Détecter le tatouage
-detection_result = mywatermark.detect_watermark_in_media(watermarked_image)
-print(detection_result)
+(Alternative) Pour les utilisateurs qui sont *restreints uniquement à l'utilisation de l'environnement conda*, nous fournissons également un package conda-forge, qui peut être installé avec les commandes suivantes :
+```bash
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+conda config --add channels conda-forge
+conda config --set channel_priority strict
+conda install markdiffusion
```
+Cependant, veuillez noter que certaines fonctionnalités avancées nécessitent des packages supplémentaires qui ne sont pas disponibles sur conda et ne peuvent pas être inclus dans la version. Vous devrez les installer séparément si nécessaire.
-### Visualisation des mécanismes de tatouage
-
-La boîte à outils comprend des outils de visualisation personnalisés qui permettent des vues claires et perspicaces sur le fonctionnement des différents algorithmes de tatouage dans divers scénarios. Ces visualisations aident à démystifier les mécanismes des algorithmes, les rendant plus compréhensibles pour les utilisateurs.
-
-
+### Comment utiliser la boîte à outils
-#### Cas de visualisation du mécanisme de tatouage
+Après l'installation, il existe deux façons d'utiliser MarkDiffusion :
-```python
-from visualize.auto_visualization import AutoVisualizer
-
-# Obtenir les données pour la visualisation
-data_for_visualization = mywatermark.get_data_for_visualize(watermarked_image)
-
-# Charger le visualiseur
-visualizer = AutoVisualizer.load('GS',
- data_for_visualization=data_for_visualization)
-
-# Dessiner des diagrammes sur le canevas Matplotlib
-fig = visualizer.visualize(rows=2, cols=2,
- methods=['draw_watermark_bits',
- 'draw_reconstructed_watermark_bits',
- 'draw_inverted_latents',
- 'draw_inverted_latents_fft'])
-```
+1. **Cloner le dépôt pour essayer les démos ou l'utiliser pour un développement personnalisé.** Le notebook `MarkDiffusion_demo.ipynb` offre des démonstrations détaillées pour divers cas d'utilisation — veuillez le consulter pour obtenir des conseils. Voici un exemple rapide de génération et de détection d'image tatouée avec l'algorithme TR :
-### Pipelines d'évaluation
-
-#### Cas d'évaluation
-
-1. **Pipeline de détection de tatouage**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset
-from evaluation.pipelines.detection import (
- WatermarkedMediaDetectionPipeline,
- UnWatermarkedMediaDetectionPipeline,
- DetectionPipelineReturnType
-)
-from evaluation.tools.image_editor import JPEGCompression
-from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator
-
-# Jeu de données
-my_dataset = StableDiffusionPromptsDataset(max_samples=200)
-
-# Configurer les pipelines de détection
-pipeline1 = WatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[JPEGCompression(quality=60)],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-pipeline2 = UnWatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-# Configurer les paramètres de détection
-detection_kwargs = {
- "num_inference_steps": 50,
- "guidance_scale": 1.0,
-}
-# Calculer les taux de réussite
-calculator = DynamicThresholdSuccessRateCalculator(
- labels=labels,
- rule=rules,
- target_fpr=target_fpr
-)
-
-results = calculator.calculate(
- pipeline1.evaluate(my_watermark, detection_kwargs=detection_kwargs),
- pipeline2.evaluate(my_watermark, detection_kwargs=detection_kwargs)
-)
-print(results)
-```
+ ```python
+ import torch
+ from watermark.auto_watermark import AutoWatermark
+ from utils.diffusion_config import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-2. **Pipeline d'analyse de qualité d'image**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset, MSCOCODataset
-from evaluation.pipelines.image_quality_analysis import (
- DirectImageQualityAnalysisPipeline,
- ReferencedImageQualityAnalysisPipeline,
- GroupImageQualityAnalysisPipeline,
- RepeatImageQualityAnalysisPipeline,
- ComparedImageQualityAnalysisPipeline,
- QualityPipelineReturnType
-)
-from evaluation.tools.image_quality_analyzer import (
- NIQECalculator, CLIPScoreCalculator, FIDCalculator,
- InceptionScoreCalculator, LPIPSAnalyzer, PSNRAnalyzer
-)
-
-# Exemples de différentes métriques de qualité :
-
-# NIQE (Évaluateur de qualité d'image naturelle)
-if metric == 'NIQE':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = DirectImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[NIQECalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Configuration du périphérique
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
-# Score CLIP
-elif metric == 'CLIP':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = ReferencedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[CLIPScoreCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # Configuration du pipeline de diffusion
+ scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
+ diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# FID (Distance d'Inception de Fréchet)
-elif metric == 'FID':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[FIDCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # Charger l'algorithme de tatouage
+ watermark = AutoWatermark.load('TR',
+ algorithm_config='config/TR.json',
+ diffusion_config=diffusion_config)
+
+ # Générer un média tatoué
+ prompt = "A beautiful sunset over the ocean"
+ watermarked_image = watermark.generate_watermarked_media(prompt)
+ watermarked_image.save("watermarked_image.png")
+
+ # Détecter le tatouage
+ detection_result = watermark.detect_watermark_in_media(watermarked_image)
+ print(f"Watermark detected: {detection_result}")
+ ```
+
+2. **Importer la bibliothèque markdiffusion directement dans votre code sans cloner le dépôt.** Le notebook `MarkDiffusion_pypi_demo.ipynb` fournit des exemples complets pour utiliser MarkDiffusion via la bibliothèque markdiffusion — veuillez le consulter pour obtenir des conseils. Voici un exemple rapide :
+
+ ```python
+ import torch
+ from markdiffusion.watermark import AutoWatermark
+ from markdiffusion.utils import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
+
+ # Périphérique
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Using device: {device}")
+
+ # Chemin du modèle
+ MODEL_PATH = "huanzi05/stable-diffusion-2-1-base"
+
+ # Initialiser le planificateur et le pipeline
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(MODEL_PATH, subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained(
+ MODEL_PATH,
+ scheduler=scheduler,
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
+ safety_checker=None,
+ ).to(device)
+
+ # Créer DiffusionConfig pour la génération d'images
+ image_diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# IS (Score Inception)
-elif metric == 'IS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[InceptionScoreCalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Charger l'algorithme de tatouage Tree-Ring
+ tr_watermark = AutoWatermark.load('TR', diffusion_config=image_diffusion_config)
+ print("TR watermark algorithm loaded successfully!")
-# LPIPS (Similarité de patch d'image perceptuelle apprise)
-elif metric == 'LPIPS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=10)
- pipeline = RepeatImageQualityAnalysisPipeline(
- dataset=my_dataset,
- prompt_per_image=20,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[LPIPSAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # Générer une image tatouée
+ prompt = "A beautiful landscape with mountains and a river at sunset"
-# PSNR (Rapport signal sur bruit de crête)
-elif metric == 'PSNR':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = ComparedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[PSNRAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ watermarked_image = tr_watermark.generate_watermarked_media(input_data=prompt)
-# Charger le tatouage et évaluer
-my_watermark = AutoWatermark.load(
- f'{algorithm_name}',
- algorithm_config=f'config/{algorithm_name}.json',
- diffusion_config=diffusion_config
-)
+ # Afficher l'image tatouée
+ watermarked_image.save("watermarked_image.png")
+ print("Watermarked image generated!")
-print(pipeline.evaluate(my_watermark))
-```
+ # Détecter le tatouage dans l'image tatouée
+ detection_result = tr_watermark.detect_watermark_in_media(watermarked_image)
+ print("Watermarked image detection result:")
+ print(detection_result)
+ ```
-3. **Pipeline d'analyse de qualité vidéo**
-
-```python
-from evaluation.dataset import VBenchDataset
-from evaluation.pipelines.video_quality_analysis import DirectVideoQualityAnalysisPipeline
-from evaluation.tools.video_quality_analyzer import (
- SubjectConsistencyAnalyzer,
- MotionSmoothnessAnalyzer,
- DynamicDegreeAnalyzer,
- BackgroundConsistencyAnalyzer,
- ImagingQualityAnalyzer
-)
-
-# Charger le jeu de données VBench
-my_dataset = VBenchDataset(max_samples=200, dimension=dimension)
-
-# Initialiser l'analyseur en fonction de la métrique
-if metric == 'subject_consistency':
- analyzer = SubjectConsistencyAnalyzer(device=device)
-elif metric == 'motion_smoothness':
- analyzer = MotionSmoothnessAnalyzer(device=device)
-elif metric == 'dynamic_degree':
- analyzer = DynamicDegreeAnalyzer(device=device)
-elif metric == 'background_consistency':
- analyzer = BackgroundConsistencyAnalyzer(device=device)
-elif metric == 'imaging_quality':
- analyzer = ImagingQualityAnalyzer(device=device)
-else:
- raise ValueError(f'Invalid metric: {metric}. Supported metrics:
- subject_consistency, motion_smoothness, dynamic_degree,
- background_consistency, imaging_quality')
-
-# Créer le pipeline d'analyse de qualité vidéo
-pipeline = DirectVideoQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_video_editor_list=[],
- unwatermarked_video_editor_list=[],
- watermarked_frame_editor_list=[],
- unwatermarked_frame_editor_list=[],
- analyzers=[analyzer],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
-)
-
-print(pipeline.evaluate(my_watermark))
-```
+## 🛠 Modules de test
+Nous fournissons un ensemble complet de modules de test pour assurer la qualité du code. Le module comprend 454 tests unitaires avec environ 90% de couverture de code. Veuillez vous référer au répertoire `test/` pour plus de détails.
## Citation
```
diff --git a/README_zh.md b/README_zh.md
index 105b261..7d72a56 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -4,39 +4,44 @@
# 潜在扩散模型生成式水印的开源工具包
-[](https://generative-watermark.github.io/)
+[](https://generative-watermark.github.io/)
[](https://arxiv.org/abs/2509.10569)
-[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://huggingface.co/Generative-Watermark-Toolkits)
+[](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing)
+[](https://markdiffusion.readthedocs.io)
+[](https://pypi.org/project/markdiffusion)
+[](https://github.com/conda-forge/markdiffusion-feedstock)
+
-**语言版本:** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
+**语言版本:** [English](README.md) | [中文](README_zh.md) | [Français](README_fr.md) | [Español](README_es.md)
> 🔥 **作为一个新发布的项目,我们欢迎 PR!** 如果您已经实现了 LDM 水印算法或有兴趣贡献一个算法,我们很乐意将其包含在 MarkDiffusion 中。加入我们的社区,帮助让生成式水印技术对每个人都更易用!
## 目录
-- [注意事项](#-注意事项)
- [更新日志](#-更新日志)
-- [MarkDiffusion 简介](#markdiffusion-简介)
- - [概述](#概述)
- - [核心特性](#核心特性)
- - [已实现算法](#已实现算法)
- - [评估模块](#评估模块)
-- [安装](#安装)
-- [快速开始](#快速开始)
-- [如何使用工具包](#如何使用工具包)
- - [生成和检测水印媒体](#生成和检测水印媒体)
- - [可视化水印机制](#可视化水印机制)
- - [评估流水线](#评估流水线)
+- [MarkDiffusion 简介](#-markdiffusion-简介)
+ - [概述](#-概述)
+ - [核心特性](#-核心特性)
+ - [已实现算法](#-已实现算法)
+ - [评估模块](#-评估模块)
+- [快速开始](#-快速开始)
+ - [Google Colab 演示](#google-colab-演示)
+ - [安装](#安装)
+ - [如何使用工具包](#如何使用工具包)
+- [测试模块](#-测试模块)
- [引用](#引用)
-## ❗❗❗ 注意事项
-随着 MarkDiffusion 仓库内容日益丰富且体积不断增大,我们在 Hugging Face 上创建了一个名为 [Generative-Watermark-Toolkits](https://huggingface.co/Generative-Watermark-Toolkits) 的模型存储仓库以便于使用。该仓库包含了各种涉及自训练模型的水印算法的默认模型。我们已从主仓库中这些水印算法对应的 `ckpts/` 文件夹中移除了模型权重。**使用代码时,请首先根据配置路径从 Hugging Face 仓库下载相应的模型,并将其保存到 `ckpts/` 目录后再运行代码。**
## 🔥 更新日志
+🛠 **(2025.12.19)** 为所有功能添加了包含454个测试用例的完整测试套件。
+
+🛠 **(2025.12.10)** 使用 GitHub Actions 添加了持续集成测试系统。
+
🎯 **(2025.10.10)** 添加 *Mask、Overlay、AdaptiveNoiseInjection* 图像攻击工具,感谢付哲语的 PR!
-🎯 **(2025.10.09)** 添加 *VideoCodecAttack、FrameRateAdapter、FrameInterpolationAttack* 视频攻击工具,感谢司路阳的 PR!
+🎯 **(2025.10.09)** 添加 *FrameRateAdapter、FrameInterpolationAttack* 视频攻击工具,感谢司路阳的 PR!
🎯 **(2025.10.08)** 添加 *SSIM、BRISQUE、VIF、FSIM* 图像质量分析器,感谢王欢的 PR!
@@ -46,27 +51,27 @@
✨ **(2025.9.29)** 添加 [GaussMarker](https://arxiv.org/abs/2506.11444) 水印方法,感谢司路阳的 PR!
-## MarkDiffusion 简介
+## 🔓 MarkDiffusion 简介
-### 概述
+### 👀 概述
MarkDiffusion 是一个用于潜在扩散模型生成式水印的开源 Python 工具包。随着基于扩散的生成模型应用范围的扩大,确保生成媒体的真实性和来源变得至关重要。MarkDiffusion 简化了水印技术的访问、理解和评估,使研究人员和更广泛的社区都能轻松使用。*注意:如果您对 LLM 水印(文本水印)感兴趣,请参考我们团队的 [MarkLLM](https://github.com/THU-BPM/MarkLLM) 工具包。*
-该工具包包含三个关键组件:统一的实现框架,用于简化水印算法集成和用户友好的界面;机制可视化套件,直观地展示添加和提取的水印模式,帮助公众理解;以及全面的评估模块,提供 24 个工具的标准实现,涵盖三个关键方面——可检测性、鲁棒性和输出质量,以及 8 个自动化评估流水线。
+该工具包包含三个关键组件:统一的实现框架,用于简化水印算法集成和用户友好的界面;机制可视化套件,直观地展示添加和提取的水印模式,帮助公众理解;以及全面的评估模块,提供 31 个工具的标准实现,涵盖三个关键方面——可检测性、鲁棒性和输出质量,以及 6 个自动化评估流水线。
-### 核心特性
+### 💍 核心特性
-- **统一实现框架:** MarkDiffusion 提供了一个模块化架构,支持八种最先进的 LDM 生成式图像/视频水印算法。
+- **统一实现框架:** MarkDiffusion 提供了一个模块化架构,支持十一种最先进的 LDM 生成式图像/视频水印算法。
-- **全面的算法支持:** 目前实现了来自两大类别的 8 种水印算法:基于模式的方法(Tree-Ring、Ring-ID、ROBIN、WIND)和基于密钥的方法(Gaussian-Shading、PRC、SEAL、VideoShield)。
+- **全面的算法支持:** 目前实现了来自两大类别的 11 种水印算法:基于模式的方法(Tree-Ring、Ring-ID、ROBIN、WIND、SFW)和基于密钥的方法(Gaussian-Shading、PRC、SEAL、VideoShield、GaussMarker、VideoMark)。
- **可视化解决方案:** 该工具包包含定制的可视化工具,能够清晰而深入地展示不同水印算法在各种场景下的运行方式。这些可视化有助于揭示算法机制,使其对用户更易理解。
-- **评估模块:** 拥有 20 个评估工具,涵盖可检测性、鲁棒性和对输出质量的影响,MarkDiffusion 提供全面的评估能力。它具有 5 个自动化评估流水线:水印检测流水线、图像质量分析流水线、视频质量分析流水线以及专门的鲁棒性评估工具。
+- **评估模块:** 拥有 31 个评估工具,涵盖可检测性、鲁棒性和对输出质量的影响,MarkDiffusion 提供全面的评估能力。它具有 6 个自动化评估流水线:水印检测流水线、图像质量分析流水线、视频质量分析流水线以及专门的鲁棒性评估工具。
-### 已实现算法
+### ✨ 已实现算法
| **算法** | **类别** | **目标** | **参考文献** |
|---------------|-------------|------------|---------------|
@@ -82,7 +87,7 @@ MarkDiffusion 是一个用于潜在扩散模型生成式水印的开源 Python
| VideoShield | 密钥 | 视频 | [VideoShield: Regulating Diffusion-based Video Generation Models via Watermarking](https://arxiv.org/abs/2501.14195) |
| VideoMark | 密钥 | 视频 | [VideoMark: A Distortion-Free Robust Watermarking Framework for Video Diffusion Models](https://arxiv.org/abs/2504.16359) |
-### 评估模块
+### 🎯 评估模块
#### 评估流水线
MarkDiffusion 支持八个流水线,两个用于检测(WatermarkedMediaDetectionPipeline 和 UnWatermarkedMediaDetectionPipeline),六个用于质量分析。下表详细说明了质量分析流水线。
@@ -116,7 +121,6 @@ MarkDiffusion 支持八个流水线,两个用于检测(WatermarkedMediaDetec
| MPEG4Compression | 鲁棒性(视频) | MPEG-4 视频压缩攻击,测试视频水印的压缩鲁棒性 | 压缩后的视频帧 |
| FrameAverage | 鲁棒性(视频) | 帧平均攻击,通过帧间平均破坏水印 | 平均后的视频帧 |
| FrameSwap | 鲁棒性(视频) | 帧交换攻击,通过改变帧序列测试鲁棒性 | 交换后的视频帧 |
-| VideoCodecAttack | 鲁棒性(视频) | 编解码器重编码攻击,模拟平台转码(H.264/H.265/VP9/AV1) | 重编码后的视频帧 |
| FrameRateAdapter | 鲁棒性(视频) | 帧率转换攻击,在保持时长的同时重采样帧 | 重采样后的帧序列 |
| FrameInterpolationAttack | 鲁棒性(视频) | 帧插值攻击,插入混合帧以改变时间密度 | 插值后的视频帧 |
| **图像质量分析器** | | | |
@@ -137,326 +141,130 @@ MarkDiffusion 支持八个流水线,两个用于检测(WatermarkedMediaDetec
| DynamicDegreeAnalyzer | 质量(视频) | 测量视频中的动态水平和变化幅度 | 动态度值 |
| ImagingQualityAnalyzer | 质量(视频) | 综合评估视频成像质量 | 成像质量分数 |
-## 安装
-
-### 环境设置
-
-- Python 3.10+
-- PyTorch
-- 安装依赖:
+## 🧩 快速开始
+### Google Colab 演示
+如果您想在不安装任何内容的情况下试用 MarkDiffusion,可以使用 [Google Colab](https://colab.research.google.com/drive/1N1C9elDAB5zwF4FxKKYMCqR3eSpCSqAW?usp=sharing#scrollTo=-kWt7m9Y3o-G) 查看其工作方式。
+### 安装
+**(推荐)** 我们为 MarkDiffusion 发布了 pypi 包。您可以直接使用 pip 安装:
```bash
-pip install -r requirements.txt
-```
-
-*注意:* 某些算法可能需要额外的设置步骤。请参考各个算法文档了解具体要求。
-
-## 快速开始
-
-这里有一个简单的示例帮助您开始使用 MarkDiffusion:
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-
-# 设备设置
-device = 'cuda' if torch.cuda.is_available() else 'cpu'
-
-# 配置扩散流水线
-scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
-pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
-diffusion_config = DiffusionConfig(
- scheduler=scheduler,
- pipe=pipe,
- device=device,
- image_size=(512, 512),
- num_inference_steps=50,
- guidance_scale=7.5,
- gen_seed=42,
- inversion_type="ddim"
-)
-
-# 加载水印算法
-watermark = AutoWatermark.load('TR',
- algorithm_config='config/TR.json',
- diffusion_config=diffusion_config)
-
-# 生成带水印的媒体
-prompt = "A beautiful sunset over the ocean"
-watermarked_image = watermark.generate_watermarked_media(prompt)
-
-# 检测水印
-detection_result = watermark.detect_watermark_in_media(watermarked_image)
-print(f"Watermark detected: {detection_result}")
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+pip install markdiffusion[optional]
```
-## 如何使用工具包
-
-我们在 `MarkDiffusion_demo.ipynb` 中提供了大量示例。
-
-### 生成和检测水印媒体
-
-#### 生成和检测水印媒体的案例
-
-```python
-import torch
-from watermark.auto_watermark import AutoWatermark
-from utils.diffusion_config import DiffusionConfig
-
-# 加载水印算法
-mywatermark = AutoWatermark.load(
- 'GS',
- algorithm_config=f'config/GS.json',
- diffusion_config=diffusion_config
-)
-
-# 生成带水印的图像
-watermarked_image = mywatermark.generate_watermarked_media(
- input_data="A beautiful landscape with a river and mountains"
-)
-
-# 可视化带水印的图像
-watermarked_image.show()
-
-# 检测水印
-detection_result = mywatermark.detect_watermark_in_media(watermarked_image)
-print(detection_result)
+(替代方案)对于*仅限于使用 conda 环境*的用户,我们还提供了 conda-forge 包,可以使用以下命令安装:
+```bash
+conda create -n markdiffusion python=3.11
+conda activate markdiffusion
+conda config --add channels conda-forge
+conda config --set channel_priority strict
+conda install markdiffusion
```
+但是,请注意,某些高级功能需要 conda 上不可用的额外包,因此无法包含在发布版本中。如有必要,您需要单独安装这些包。
-### 可视化水印机制
-
-该工具包包含定制的可视化工具,能够清晰而深入地展示不同水印算法在各种场景下的运行方式。这些可视化有助于揭示算法机制,使其对用户更易理解。
-
-
+### 如何使用工具包
-#### 可视化水印机制的案例
+安装后,有两种方式使用 MarkDiffusion:
-```python
-from visualize.auto_visualization import AutoVisualizer
-
-# 获取用于可视化的数据
-data_for_visualization = mywatermark.get_data_for_visualize(watermarked_image)
-
-# 加载可视化器
-visualizer = AutoVisualizer.load('GS',
- data_for_visualization=data_for_visualization)
-
-# 在 Matplotlib 画布上绘制图表
-fig = visualizer.visualize(rows=2, cols=2,
- methods=['draw_watermark_bits',
- 'draw_reconstructed_watermark_bits',
- 'draw_inverted_latents',
- 'draw_inverted_latents_fft'])
-```
+1. **克隆仓库以尝试演示或用于自定义开发。** `MarkDiffusion_demo.ipynb` notebook 提供了各种用例的详细演示——请查看以获取指导。以下是使用 TR 算法生成和检测带水印图像的快速示例:
-### 评估流水线
-
-#### 评估案例
-
-1. **水印检测流水线**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset
-from evaluation.pipelines.detection import (
- WatermarkedMediaDetectionPipeline,
- UnWatermarkedMediaDetectionPipeline,
- DetectionPipelineReturnType
-)
-from evaluation.tools.image_editor import JPEGCompression
-from evaluation.tools.success_rate_calculator import DynamicThresholdSuccessRateCalculator
-
-# 数据集
-my_dataset = StableDiffusionPromptsDataset(max_samples=200)
-
-# 设置检测流水线
-pipeline1 = WatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[JPEGCompression(quality=60)],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-pipeline2 = UnWatermarkedMediaDetectionPipeline(
- dataset=my_dataset,
- media_editor_list=[],
- show_progress=True,
- return_type=DetectionPipelineReturnType.SCORES
-)
-
-# 配置检测参数
-detection_kwargs = {
- "num_inference_steps": 50,
- "guidance_scale": 1.0,
-}
-# 计算成功率
-calculator = DynamicThresholdSuccessRateCalculator(
- labels=labels,
- rule=rules,
- target_fpr=target_fpr
-)
-
-results = calculator.calculate(
- pipeline1.evaluate(my_watermark, detection_kwargs=detection_kwargs),
- pipeline2.evaluate(my_watermark, detection_kwargs=detection_kwargs)
-)
-print(results)
-```
+ ```python
+ import torch
+ from watermark.auto_watermark import AutoWatermark
+ from utils.diffusion_config import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
-2. **图像质量分析流水线**
-
-```python
-from evaluation.dataset import StableDiffusionPromptsDataset, MSCOCODataset
-from evaluation.pipelines.image_quality_analysis import (
- DirectImageQualityAnalysisPipeline,
- ReferencedImageQualityAnalysisPipeline,
- GroupImageQualityAnalysisPipeline,
- RepeatImageQualityAnalysisPipeline,
- ComparedImageQualityAnalysisPipeline,
- QualityPipelineReturnType
-)
-from evaluation.tools.image_quality_analyzer import (
- NIQECalculator, CLIPScoreCalculator, FIDCalculator,
- InceptionScoreCalculator, LPIPSAnalyzer, PSNRAnalyzer
-)
-
-# 不同质量指标的示例:
-
-# NIQE(无参考图像质量评估器)
-if metric == 'NIQE':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = DirectImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[NIQECalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # 设备设置
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
-# CLIP 分数
-elif metric == 'CLIP':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = ReferencedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[CLIPScoreCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # 配置扩散流水线
+ scheduler = DPMSolverMultistepScheduler.from_pretrained("model_path", subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained("model_path", scheduler=scheduler).to(device)
+ diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# FID(Fréchet Inception Distance)
-elif metric == 'FID':
- my_dataset = MSCOCODataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[FIDCalculator()],
- unwatermarked_image_source='generated',
- reference_image_source='natural',
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
+ # 加载水印算法
+ watermark = AutoWatermark.load('TR',
+ algorithm_config='config/TR.json',
+ diffusion_config=diffusion_config)
+
+ # 生成带水印的媒体
+ prompt = "A beautiful sunset over the ocean"
+ watermarked_image = watermark.generate_watermarked_media(prompt)
+ watermarked_image.save("watermarked_image.png")
+
+ # 检测水印
+ detection_result = watermark.detect_watermark_in_media(watermarked_image)
+ print(f"Watermark detected: {detection_result}")
+ ```
+
+2. **在代码中直接导入 markdiffusion 库,无需克隆仓库。** `MarkDiffusion_pypi_demo.ipynb` notebook 提供了通过 markdiffusion 库使用 MarkDiffusion 的全面示例——请查看以获取指导。以下是一个快速示例:
+
+ ```python
+ import torch
+ from markdiffusion.watermark import AutoWatermark
+ from markdiffusion.utils import DiffusionConfig
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
+
+ # 设备
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(f"Using device: {device}")
+
+ # 模型路径
+ MODEL_PATH = "huanzi05/stable-diffusion-2-1-base"
+
+ # 初始化调度器和流水线
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(MODEL_PATH, subfolder="scheduler")
+ pipe = StableDiffusionPipeline.from_pretrained(
+ MODEL_PATH,
+ scheduler=scheduler,
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
+ safety_checker=None,
+ ).to(device)
+
+ # 创建用于图像生成的 DiffusionConfig
+ image_diffusion_config = DiffusionConfig(
+ scheduler=scheduler,
+ pipe=pipe,
+ device=device,
+ image_size=(512, 512),
+ guidance_scale=7.5,
+ num_inference_steps=50,
+ gen_seed=42,
+ inversion_type="ddim"
)
-# IS(Inception Score)
-elif metric == 'IS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = GroupImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[InceptionScoreCalculator()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # 加载 Tree-Ring 水印算法
+ tr_watermark = AutoWatermark.load('TR', diffusion_config=image_diffusion_config)
+ print("TR watermark algorithm loaded successfully!")
-# LPIPS(学习感知图像块相似度)
-elif metric == 'LPIPS':
- my_dataset = StableDiffusionPromptsDataset(max_samples=10)
- pipeline = RepeatImageQualityAnalysisPipeline(
- dataset=my_dataset,
- prompt_per_image=20,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[LPIPSAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ # 生成带水印的图像
+ prompt = "A beautiful landscape with mountains and a river at sunset"
-# PSNR(峰值信噪比)
-elif metric == 'PSNR':
- my_dataset = StableDiffusionPromptsDataset(max_samples=max_samples)
- pipeline = ComparedImageQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_image_editor_list=[],
- unwatermarked_image_editor_list=[],
- analyzers=[PSNRAnalyzer()],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
- )
+ watermarked_image = tr_watermark.generate_watermarked_media(input_data=prompt)
-# 加载水印并评估
-my_watermark = AutoWatermark.load(
- f'{algorithm_name}',
- algorithm_config=f'config/{algorithm_name}.json',
- diffusion_config=diffusion_config
-)
+ # 显示带水印的图像
+ watermarked_image.save("watermarked_image.png")
+ print("Watermarked image generated!")
-print(pipeline.evaluate(my_watermark))
-```
+ # 检测带水印图像中的水印
+ detection_result = tr_watermark.detect_watermark_in_media(watermarked_image)
+ print("Watermarked image detection result:")
+ print(detection_result)
+ ```
-3. **视频质量分析流水线**
-
-```python
-from evaluation.dataset import VBenchDataset
-from evaluation.pipelines.video_quality_analysis import DirectVideoQualityAnalysisPipeline
-from evaluation.tools.video_quality_analyzer import (
- SubjectConsistencyAnalyzer,
- MotionSmoothnessAnalyzer,
- DynamicDegreeAnalyzer,
- BackgroundConsistencyAnalyzer,
- ImagingQualityAnalyzer
-)
-
-# 加载 VBench 数据集
-my_dataset = VBenchDataset(max_samples=200, dimension=dimension)
-
-# 根据指标初始化分析器
-if metric == 'subject_consistency':
- analyzer = SubjectConsistencyAnalyzer(device=device)
-elif metric == 'motion_smoothness':
- analyzer = MotionSmoothnessAnalyzer(device=device)
-elif metric == 'dynamic_degree':
- analyzer = DynamicDegreeAnalyzer(device=device)
-elif metric == 'background_consistency':
- analyzer = BackgroundConsistencyAnalyzer(device=device)
-elif metric == 'imaging_quality':
- analyzer = ImagingQualityAnalyzer(device=device)
-else:
- raise ValueError(f'Invalid metric: {metric}. Supported metrics:
- subject_consistency, motion_smoothness, dynamic_degree,
- background_consistency, imaging_quality')
-
-# 创建视频质量分析流水线
-pipeline = DirectVideoQualityAnalysisPipeline(
- dataset=my_dataset,
- watermarked_video_editor_list=[],
- unwatermarked_video_editor_list=[],
- watermarked_frame_editor_list=[],
- unwatermarked_frame_editor_list=[],
- analyzers=[analyzer],
- show_progress=True,
- return_type=QualityPipelineReturnType.MEAN_SCORES
-)
-
-print(pipeline.evaluate(my_watermark))
-```
+## 🛠 测试模块
+我们提供了一套全面的测试模块来确保代码质量。该模块包含454个单元测试,覆盖率约为90%。详情请参考 `test/` 目录。
## 引用
```
diff --git a/docs/BUILD.md b/docs/BUILD.md
index e0e2b0a..9b3abff 100644
--- a/docs/BUILD.md
+++ b/docs/BUILD.md
@@ -131,11 +131,9 @@ docs/
│ └── configuration.rst
├── api/ # API reference
│ ├── watermark.rst
-│ ├── detection.rst
│ ├── visualization.rst
-│ ├── evaluation.rst
-│ └── utils.rst
-├── changelog.rst # Changelog
+│ ├── utils.rst
+│ └── evaluation.rst
├── contributing.rst # Contributing guide
├── citation.rst # Citation information
├── _static/ # Static files (CSS, images)
diff --git a/docs/_build/doctrees/BUILD.doctree b/docs/_build/doctrees/BUILD.doctree
deleted file mode 100644
index ad89115..0000000
Binary files a/docs/_build/doctrees/BUILD.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/api/detection.doctree b/docs/_build/doctrees/api/detection.doctree
deleted file mode 100644
index b6f0f8f..0000000
Binary files a/docs/_build/doctrees/api/detection.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/api/evaluation.doctree b/docs/_build/doctrees/api/evaluation.doctree
deleted file mode 100644
index 563b198..0000000
Binary files a/docs/_build/doctrees/api/evaluation.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/api/utils.doctree b/docs/_build/doctrees/api/utils.doctree
deleted file mode 100644
index 725eb86..0000000
Binary files a/docs/_build/doctrees/api/utils.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/api/visualization.doctree b/docs/_build/doctrees/api/visualization.doctree
deleted file mode 100644
index c4e2f0b..0000000
Binary files a/docs/_build/doctrees/api/visualization.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/api/watermark.doctree b/docs/_build/doctrees/api/watermark.doctree
deleted file mode 100644
index 72000e6..0000000
Binary files a/docs/_build/doctrees/api/watermark.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/changelog.doctree b/docs/_build/doctrees/changelog.doctree
deleted file mode 100644
index 1ce9334..0000000
Binary files a/docs/_build/doctrees/changelog.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/citation.doctree b/docs/_build/doctrees/citation.doctree
deleted file mode 100644
index a9cf1b4..0000000
Binary files a/docs/_build/doctrees/citation.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/code_of_conduct.doctree b/docs/_build/doctrees/code_of_conduct.doctree
deleted file mode 100644
index a07e735..0000000
Binary files a/docs/_build/doctrees/code_of_conduct.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/contributing.doctree b/docs/_build/doctrees/contributing.doctree
deleted file mode 100644
index c2c57b4..0000000
Binary files a/docs/_build/doctrees/contributing.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle
deleted file mode 100644
index 4a4c317..0000000
Binary files a/docs/_build/doctrees/environment.pickle and /dev/null differ
diff --git a/docs/_build/doctrees/index.doctree b/docs/_build/doctrees/index.doctree
deleted file mode 100644
index 2e3eaf9..0000000
Binary files a/docs/_build/doctrees/index.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/installation.doctree b/docs/_build/doctrees/installation.doctree
deleted file mode 100644
index a3f10fa..0000000
Binary files a/docs/_build/doctrees/installation.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/quickstart.doctree b/docs/_build/doctrees/quickstart.doctree
deleted file mode 100644
index cdced12..0000000
Binary files a/docs/_build/doctrees/quickstart.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/tutorial.doctree b/docs/_build/doctrees/tutorial.doctree
deleted file mode 100644
index 17e205a..0000000
Binary files a/docs/_build/doctrees/tutorial.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/user_guide/algorithms.doctree b/docs/_build/doctrees/user_guide/algorithms.doctree
deleted file mode 100644
index 06786d6..0000000
Binary files a/docs/_build/doctrees/user_guide/algorithms.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/user_guide/evaluation.doctree b/docs/_build/doctrees/user_guide/evaluation.doctree
deleted file mode 100644
index 7426b9e..0000000
Binary files a/docs/_build/doctrees/user_guide/evaluation.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/user_guide/visualization.doctree b/docs/_build/doctrees/user_guide/visualization.doctree
deleted file mode 100644
index c3abf65..0000000
Binary files a/docs/_build/doctrees/user_guide/visualization.doctree and /dev/null differ
diff --git a/docs/_build/doctrees/user_guide/watermarking.doctree b/docs/_build/doctrees/user_guide/watermarking.doctree
deleted file mode 100644
index 78eb0ec..0000000
Binary files a/docs/_build/doctrees/user_guide/watermarking.doctree and /dev/null differ
diff --git a/docs/_build/html/.buildinfo b/docs/_build/html/.buildinfo
deleted file mode 100644
index 17acfd6..0000000
--- a/docs/_build/html/.buildinfo
+++ /dev/null
@@ -1,4 +0,0 @@
-# Sphinx build info version 1
-# This file records the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: e8649eef295dd15fd61c57a06ff7b183
-tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/_build/html/BUILD.html b/docs/_build/html/BUILD.html
deleted file mode 100644
index 72ac430..0000000
--- a/docs/_build/html/BUILD.html
+++ /dev/null
@@ -1,1046 +0,0 @@
-
-
-
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromabcimportABC,abstractmethod
-importtorch
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""GaussMarker detection utilities.
-
-This module adapts the official GaussMarker detection pipeline to the
-MarkDiffusion detection API. It evaluates recovered diffusion latents to
-decide whether a watermark is present, reporting both hard decisions and
-auxiliary scores (bit/message accuracies, frequency-domain distances).
-"""
-
-from__future__importannotations
-
-frompathlibimportPath
-fromtypingimportDict,Optional,Union
-
-importnumpyasnp
-importtorch
-
-importjoblib
-
-fromdetection.baseimportBaseDetector
-fromwatermark.gm.gmimportGaussianShadingChaCha,extract_complex_sign
-fromwatermark.gm.gnrimportGNRRestorer
-
-
-
-[docs]
-classGMDetector(BaseDetector):
-"""Detector for GaussMarker watermarks.
-
- Args:
- watermark_generator: Instance of :class:`GaussianShadingChaCha` that
- holds the original watermark bits and ChaCha20 key stream.
- watermarking_mask: Frequency-domain mask (or label map) indicating the
- region that carries the watermark.
- gt_patch: Reference watermark pattern in the frequency domain.
- w_measurement: Measurement mode (e.g., ``"l1_complex"`` or
- ``"signal_complex"``), mirroring the official implementation.
- device: Torch device used for evaluation.
- bit_threshold: Optional override for the bit-accuracy decision
- threshold. Defaults to the generator's ``tau_bits`` value.
- message_threshold: Optional threshold for message accuracy decisions.
- l1_threshold: Optional threshold for frequency L1 distance decisions
- (smaller is better).
- """
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importnumpyasnp
-fromCrypto.Randomimportget_random_bytes
-fromCrypto.CipherimportChaCha20
-fromscipy.statsimporttruncnorm,norm
-fromfunctoolsimportreduce
-fromdetection.baseimportBaseDetector
-fromtypingimportUnion
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importnumpyasnp
-fromtypingimportTuple,Type
-fromscipy.specialimporterf
-fromdetection.baseimportBaseDetector
-fromldpcimportbp_decoder
-fromgaloisimportFieldArray
-
-
-
-
- def_recover_posteriors(self,z,basis=None,variances=None):
- ifvariancesisNone:
- default_variance=1.5
- denominators=np.sqrt(2*default_variance*(1+default_variance))*torch.ones_like(z)
- eliftype(variances)isfloat:
- denominators=np.sqrt(2*variances*(1+variances))
- else:
- denominators=torch.sqrt(2*variances*(1+variances))
-
- ifbasisisNone:
- returnerf(z/denominators)
- else:
- returnerf((z@basis)/denominators)
-
- def_detect_watermark(self,posteriors:torch.Tensor)->Tuple[bool,float]:
-"""Detect watermark in posteriors."""
- generator_matrix,parity_check_matrix,one_time_pad,false_positive_rate,noise_rate,test_bits,g,max_bp_iter,t=self.decoding_key
- posteriors=(1-2*noise_rate)*(1-2*np.array(one_time_pad,dtype=float))*posteriors.numpy(force=True)
-
- r=parity_check_matrix.shape[0]
- Pi=np.prod(posteriors[parity_check_matrix.indices.reshape(r,t)],axis=1)
- log_plus=np.log((1+Pi)/2)
- log_minus=np.log((1-Pi)/2)
- log_prod=log_plus+log_minus
-
- const=0.5*np.sum(np.power(log_plus,2)+np.power(log_minus,2)-0.5*np.power(log_prod,2))
- threshold=np.sqrt(2*const*np.log(1/false_positive_rate))+0.5*log_prod.sum()
- #print(f"threshold: {threshold}")
- returnlog_plus.sum()>=threshold,log_plus.sum()
-
- def_boolean_row_reduce(self,A,print_progress=False):
-"""Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix
-
- Args:
- A (np.ndarray): A GF(2) matrix
- print_progress (bool, optional): Whether to print the progress. Defaults to False.
-
- Returns:
- np.ndarray: The first k rows of A that form an invertible matrix
- """
- n,k=A.shape
- A_rr=A.copy()
- perm=np.arange(n)
- forjinrange(k):
- idxs=j+np.nonzero(A_rr[j:,j])[0]
- ifidxs.size==0:
- print("The given matrix is not invertible")
- returnNone
- A_rr[[j,idxs[0]]]=A_rr[[idxs[0],j]]# For matrices you have to swap them this way
- (perm[j],perm[idxs[0]])=(perm[idxs[0]],perm[j])# Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]]
- A_rr[idxs[1:]]+=A_rr[j]
- ifprint_progressand(j%5==0orj+1==k):
- sys.stdout.write(f'\rDecoding progress: {j+1} / {k}')
- sys.stdout.flush()
- ifprint_progress:print()
- returnperm[:k]
-
- def_decode_message(self,posteriors,print_progress=False,max_bp_iter=None):
- generator_matrix,parity_check_matrix,one_time_pad,false_positive_rate_key,noise_rate,test_bits,g,max_bp_iter_key,t=self.decoding_key
- ifmax_bp_iterisNone:
- max_bp_iter=max_bp_iter_key
-
- posteriors=(1-2*noise_rate)*(1-2*np.array(one_time_pad,dtype=float))*posteriors.numpy(force=True)
- channel_probs=(1-np.abs(posteriors))/2
- x_recovered=(1-np.sign(posteriors))//2
-
-
- # Apply the belief-propagation decoder.
- ifprint_progress:
- print("Running belief propagation...")
- bpd=bp_decoder(parity_check_matrix,channel_probs=channel_probs,max_iter=max_bp_iter,bp_method="product_sum")
- x_decoded=bpd.decode(x_recovered)
-
- # Compute a confidence score.
- bpd_probs=1/(1+np.exp(bpd.log_prob_ratios))
- confidences=2*np.abs(0.5-bpd_probs)
-
- # Order codeword bits by confidence.
- confidence_order=np.argsort(-confidences)
- ordered_generator_matrix=generator_matrix[confidence_order]
- ordered_x_decoded=x_decoded[confidence_order].astype(int)
-
- # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix.
- top_invertible_rows=self._boolean_row_reduce(ordered_generator_matrix,print_progress=print_progress)
- iftop_invertible_rowsisNone:
- returnNone
-
- # Solve the system.
- ifprint_progress:
- print("Solving linear system...")
- recovered_string=np.linalg.solve(ordered_generator_matrix[top_invertible_rows],self.GF(ordered_x_decoded[top_invertible_rows]))
-
- ifnot(recovered_string[:len(test_bits)]==test_bits).all():
- returnNone
- returnnp.array(recovered_string[len(test_bits)+g:])
-
- def_binary_array_to_str(self,binary_array:np.ndarray)->str:
-"""Convert binary array back to string."""
- # Ensure the binary array length is divisible by 8 (1 byte = 8 bits)
- assertlen(binary_array)%8==0,"Binary array length must be a multiple of 8"
-
- # Group the binary array into chunks of 8 bits
- byte_chunks=binary_array.reshape(-1,8)
-
- # Convert each byte (8 bits) to a character
- chars=[chr(int(''.join(map(str,byte)),2))forbyteinbyte_chunks]
-
- # Join the characters to form the original string
- return''.join(chars)
-
-
-[docs]
- defeval_watermark(self,reversed_latents:torch.Tensor,reference_latents:torch.Tensor=None,detector_type:str="is_watermarked")->float:
-"""Evaluate watermark in reversed latents."""
- ifdetector_type!='is_watermarked':
- raiseValueError(f'Detector type {detector_type} is not supported for PRC. Use "is_watermarked" instead.')
- reversed_prc=self._recover_posteriors(reversed_latents.to(torch.float64).flatten().cpu(),variances=self.var).flatten().cpu()
- self.recovered_prc=reversed_prc
- detect_result,score=self._detect_watermark(reversed_prc)
- decoding_result=self._decode_message(reversed_prc)
- ifdecoding_resultisNone:
- return{
- 'is_watermarked':False,
- "score":score,# Keep the score for potential future use
- 'decoding_result':decoding_result,
- "decoded_message":None
- }
- decoded_message=self._binary_array_to_str(decoding_result)
- combined_result=detect_resultor(decoding_resultisnotNone)
- #print(f"detection_result: {detect_result}, decoding_result: {decoding_result}, combined_result: {combined_result}")
- return{
- 'is_watermarked':bool(combined_result),
- "score":score,# Keep the score for potential future use
- 'decoding_result':decoding_result,
- "decoded_message":decoded_message
- }
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-fromdetection.baseimportBaseDetector
-fromscipy.statsimportncx2
-fromtorch.nnimportfunctionalasF
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-fromdetection.baseimportBaseDetector
-fromtransformersimportBlip2Processor,Blip2ForConditionalGeneration
-fromsentence_transformersimportSentenceTransformer
-fromPILimportImage
-importmath
-
-
-
-
- def_calculate_patch_l2(self,noise1:torch.Tensor,noise2:torch.Tensor,k:int)->torch.Tensor:
-"""
- Calculate L2 distances patch by patch. Returns a list of L2 values for the first k patches.
- """
- l2_list=[]
- patch_per_side_h=int(math.ceil(math.sqrt(k)))
- patch_per_side_w=int(math.ceil(k/patch_per_side_h))
- patch_height=64//patch_per_side_h
- patch_width=64//patch_per_side_w
- patch_count=0
- foriinrange(patch_per_side_h):
- forjinrange(patch_per_side_w):
- ifpatch_count>=k:
- break
- y_start=i*patch_height
- x_start=j*patch_width
- y_end=min(y_start+patch_height,64)
- x_end=min(x_start+patch_width,64)
- patch1=noise1[:,:,y_start:y_end,x_start:x_end]
- patch2=noise2[:,:,y_start:y_end,x_start:x_end]
- l2_val=torch.norm(patch1-patch2).item()
- l2_list.append(l2_val)
- patch_count+=1
- returnl2_list
-
-
-[docs]
- defeval_watermark(self,
- reversed_latents:torch.Tensor,
- reference_latents:torch.Tensor,
- detector_type:str="patch_accuracy")->float:
-
- ifdetector_type!="patch_accuracy":
- raiseValueError(f"Detector type {detector_type} is not supported for SEAL detector")
-
- l2_patch_list=self._calculate_patch_l2(reversed_latents,reference_latents,self.k)
-
- # Count the number of patches that are less than the threshold
- num_patches_below_threshold=sum(1forl2inl2_patch_listifl2<self.patch_distance_threshold)
-
- return{
- "is_watermarked":bool(num_patches_below_threshold>=self.match_threshold),
- "patch_accuracy":num_patches_below_threshold/self.k,
- }
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-fromdetection.baseimportBaseDetector
-fromscipy.statsimportncx2
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-fromdetection.baseimportBaseDetector
-fromscipy.statsimportncx2
-
-
Source code for detection.videomark.videomark_detection
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importnumpyasnp
-fromtypingimportTuple,Type
-fromscipy.specialimporterf
-fromdetection.baseimportBaseDetector
-fromldpcimportbp_decoder
-fromgaloisimportFieldArray
-importsys
-fromLevenshteinimporthamming
-importlogging
-
-logger=logging.getLogger(__name__)
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-fromdetection.baseimportBaseDetector
-importtorch.nn.functionalasF
-importnumpyasnp
-importlogging
-fromtypingimportDict,Any
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importujsonasjson
-fromdatasetsimportload_dataset
-importpandasaspd
-fromPILimportImage
-importrequests
-fromioimportBytesIO
-fromtqdmimporttqdm
-importrandom
-fromtypingimportList
-
-
-[docs]
- def__init__(self,max_samples:int=200):
-"""Initialize the dataset.
-
- Parameters:
- max_samples: Maximum number of samples to load.
- """
- self.max_samples=max_samples
- self.prompts=[]
- self.references=[]
-
-
- @property
- defnum_samples(self)->int:
-"""Number of samples in the dataset."""
- returnlen(self.prompts)
-
- @property
- defnum_references(self)->int:
-"""Number of references in the dataset."""
- returnlen(self.references)
-
-
-[docs]
- defget_prompt(self,idx)->str:
-"""Get the prompt at the given index."""
- returnself.prompts[idx]
-
-
-
-[docs]
- defget_reference(self,idx)->Image.Image:
-"""Get the reference Image at the given index."""
- returnself.references[idx]
-
-
-
-[docs]
- def__len__(self)->int:
-"""Number of samples in the dataset.(Equivalent to num_samples)"""
- returnself.num_samples
-
-
-
-[docs]
- def__getitem__(self,idx)->tuple[str,Image.Image]:
-"""Get the prompt (and reference Image if available) at the given index."""
- iflen(self.references)==0:
- returnself.prompts[idx]
- else:
- returnself.prompts[idx],self.references[idx]
-
-
- def_load_data(self):
-"""Load data from the dataset."""
- pass
-[docs]
- def__init__(self,max_samples:int=200,split:str="test",shuffle:bool=False):
-"""Initialize the dataset.
-
- Parameters:
- max_samples: Maximum number of samples to load.
- split: Split to load.
- shuffle: Whether to shuffle the dataset.
- """
- super().__init__(max_samples)
- self.split=split
- self.shuffle=shuffle
- self._load_data()
-[docs]
- def__init__(self,max_samples:int=200,shuffle:bool=False):
-"""Initialize the dataset.
-
- Parameters:
- max_samples: Maximum number of samples to load.
- shuffle: Whether to shuffle the dataset.
- """
- super().__init__(max_samples)
- self.shuffle=shuffle
- self._load_data()
-
-
- @property
- defname(self):
-"""Name of the dataset."""
- return"MS-COCO 2017"
-
- def_load_image_from_url(self,url):
-"""Load image from url."""
- try:
- response=requests.get(url)
- response.raise_for_status()
- image=Image.open(BytesIO(response.content))
- returnimage
- exceptExceptionase:
- print(f"Load image from url failed: {e}")
- returnNone
-
- def_load_data(self):
-"""Load data from the MSCOCO 2017 dataset."""
- df=pd.read_parquet("dataset/mscoco/mscoco.parquet")
- ifself.shuffle:
- df=df.sample(frac=1).reset_index(drop=True)
- foriintqdm(range(self.max_samples),desc="Loading MSCOCO dataset"):
- item=df.iloc[i]
- self.prompts.append(item['TEXT'])
- self.references.append(self._load_image_from_url(item['URL']))
Source code for evaluation.pipelines.image_quality_analysis
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromwatermark.baseimportBaseWatermark
-fromevaluation.datasetimportBaseDataset
-fromtqdmimporttqdm
-fromenumimportEnum,auto
-fromPILimportImage
-fromevaluation.tools.image_editorimportImageEditor
-fromtypingimportList,Dict,Union,Tuple,Any,Optional
-importnumpyasnp
-fromdataclassesimportdataclass,field
-importos
-importrandom
-fromevaluation.tools.image_quality_analyzerimport(
- ImageQualityAnalyzer
-)
-importlpips
-
-
-[docs]
-classQualityPipelineReturnType(Enum):
-"""Return type of the image quality analysis pipeline."""
- FULL=auto()
- SCORES=auto()
- MEAN_SCORES=auto()
-[docs]
-classQualityComparisonResult:
-"""Result of image quality comparison."""
-
-
-[docs]
- def__init__(self,
- store_path:str,
- watermarked_quality_scores:Dict[str,List[float]],
- unwatermarked_quality_scores:Dict[str,List[float]],
- prompts:List[str],
- )->None:
-"""
- Initialize the image quality comparison result.
-
- Parameters:
- store_path: The path to store the results.
- watermarked_quality_scores: The quality scores of the watermarked image.
- unwatermarked_quality_scores: The quality scores of the unwatermarked image.
- prompts: The prompts used to generate the images.
- """
- self.store_path=store_path
- self.watermarked_quality_scores=watermarked_quality_scores
- self.unwatermarked_quality_scores=unwatermarked_quality_scores
- self.prompts=prompts
-
-
-
-
-
-[docs]
-classImageQualityAnalysisPipeline:
-"""Pipeline for image quality analysis."""
-
-
-[docs]
- def__init__(self,
- dataset:BaseDataset,
- watermarked_image_editor_list:List[ImageEditor]=[],
- unwatermarked_image_editor_list:List[ImageEditor]=[],
- analyzers:List[ImageQualityAnalyzer]=None,
- unwatermarked_image_source:str='generated',
- reference_image_source:str='natural',
- show_progress:bool=True,
- store_path:str=None,
- return_type:QualityPipelineReturnType=QualityPipelineReturnType.MEAN_SCORES)->None:
-"""
- Initialize the image quality analysis pipeline.
-
- Parameters:
- dataset: The dataset for evaluation.
- watermarked_image_editor_list: The list of image editors for watermarked images.
- unwatermarked_image_editor_list: The list of image editors for unwatermarked images.
- analyzers: List of quality analyzers for images.
- unwatermarked_image_source: The source of unwatermarked images ('natural' or 'generated').
- reference_image_source: The source of reference images ('natural' or 'generated').
- show_progress: Whether to show progress.
- store_path: The path to store the results. If None, the generated images will not be stored.
- return_type: The return type of the pipeline.
- """
- ifunwatermarked_image_sourcenotin['natural','generated']:
- raiseValueError(f"Invalid unwatermarked_image_source: {unwatermarked_image_source}")
-
- self.dataset=dataset
- self.watermarked_image_editor_list=watermarked_image_editor_list
- self.unwatermarked_image_editor_list=unwatermarked_image_editor_list
- self.analyzers=analyzersor[]
- self.unwatermarked_image_source=unwatermarked_image_source
- self.reference_image_source=reference_image_source
- self.show_progress=show_progress
- self.store_path=store_path
- self.return_type=return_type
-
-
- def_check_compatibility(self):
-"""Check if the pipeline is compatible with the dataset."""
- pass
-
- def_get_iterable(self):
-"""Return an iterable for the dataset."""
- pass
-
- def_get_progress_bar(self,iterable):
-"""Return an iterable possibly wrapped with a progress bar."""
- ifself.show_progress:
- returntqdm(iterable,desc="Processing",leave=True)
- returniterable
-
- def_get_prompt(self,index:int)->str:
-"""Get prompt from dataset."""
- returnself.dataset.get_prompt(index)
-
- def_get_watermarked_image(self,watermark:BaseWatermark,index:int,**generation_kwargs)->Union[Image.Image,List[Image.Image]]:
-"""Generate watermarked image from dataset."""
- prompt=self._get_prompt(index)
- image=watermark.generate_watermarked_media(input_data=prompt,**generation_kwargs)
- returnimage
-
- def_get_unwatermarked_image(self,watermark:BaseWatermark,index:int,**generation_kwargs)->Union[Image.Image,List[Image.Image]]:
-"""Generate or retrieve unwatermarked image from dataset."""
- ifself.unwatermarked_image_source=='natural':
- returnself.dataset.get_reference(index)
- elifself.unwatermarked_image_source=='generated':
- prompt=self._get_prompt(index)
- image=watermark.generate_unwatermarked_media(input_data=prompt,**generation_kwargs)
- returnimage
-
- def_edit_watermarked_image(self,image:Union[Image.Image,List[Image.Image]])->Union[Image.Image,List[Image.Image]]:
-"""Edit watermarked image using image editors."""
- ifisinstance(image,list):
- edited_images=[]
- forimginimage:
- forimage_editorinself.watermarked_image_editor_list:
- img=image_editor.edit(img)
- edited_images.append(img)
- returnedited_images
- else:
- forimage_editorinself.watermarked_image_editor_list:
- image=image_editor.edit(image)
- returnimage
-
- def_edit_unwatermarked_image(self,image:Union[Image.Image,List[Image.Image]])->Union[Image.Image,List[Image.Image]]:
-"""Edit unwatermarked image using image editors."""
- ifisinstance(image,list):
- edited_images=[]
- forimginimage:
- forimage_editorinself.unwatermarked_image_editor_list:
- img=image_editor.edit(img)
- edited_images.append(img)
- returnedited_images
- else:
- forimage_editorinself.unwatermarked_image_editor_list:
- image=image_editor.edit(image)
- returnimage
-
- def_prepare_dataset(self,watermark:BaseWatermark,**generation_kwargs)->DatasetForEvaluation:
-"""
- Prepare and generate all necessary data for quality analysis.
-
- This method should be overridden by subclasses to implement specific
- data preparation logic based on the analysis requirements.
-
- Parameters:
- watermark: The watermark algorithm instance.
- generation_kwargs: Additional generation parameters.
-
- Returns:
- DatasetForEvaluation object containing all prepared data.
- """
- dataset_eval=DatasetForEvaluation()
-
- # Generate all images
- bar=self._get_progress_bar(self._get_iterable())
- bar.set_description("Generating images for quality analysis")
- forindexinbar:
- # Generate and edit watermarked image
- watermarked_image=self._get_watermarked_image(watermark,index,**generation_kwargs)
- watermarked_image=self._edit_watermarked_image(watermarked_image)
-
- # Generate and edit unwatermarked image
- unwatermarked_image=self._get_unwatermarked_image(watermark,index,**generation_kwargs)
- unwatermarked_image=self._edit_unwatermarked_image(unwatermarked_image)
-
- dataset_eval.watermarked_images.append(watermarked_image)
- dataset_eval.unwatermarked_images.append(unwatermarked_image)
- ifhasattr(self,"prompt_per_image"):
- index=index//self.prompt_per_image
- dataset_eval.indexes.append(index)
- dataset_eval.prompts.append(self._get_prompt(index))
-
- ifself.reference_image_source=='natural':
- ifself.dataset.num_references>0:
- reference_image=self.dataset.get_reference(index)
- dataset_eval.reference_images.append(reference_image)
- else:
- # For text-based analyzers, add None placeholder
- dataset_eval.reference_images.append(None)
- else:
- dataset_eval.reference_images.append(unwatermarked_image)
-
- returndataset_eval
-
- def_prepare_input_for_quality_analyzer(self,
- prepared_dataset:DatasetForEvaluation):
-"""
- Prepare input for quality analyzer.
-
- Parameters:
- prepared_dataset: The prepared dataset.
- """
- pass
-
- def_store_results(self,prepared_dataset:DatasetForEvaluation):
-"""Store results."""
- os.makedirs(self.store_path,exist_ok=True)
- dataset_name=self.dataset.name
-
- for(index,watermarked_image,unwatermarked_image,prompt)inzip(prepared_dataset.indexes,prepared_dataset.watermarked_images,prepared_dataset.unwatermarked_images,prepared_dataset.prompts):
- watermarked_image.save(os.path.join(self.store_path,f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt_{index}.png"))
- unwatermarked_image.save(os.path.join(self.store_path,f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt_{index}.png"))
-
-
-[docs]
- defanalyze_quality(self,prepared_data,analyzer):
-"""Analyze quality of watermarked and unwatermarked images."""
- pass
-[docs]
-classDirectImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
-"""
- Pipeline for direct image quality analysis.
-
- This class analyzes the quality of images by directly comparing the characteristics
- of watermarked images with unwatermarked images. It evaluates metrics such as PSNR,
- SSIM, LPIPS, FID, BRISQUE without the need for any external reference image.
-
- Use this pipeline to assess the impact of watermarking on image quality directly.
- """
-
-
-
-
- def_get_iterable(self):
-"""Return an iterable for the dataset."""
- returnrange(self.dataset.num_samples)
-
- def_prepare_input_for_quality_analyzer(self,
- prepared_dataset:DatasetForEvaluation):
-"""Prepare input for quality analyzer."""
- return[(watermarked_image,unwatermarked_image)forwatermarked_image,unwatermarked_imageinzip(prepared_dataset.watermarked_images,prepared_dataset.unwatermarked_images)]
-
-
-[docs]
- defanalyze_quality(self,
- prepared_data:List[Tuple[Image.Image,Image.Image]],
- analyzer:ImageQualityAnalyzer):
-"""Analyze quality of watermarked and unwatermarked images."""
- bar=self._get_progress_bar(prepared_data)
- bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
- w_scores,u_scores=[],[]
- # For direct analyzers, we analyze each image independently
- forwatermarked_image,unwatermarked_imageinbar:
- # watermarked score
- try:
- w_score=analyzer.analyze(watermarked_image)
- exceptTypeError:
- # analyzer expects a reference -> use unwatermarked_image as reference
- w_score=analyzer.analyze(watermarked_image,unwatermarked_image)
- # unwatermarked score
- try:
- u_score=analyzer.analyze(unwatermarked_image)
- exceptTypeError:
- u_score=analyzer.analyze(unwatermarked_image,watermarked_image)
- w_scores.append(w_score)
- u_scores.append(u_score)
-
- returnw_scores,u_scores
-
-
-
-
-
-[docs]
-classReferencedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
-"""
- Pipeline for referenced image quality analysis.
-
- This pipeline assesses image quality by comparing both watermarked and unwatermarked
- images against a common reference image. It measures the degree of similarity or
- deviation from the reference.
-
- Ideal for scenarios where the impact of watermarking on image quality needs to be
- assessed, particularly in relation to specific reference images or ground truth.
- """
-
-
-
-
- def_check_compatibility(self):
-"""Check if the pipeline is compatible with the dataset."""
- # Check if we have analyzers that use text as reference
- has_text_analyzer=any(hasattr(analyzer,'reference_source')andanalyzer.reference_source=='text'
- foranalyzerinself.analyzers)
-
- # If all analyzers use text reference, we don't need reference images
- ifnothas_text_analyzerandself.dataset.num_references==0:
- raiseValueError(f"Reference images are required for referenced image quality analysis. Dataset {self.dataset.name} has no reference images.")
-
- def_get_iterable(self):
-"""Return an iterable for the dataset."""
- returnrange(self.dataset.num_samples)
-
- def_prepare_input_for_quality_analyzer(self,
- prepared_dataset:DatasetForEvaluation):
-"""Prepare input for quality analyzer."""
- return[(watermarked_image,unwatermarked_image,reference_image,prompt)
- forwatermarked_image,unwatermarked_image,reference_image,promptin
- zip(prepared_dataset.watermarked_images,prepared_dataset.unwatermarked_images,prepared_dataset.reference_images,prepared_dataset.prompts)
- ]
-
-
-[docs]
- defanalyze_quality(self,
- prepared_data:List[Tuple[Image.Image,Image.Image,Image.Image,str]],
- analyzer:ImageQualityAnalyzer):
-"""Analyze quality of watermarked and unwatermarked images."""
- bar=self._get_progress_bar(prepared_data)
- bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
- w_scores,u_scores=[],[]
- # For referenced analyzers, we compare against the reference
- forwatermarked_image,unwatermarked_image,reference_image,promptinbar:
- ifanalyzer.reference_source=="image":
- w_score=analyzer.analyze(watermarked_image,reference_image)
- u_score=analyzer.analyze(unwatermarked_image,reference_image)
- elifanalyzer.reference_source=="text":
- w_score=analyzer.analyze(watermarked_image,prompt)
- u_score=analyzer.analyze(unwatermarked_image,prompt)
- else:
- raiseValueError(f"Invalid reference source: {analyzer.reference_source}")
- w_scores.append(w_score)
- u_scores.append(u_score)
- returnw_scores,u_scores
-
-
-
-
-
-[docs]
-classGroupImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
-"""
- Pipeline for group-based image quality analysis.
-
- This pipeline analyzes quality metrics that require comparing distributions
- of multiple images (e.g., FID). It generates all images upfront and then
- performs a single analysis on the entire collection.
- """
-
-
-
-
- def_get_iterable(self):
-"""Return an iterable for analyzers instead of dataset indices."""
- returnrange(self.dataset.num_samples)
-
- def_prepare_input_for_quality_analyzer(self,
- prepared_dataset:DatasetForEvaluation):
-"""Prepare input for group analyzer."""
- return[(prepared_dataset.watermarked_images,prepared_dataset.unwatermarked_images,prepared_dataset.reference_images)]
-
-
-[docs]
- defanalyze_quality(self,
- prepared_data:List[Tuple[List[Image.Image],List[Image.Image],List[Image.Image]]],
- analyzer:ImageQualityAnalyzer):
-"""Analyze quality of image groups."""
- bar=self._get_progress_bar(prepared_data)
- bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
- w_scores,u_scores=[],[]
- # For group analyzers, we pass the entire collection
- forwatermarked_images,unwatermarked_images,reference_imagesinbar:
- w_score=analyzer.analyze(watermarked_images,reference_images)
- u_score=analyzer.analyze(unwatermarked_images,reference_images)
- w_scores.append(w_score)
- u_scores.append(u_score)
- returnw_scores,u_scores
-
-
-
-
-
-[docs]
-classRepeatImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
-"""
- Pipeline for repeat-based image quality analysis.
-
- This pipeline analyzes diversity metrics by generating multiple images
- for each prompt (e.g., LPIPS diversity). It generates multiple versions
- per prompt and analyzes the diversity within each group.
- """
-
-
-[docs]
- defanalyze_quality(self,
- prepared_data:List[Tuple[List[Image.Image],List[Image.Image]]],
- analyzer:ImageQualityAnalyzer):
-"""Analyze diversity of image batches."""
- bar=self._get_progress_bar(prepared_data)
- bar.set_description(f"Analyzing diversity for {analyzer.__class__.__name__}")
- w_scores,u_scores=[],[]
- # For diversity analyzers, we analyze each batch
- forwatermarked_images,unwatermarked_imagesinbar:
- w_score=analyzer.analyze(watermarked_images)
- u_score=analyzer.analyze(unwatermarked_images)
- w_scores.append(w_score)
- u_scores.append(u_score)
-
- returnw_scores,u_scores
-
-
-
-
-
-[docs]
-classComparedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
-"""
- Pipeline for compared image quality analysis.
-
- This pipeline directly compares watermarked and unwatermarked images
- to compute metrics like PSNR, SSIM, VIF, FSIM and MS-SSIM. The analyzer receives
- both images and outputs a single comparison score.
- """
-
-
-[docs]
-classQualityComparisonResult:
-"""Result of quality comparison."""
-
-
-[docs]
- def__init__(self,
- store_path:str,
- watermarked_quality_scores:Dict[str,List[float]],
- unwatermarked_quality_scores:Dict[str,List[float]],
- prompts:List[str],
- )->None:
-"""
- Initialize the image quality comparison result.
-
- Parameters:
- store_path: The path to store the results.
- watermarked_quality_scores: The quality scores of the watermarked image.
- unwatermarked_quality_scores: The quality scores of the unwatermarked image.
- prompts: The prompts used to generate the images.
- """
- self.store_path=store_path
- self.watermarked_quality_scores=watermarked_quality_scores
- self.unwatermarked_quality_scores=unwatermarked_quality_scores
- self.prompts=prompts
-
-
-
-
-
-[docs]
-classVideoQualityAnalysisPipeline:
-"""Pipeline for video quality analysis."""
-
-
-[docs]
- def__init__(self,
- dataset:BaseDataset,
- watermarked_video_editor_list:List[VideoEditor]=[],
- unwatermarked_video_editor_list:List[VideoEditor]=[],
- watermarked_frame_editor_list:List[ImageEditor]=[],
- unwatermarked_frame_editor_list:List[ImageEditor]=[],
- analyzers:List[VideoQualityAnalyzer]=None,
- show_progress:bool=True,
- store_path:str=None,
- return_type:QualityPipelineReturnType=QualityPipelineReturnType.MEAN_SCORES)->None:
-"""Initialize the image quality analysis pipeline.
-
- Args:
- dataset (BaseDataset): The dataset for evaluation.
- watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to [].
- unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to [].
- watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to [].
- unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to [].
- analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None.
- show_progress (bool, optional): The path to store the results. Defaults to True.
- store_path (str, optional): The path to store the results. Defaults to None.
- return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES.
- """
- self.dataset=dataset
- self.watermarked_video_editor_list=watermarked_video_editor_list
- self.unwatermarked_video_editor_list=unwatermarked_video_editor_list
- self.watermarked_frame_editor_list=watermarked_frame_editor_list
- self.unwatermarked_frame_editor_list=unwatermarked_frame_editor_list
- self.analyzers=analyzersor[]
- self.show_progress=show_progress
- self.store_path=store_path
- self.return_type=return_type
-
-
- def_check_compatibility(self):
-"""Check if the pipeline is compatible with the dataset."""
- pass
-
- def_get_iterable(self):
-"""Return an iterable for the dataset."""
- pass
-
- def_get_progress_bar(self,iterable):
-"""Return an iterable possibly wrapped with a progress bar."""
- ifself.show_progress:
- returntqdm(iterable,desc="Processing",leave=True)
- returniterable
-
- def_get_prompt(self,index:int)->str:
-"""Get prompt from dataset."""
- returnself.dataset.get_prompt(index)
-
- def_get_watermarked_video(self,watermark:BaseWatermark,index:int,**generation_kwargs)->List[Image.Image]:
-"""Generate watermarked image from dataset."""
- prompt=self._get_prompt(index)
- frames=watermark.generate_watermarked_media(input_data=prompt,**generation_kwargs)
- returnframes
-
- def_get_unwatermarked_video(self,watermark:BaseWatermark,index:int,**generation_kwargs)->List[Image.Image]:
-"""Generate or retrieve unwatermarked image from dataset."""
- prompt=self._get_prompt(index)
- frames=watermark.generate_unwatermarked_media(input_data=prompt,**generation_kwargs)
- returnframes
-
- def_edit_watermarked_video(self,frames:List[Image.Image])->List[Image.Image]:
-"""Edit watermarked image using image editors."""
- # Step 1: Edit all frames using video editors
- forvideo_editorinself.watermarked_video_editor_list:
- frames=video_editor.edit(frames)
- # Step 2: Edit individual frames using image editors
- forframe_editorinself.watermarked_frame_editor_list:
- frames=[frame_editor.edit(frame)forframeinframes]
- returnframes
-
- def_edit_unwatermarked_video(self,frames:List[Image.Image])->List[Image.Image]:
-"""Edit unwatermarked image using image editors."""
- # Step 1: Edit all frames using video editors
- forvideo_editorinself.unwatermarked_video_editor_list:
- frames=video_editor.edit(frames)
- # Step 2: Edit individual frames using image editors
- forframe_editorinself.unwatermarked_frame_editor_list:
- frames=[frame_editor.edit(frame)forframeinframes]
- returnframes
-
- def_prepare_dataset(self,watermark:BaseWatermark,**generation_kwargs)->DatasetForEvaluation:
-"""
- Prepare and generate all necessary data for quality analysis.
-
- This method should be overridden by subclasses to implement specific
- data preparation logic based on the analysis requirements.
-
- Parameters:
- watermark: The watermark algorithm instance.
- generation_kwargs: Additional generation parameters.
-
- Returns:
- DatasetForEvaluation object containing all prepared data.
- """
- dataset_eval=DatasetForEvaluation()
-
- # Generate all videos
- bar=self._get_progress_bar(self._get_iterable())
- bar.set_description("Generating videos for quality analysis")
- forindexinbar:
- # Generate and edit watermarked image
- watermarked_frames=self._get_watermarked_video(watermark,index,**generation_kwargs)
- watermarked_frames=self._edit_watermarked_video(watermarked_frames)
-
- # Generate and edit unwatermarked image
- unwatermarked_frames=self._get_unwatermarked_video(watermark,index,**generation_kwargs)
- unwatermarked_frames=self._edit_unwatermarked_video(unwatermarked_frames)
-
- dataset_eval.watermarked_videos.append(watermarked_frames)
- dataset_eval.unwatermarked_videos.append(unwatermarked_frames)
- dataset_eval.indexes.append(index)
-
- ifself.dataset.num_references>0:
- reference_frames=self.dataset.get_reference(index)
- dataset_eval.reference_videos.append(reference_frames)
-
- returndataset_eval
-
- def_prepare_input_for_quality_analyzer(self,
- watermarked_videos:List[List[Image.Image]],
- unwatermarked_videos:List[List[Image.Image]],
- reference_videos:List[List[Image.Image]]):
-""" Prepare input for quality analyzer.
-
- Args:
- watermarked_videos (List[List[Image.Image]]): Watermarked video(s)
- unwatermarked_videos (List[List[Image.Image]]): Unwatermarked video(s)
- reference_videos (List[List[Image.Image]]): Reference video if available
- """
- pass
-
- def_store_results(self,prepared_dataset:DatasetForEvaluation):
-"""Store results."""
- os.makedirs(self.store_path,exist_ok=True)
- dataset_name=self.dataset.name
-
- for(index,watermarked_video,unwatermarked_video)inzip(prepared_dataset.indexes,prepared_dataset.watermarked_videos,prepared_dataset.unwatermarked_videos):
- # unwatermarked/watermarked_video is List[Image.Image], so first make a video from the frames
- save_dir=os.path.join(self.store_path,f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt{index}")
- os.makedirs(save_dir,exist_ok=True)
- fori,frameinenumerate(watermarked_video):
- frame.save(os.path.join(save_dir,f"frame_{i}.png"))
-
- save_dir=os.path.join(self.store_path,f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt{index}")
- os.makedirs(save_dir,exist_ok=True)
- fori,frameinenumerate(unwatermarked_video):
- frame.save(os.path.join(save_dir,f"frame_{i}.png"))
-
- ifself.dataset.num_references>0:
- reference_frames=self.dataset.get_reference(index)
- save_dir=os.path.join(self.store_path,f"{self.__class__.__name__}_{dataset_name}_reference_prompt{index}")
- os.makedirs(save_dir,exist_ok=True)
- fori,frameinenumerate(reference_frames):
- frame.save(os.path.join(save_dir,f"frame_{i}.png"))
-
-
-[docs]
- defanalyze_quality(self,prepared_data,analyzer):
-"""Analyze quality of watermarked and unwatermarked images."""
- pass
-[docs]
-classDirectVideoQualityAnalysisPipeline(VideoQualityAnalysisPipeline):
-"""Pipeline for direct video quality analysis."""
-
-
-[docs]
- def__init__(self,
- dataset:BaseDataset,
- watermarked_video_editor_list:List[VideoEditor]=[],
- unwatermarked_video_editor_list:List[VideoEditor]=[],
- watermarked_frame_editor_list:List[ImageEditor]=[],
- unwatermarked_frame_editor_list:List[ImageEditor]=[],
- analyzers:List[VideoQualityAnalyzer]=None,
- show_progress:bool=True,
- store_path:str=None,
- return_type:QualityPipelineReturnType=QualityPipelineReturnType.MEAN_SCORES)->None:
-"""Initialize the video quality analysis pipeline.
-
- Args:
- dataset (BaseDataset): The dataset for evaluation.
- watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to [].
- unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to [].
- watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to [].
- unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to [].
- analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None.
- show_progress (bool, optional): Whether to show progress. Defaults to True.
- store_path (str, optional): The path to store the results. Defaults to None.
- return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES.
- """
- super().__init__(dataset,watermarked_video_editor_list,unwatermarked_video_editor_list,watermarked_frame_editor_list,unwatermarked_frame_editor_list,analyzers,show_progress,store_path,return_type)
-
-
- def_get_iterable(self):
-"""Return an iterable for the dataset."""
- returnrange(self.dataset.num_samples)
-
- def_get_prompt(self,index:int)->str:
-"""Get prompt from dataset."""
- returnself.dataset.get_prompt(index)
-
- def_get_watermarked_video(self,watermark:BaseWatermark,index:int,**generation_kwargs)->List[Image.Image]:
-"""Generate watermarked video from dataset."""
- prompt=self._get_prompt(index)
- frames=watermark.generate_watermarked_media(input_data=prompt,**generation_kwargs)
- returnframes
-
- def_get_unwatermarked_video(self,watermark:BaseWatermark,index:int,**generation_kwargs)->List[Image.Image]:
-"""Generate or retrieve unwatermarked video from dataset."""
- prompt=self._get_prompt(index)
- frames=watermark.generate_unwatermarked_media(input_data=prompt,**generation_kwargs)
- returnframes
-
- def_prepare_input_for_quality_analyzer(self,
- watermarked_videos:List[List[Image.Image]],
- unwatermarked_videos:List[List[Image.Image]],
- reference_videos:List[List[Image.Image]]):
-"""Prepare input for quality analyzer."""
- # Group videos by prompt
- returnwatermarked_videos,unwatermarked_videos
-
-
-[docs]
- defanalyze_quality(self,
- prepared_data:Tuple[List[List[Image.Image]],List[List[Image.Image]],List[List[Image.Image]]],
- analyzer:VideoQualityAnalyzer):
-"""Analyze quality of watermarked and unwatermarked videos."""
- watermarked_videos,unwatermarked_videos=prepared_data
-
- # Create pairs of watermarked and unwatermarked videos
- video_pairs=list(zip(watermarked_videos,unwatermarked_videos))
-
- bar=self._get_progress_bar(video_pairs)
- bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
- w_scores,u_scores=[],[]
- forwatermarked_video,unwatermarked_videoinbar:
- w_score=analyzer.analyze(watermarked_video)
- u_score=analyzer.analyze(unwatermarked_video)
- w_scores.append(w_score)
- u_scores.append(u_score)
- returnw_scores,u_scores
-[docs]
-classInceptionScoreCalculator(RepeatImageQualityAnalyzer):
-"""Inception Score (IS) calculator for evaluating image generation quality.
-
- Inception Score measures both the quality and diversity of generated images
- by evaluating how confidently an Inception model can classify them and how
- diverse the predictions are across the image set.
-
- Higher IS indicates better image quality and diversity (typical range: 1-10+).
- """
-
-
-[docs]
- def__init__(self,device:str="cuda",batch_size:int=32,splits:int=1):
-"""Initialize the Inception Score calculator.
-
- Args:
- device: Device to run the model on ("cuda" or "cpu")
- batch_size: Batch size for processing images
- splits: Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison.
- For calculating the mean and standard error of IS, the splits should be set greater than 1.
- If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0)
- """
- super().__init__()
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.batch_size=batch_size
- self.splits=splits
- self._load_model()
-
-
- def_load_model(self):
-"""Load the Inception v3 model for feature extraction."""
- fromtorchvisionimportmodels,transforms
-
- # Load pre-trained Inception v3 model
- self.model=models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
- self.model.aux_logits=False# Disable auxiliary output
- self.model.eval()
- self.model.to(self.device)
-
- # Keep the original classification layer for proper predictions
- # No need to modify model.fc - it should output 1000 classes
-
- # Define preprocessing pipeline for Inception v3
- self.preprocess=transforms.Compose([
- transforms.Resize((299,299)),# Inception v3 input size
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])# ImageNet statistics
- ])
-
- def_get_predictions(self,images:List[Image.Image])->np.ndarray:
-"""Extract softmax predictions from images using Inception v3.
-
- Args:
- images: List of PIL images to process
-
- Returns:
- Numpy array of shape (n_images, n_classes) containing softmax predictions
- """
- predictions_list=[]
-
- # Process images in batches for efficiency
- foriinrange(0,len(images),self.batch_size):
- batch_images=images[i:i+self.batch_size]
-
- # Preprocess batch
- batch_tensors=[]
- forimginbatch_images:
- # Ensure RGB format
- ifimg.mode!='RGB':
- img=img.convert('RGB')
- tensor=self.preprocess(img)
- batch_tensors.append(tensor)
-
- # Stack into batch tensor
- batch_tensor=torch.stack(batch_tensors).to(self.device)
-
- # Get predictions from Inception model
- withtorch.no_grad():
- logits=self.model(batch_tensor)
- # Apply softmax to get probability distributions
- probs=F.softmax(logits,dim=1)
- predictions_list.append(probs.cpu().numpy())
-
- returnnp.concatenate(predictions_list,axis=0)
-
- def_calculate_inception_score(self,predictions:np.ndarray)->tuple:
-"""Calculate Inception Score from predictions.
-
- The IS is calculated as exp(KL divergence between conditional and marginal distributions).
-
- Args:
- predictions: Softmax predictions of shape (n_images, n_classes)
-
- Returns:
- Tuple of (mean_is, std_is) across splits
- """
- # Split predictions for more stable estimation
- n_samples=predictions.shape[0]# (n_images, n_classes)
- split_size=n_samples//self.splits
-
- splits=self.splits
-
- split_scores=[]
-
- forsplit_idxinrange(splits):
- # Get current split
- start_idx=split_idx*split_size
- end_idx=(split_idx+1)*split_sizeifsplit_idx<splits-1elsen_samples# Last split gets remaining samples
- split_preds=predictions[start_idx:end_idx]
-
- # Calculate marginal distribution p(y) - average across all samples
- p_y=np.mean(split_preds,axis=0)
-
- epsilon=1e-16
- p_y_x_safe=split_preds+epsilon
- p_y_safe=p_y+epsilon
- kl_divergences=np.sum(
- p_y_x_safe*(np.log(p_y_x_safe/p_y_safe)),
- axis=1)
-
- # Inception Score for this split is exp(mean(KL divergences))
- split_score=np.exp(np.mean(kl_divergences))
- split_scores.append(split_score)
-
- # Directly return the list of scores for each split
- returnsplit_scores
-
-
-[docs]
- defanalyze(self,images:List[Image.Image],*args,**kwargs)->List[float]:
-"""Calculate Inception Score for a set of generated images.
-
- Args:
- images: List of generated images to evaluate
-
- Returns:
- List[float]: Inception Score values for each split (higher is better, typical range: 1-10+)
- """
- iflen(images)<self.splits:
- raiseValueError(f"Inception Score requires at least {self.splits} images (one per split)")
-
- iflen(images)%self.splits!=0:
- raiseValueError(f"Inception Score requires the number of images to be divisible by the number of splits")
-
- # Get predictions from Inception model
- predictions=self._get_predictions(images)
-
- # Calculate Inception Score
- split_scores=self._calculate_inception_score(predictions)
-
- # Log the standard deviation for reference (but return only mean)
- mean_score=np.mean(split_scores)
- std_score=np.std(split_scores)
- ifstd_score>0.5*mean_score:
- print(f"Warning: High standard deviation in IS calculation: {mean_score:.2f} ± {std_score:.2f}")
-
- returnsplit_scores
-
-
-
-
-[docs]
-classCLIPScoreCalculator(ReferencedImageQualityAnalyzer):
-"""CLIP score calculator for image quality analysis.
-
- Calculates CLIP similarity between an image and a reference.
- Higher scores indicate better semantic similarity.
- """
-
-
-[docs]
- def__init__(self,device:str="cuda",model_name:str="ViT-B/32",reference_source:str="image"):
-"""Initialize the CLIP Score calculator.
-
- Args:
- device: Device to run the model on ("cuda" or "cpu")
- model_name: CLIP model variant to use
- reference_source: The source of reference ('image' or 'text')
- """
- super().__init__()
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.model_name=model_name
- self.reference_source=reference_source
- self._load_model()
-[docs]
- defanalyze(self,image:Image.Image,reference:Union[Image.Image,str],*args,**kwargs)->float:
-"""Calculate CLIP similarity between image and reference.
-
- Args:
- image: Input image to evaluate
- reference: Reference image or text for comparison
- - If reference_source is 'image': expects PIL Image
- - If reference_source is 'text': expects string
-
- Returns:
- float: CLIP similarity score (0 to 1)
- """
-
- # Convert image to RGB if necessary
- ifimage.mode!='RGB':
- image=image.convert('RGB')
-
- # Preprocess image
- img_tensor=self.preprocess(image).unsqueeze(0).to(self.device)
-
- # Extract features based on reference source
- withtorch.no_grad():
- # Encode image features
- img_features=self.model.encode_image(img_tensor)
-
- # Encode reference features based on source type
- ifself.reference_source=='text':
- ifnotisinstance(reference,str):
- raiseValueError(f"Expected string reference for text mode, got {type(reference)}")
-
- # Tokenize and encode text
- text_tokens=clip.tokenize([reference]).to(self.device)
- ref_features=self.model.encode_text(text_tokens)
-
- elifself.reference_source=='image':
- ifnotisinstance(reference,Image.Image):
- raiseValueError(f"Expected PIL Image reference for image mode, got {type(reference)}")
-
- # Convert reference image to RGB if necessary
- ifreference.mode!='RGB':
- reference=reference.convert('RGB')
-
- # Preprocess and encode reference image
- ref_tensor=self.preprocess(reference).unsqueeze(0).to(self.device)
- ref_features=self.model.encode_image(ref_tensor)
-
- else:
- raiseValueError(f"Invalid reference_source: {self.reference_source}. Must be 'image' or 'text'")
-
- # Normalize features
- img_features=F.normalize(img_features,p=2,dim=1)
- ref_features=F.normalize(ref_features,p=2,dim=1)
-
- # Calculate cosine similarity
- similarity=torch.cosine_similarity(img_features,ref_features).item()
-
- # Convert to 0-1 range
- similarity=(similarity+1)/2
-
- returnsimilarity
-
-
-
-
-[docs]
-classFIDCalculator(GroupImageQualityAnalyzer):
-"""FID calculator for image quality analysis.
-
- Calculates Fréchet Inception Distance between two sets of images.
- Lower FID indicates better quality and similarity to reference distribution.
- """
-
-
-[docs]
- def__init__(self,device:str="cuda",batch_size:int=32,splits:int=1):
-"""Initialize the FID calculator.
-
- Args:
- device: Device to run the model on ("cuda" or "cpu")
- batch_size: Batch size for processing images
- splits: Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison.
- For calculating the mean and standard error of FID, the splits should be set greater than 1.
- If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0)
- """
- super().__init__()
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.batch_size=batch_size
- self.splits=splits
- self._load_model()
-
-
- def_load_model(self):
-"""Load the Inception v3 model for feature extraction."""
- fromtorchvisionimportmodels,transforms
-
- # Load Inception v3 model
- inception=models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT,init_weights=False)
- inception.fc=nn.Identity()# Remove final classification layer
- inception.aux_logits=False
- inception.eval()
- inception.to(self.device)
- self.model=inception
-
- # Define preprocessing
- self.preprocess=transforms.Compose([
- transforms.Resize((512,512)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])# ImageNet statistics
- ])
-
- def_extract_features(self,images:List[Image.Image])->np.ndarray:
-"""Extract features from a list of images.
-
- Args:
- images: List of PIL images
-
- Returns:
- Feature matrix of shape (n_images, 2048)
- """
- features_list=[]
-
- foriinrange(0,len(images),self.batch_size):
- batch_images=images[i:i+self.batch_size]
-
- # Preprocess batch
- batch_tensors=[]
- forimginbatch_images:
- ifimg.mode!='RGB':
- img=img.convert('RGB')
- tensor=self.preprocess(img)
- batch_tensors.append(tensor)
-
- batch_tensor=torch.stack(batch_tensors).to(self.device)
-
- # Extract features
- withtorch.no_grad():
- features=self.model(batch_tensor)# (batch_size, 2048)
- features_list.append(features.cpu().numpy())
-
- returnnp.concatenate(features_list,axis=0)# (n_images, 2048)
-
- def_calculate_fid(self,features1:np.ndarray,features2:np.ndarray)->float:
-"""Calculate FID between two feature sets.
-
- Args:
- features1: First feature set
- features2: Second feature set
-
- Returns:
- float: FID score
- """
- fromscipy.linalgimportsqrtm
-
- # Calculate statistics
- mu1,sigma1=features1.mean(axis=0),np.cov(features1,rowvar=False)
- mu2,sigma2=features2.mean(axis=0),np.cov(features2,rowvar=False)
-
- # Calculate FID
- diff=mu1-mu2
- covmean=sqrtm(sigma1.dot(sigma2))
-
- # Numerical stability
- ifnp.iscomplexobj(covmean):
- covmean=covmean.real
-
- fid=diff.dot(diff)+np.trace(sigma1+sigma2-2*covmean)
- returnfloat(fid)
-
-
-[docs]
- defanalyze(self,images:List[Image.Image],references:List[Image.Image],*args,**kwargs)->List[float]:
-"""Calculate FID between two sets of images.
-
- Args:
- images: Set of images to evaluate
- references: Reference set of images
-
- Returns:
- List[float]: FID values for each split
- """
- iflen(images)<2orlen(references)<2:
- raiseValueError("FID requires at least 2 images in each set")
- iflen(images)%self.splits!=0orlen(references)%self.splits!=0:
- raiseValueError("FID requires the number of images to be divisible by the number of splits")
-
- fid_scores=[]
- # Extract features
- features1=self._extract_features(images)
- features2=self._extract_features(references)
-
- # Calculate FID
- # for i in range(self.splits):
- # start_idx = i * len(images) // self.splits
- # end_idx = (i + 1) * len(images) // self.splits
- # fid_scores.append(self._calculate_fid(features1[start_idx:end_idx], features2[start_idx:end_idx]))
-
- fid_scores=self._calculate_fid(features1,features2)
-
- returnfid_scores
-
-
-
-
-[docs]
-classLPIPSAnalyzer(RepeatImageQualityAnalyzer):
-"""LPIPS analyzer for image quality analysis.
-
- Calculates perceptual diversity within a set of images.
- Higher LPIPS indicates more diverse/varied images.
- """
-
-
-[docs]
- def__init__(self,device:str="cuda",net:str="alex"):
-"""Initialize the LPIPS analyzer.
-
- Args:
- device: Device to run the model on ("cuda" or "cpu")
- net: Network to use ('alex', 'vgg', or 'squeeze')
- """
- super().__init__()
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.net=net
- self._load_model()
-
-
- def_preprocess(self,image:Image.Image)->torch.Tensor:
-"""Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
- ifimage.mode!='RGB':
- image=image.convert('RGB')
- arr=np.array(image).astype(np.float32)/255.0
- tensor=torch.from_numpy(arr).permute(2,0,1).unsqueeze(0)# BCHW
- returntensor.to(self.device)
-
-
-[docs]
- defanalyze(self,image:Image.Image,*args,**kwargs)->float:
-"""Calculate BRISQUE score for a single image.
-
- Args:
- image: PIL Image
-
- Returns:
- float: BRISQUE score (lower is better)
- """
- x=self._preprocess(image)
- withtorch.no_grad():
- score=piq.brisque(x,data_range=1.0)# piq expects [0,1]
- returnfloat(score.item())
-
-
-
-
-
-[docs]
-classVIFAnalyzer(ComparedImageQualityAnalyzer):
-"""VIF (Visual Information Fidelity) analyzer using piq.
-
- VIF compares a distorted image with a reference image to
- quantify the amount of visual information preserved.
- Higher VIF indicates better quality/similarity.
- Typical range: 0 ~ 1 (sometimes higher for good quality).
- """
-
-
-[docs]
-classFundamentalSuccessRateCalculator(BaseSuccessRateCalculator):
-"""
- Calculator for fundamental success rates of watermark detection.
-
- This class specifically handles the calculation of success rates for scenarios involving
- watermark detection after fixed thresholding. It provides metrics based on comparisons
- between expected watermarked results and actual detection outputs.
-
- Use this class when you need to evaluate the effectiveness of watermark detection algorithms
- under fixed thresholding conditions.
- """
-
-
-[docs]
- def__init__(self,labels:List[str]=['TPR','TNR','FPR','FNR','P','R','F1','ACC'])->None:
-"""
- Initialize the fundamental success rate calculator.
-
- Parameters:
- labels (List[str]): The list of metric labels to include in the output.
- """
- super().__init__(labels)
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromPILimportImage
-fromtypingimportList
-importcv2
-importnumpyasnp
-importtempfile
-importos
-importrandom
-importsubprocess
-importshutil
-
-
-[docs]
-classVideoEditor:
-"""Base class for video editors."""
-
-
-[docs]
-classMPEG4Compression(VideoEditor):
-"""MPEG-4 compression video editor."""
-
-
-[docs]
- def__init__(self,fps:float=24.0):
-"""Initialize the MPEG-4 compression video editor.
-
- Args:
- fps (float, optional): The frames per second of the compressed video. Defaults to 24.0.
- """
- self.fourcc=cv2.VideoWriter_fourcc(*'mp4v')
- self.fps=fps
-
-
-
-[docs]
- defedit(self,frames:List[Image.Image],prompt:str=None)->List[Image.Image]:
-"""Compress the video using MPEG-4 compression.
-
- Args:
- frames (List[Image.Image]): The frames to compress.
- prompt (str, optional): The prompt for video editing. Defaults to None.
-
- Returns:
- List[Image.Image]: The compressed frames.
- """
- # Transform PIL images to numpy arrays and convert to BGR format
- frame_arrays=[cv2.cvtColor(np.array(f),cv2.COLOR_RGB2BGR)forfinframes]
-
- # Get frame size
- height,width,_=frame_arrays[0].shape
-
- # Use a temporary file to save the mp4 video
- withtempfile.NamedTemporaryFile(suffix=".mp4",delete=False)astmp:
- video_path=tmp.name
-
- # Write mp4 video (MPEG-4 encoding)
- out=cv2.VideoWriter(video_path,self.fourcc,self.fps,(width,height))
-
- forframeinframe_arrays:
- out.write(frame)
- out.release()
-
- # Read mp4 video and decode back to frames
- cap=cv2.VideoCapture(video_path)
- compressed_frames=[]
- whileTrue:
- ret,frame=cap.read()
- ifnotret:
- break
- # Transform back to PIL.Image
- pil_img=Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
- compressed_frames.append(pil_img)
- cap.release()
-
- # Clean up temporary file
- os.remove(video_path)
-
- returncompressed_frames
-
-
-
-
-
-[docs]
-classVideoCodecAttack(VideoEditor):
-"""Re-encode videos with specific codecs and bitrates to simulate platform processing."""
-
- _CODEC_MAP={
- "h264":("libx264",".mp4"),
- "h265":("libx265",".mp4"),
- "hevc":("libx265",".mp4"),
- "vp9":("libvpx-vp9",".webm"),
- "av1":("libaom-av1",".mkv"),
- }
-
-
-[docs]
- def__init__(self,codec:str="h264",bitrate:str="2M",fps:float=24.0,ffmpeg_path:str=None):
-"""Initialize the codec attack editor.
-
- Args:
- codec (str, optional): Target codec (h264, h265/hevc, vp9, av1). Defaults to "h264".
- bitrate (str, optional): Target bitrate passed to ffmpeg (e.g., "2M"). Defaults to "2M".
- fps (float, optional): Frames per second used for intermediate encoding. Defaults to 24.0.
- ffmpeg_path (str, optional): Path to ffmpeg binary. If None, resolved via PATH.
- """
- self.codec=codec.lower()
- ifself.codec=="hevc":
- self.codec="h265"
- ifself.codecnotinself._CODEC_MAP:
- raiseValueError(f"Unsupported codec '{codec}'. Supported: {', '.join(self._CODEC_MAP.keys())}")
- self.bitrate=bitrate
- self.fps=fps
- self.ffmpeg_path=ffmpeg_pathorshutil.which("ffmpeg")
- ifself.ffmpeg_pathisNone:
- raiseEnvironmentError("ffmpeg executable not found. Install ffmpeg or provide ffmpeg_path.")
-[docs]
-classFrameAverage(VideoEditor):
-"""Frame average video editor."""
-
-
-[docs]
- def__init__(self,n_frames:int=3):
-"""Initialize the frame average video editor.
-
- Args:
- n_frames (int, optional): The number of frames to average. Defaults to 3.
- """
- self.n_frames=n_frames
-
-
-
-[docs]
- defedit(self,frames:List[Image.Image],prompt:str=None)->List[Image.Image]:
-"""Average frames in a window of size n_frames.
-
- Args:
- frames (List[Image.Image]): The frames to average.
- prompt (str, optional): The prompt for video editing. Defaults to None.
-
- Returns:
- List[Image.Image]: The averaged frames.
- """
- n=self.n_frames
- num_frames=len(frames)
- # Transform all PIL images to numpy arrays and convert to float32 for averaging
- arrays=[np.asarray(img).astype(np.float32)forimginframes]
- result=[]
- foriinrange(num_frames):
- # Determine current window
- start=max(0,i-n//2)
- end=min(num_frames,start+n)
- # If the end exceeds, move the window to the left
- start=max(0,end-n)
- window=arrays[start:end]
- avg=np.mean(window,axis=0).astype(np.uint8)
- result.append(Image.fromarray(avg))
- returnresult
-
-
-
-
-
-[docs]
-classFrameRateAdapter(VideoEditor):
-"""Resample videos to a target frame rate using linear interpolation."""
-
-
-[docs]
- def__init__(self,source_fps:float=30.0,target_fps:float=24.0):
-"""Initialize the frame rate adapter.
-
- Args:
- source_fps (float, optional): Original frames per second. Defaults to 30.0.
- target_fps (float, optional): Desired frames per second. Defaults to 24.0.
- """
- ifsource_fps<=0ortarget_fps<=0:
- raiseValueError("source_fps and target_fps must be positive numbers")
- self.source_fps=source_fps
- self.target_fps=target_fps
-[docs]
-classFrameSwap(VideoEditor):
-"""Frame swap video editor."""
-
-
-[docs]
- def__init__(self,p:float=0.25):
-"""Initialize the frame swap video editor.
-
- Args:
- p (float, optional): The probability of swapping neighbor frames. Defaults to 0.25.
- """
- self.p=p
-
-
-
-[docs]
- defedit(self,frames:List[Image.Image],prompt:str=None)->List[Image.Image]:
-"""Swap adjacent frames with probability p.
-
- Args:
- frames (List[Image.Image]): The frames to swap.
- prompt (str, optional): The prompt for video editing. Defaults to None.
-
- Returns:
- List[Image.Image]: The swapped frames.
- """
- fori,frameinenumerate(frames):
- ifi==0:
- continue
- ifrandom.random()>=self.p:
- frames[i-1],frames[i]=frames[i],frames[i-1]
- returnframes
-
-
-
-
-
-[docs]
-classFrameInterpolationAttack(VideoEditor):
-"""Insert interpolated frames to alter temporal sampling density."""
-
-
-[docs]
- def__init__(self,interpolated_frames:int=1):
-"""Initialize the interpolation attack editor.
-
- Args:
- interpolated_frames (int, optional): Number of synthetic frames added between consecutive original frames. Defaults to 1.
- """
- ifinterpolated_frames<0:
- raiseValueError("interpolated_frames must be non-negative")
- self.interpolated_frames=interpolated_frames
-[docs]
- defanalyze(self,frames:List[Image.Image]):
-"""Analyze video quality.
-
- Args:
- frames: List of PIL Image frames representing the video
-
- Returns:
- Quality score(s)
- """
- raiseNotImplementedError("Subclasses must implement analyze method")
-
-
-
-
-
-[docs]
-classSubjectConsistencyAnalyzer(VideoQualityAnalyzer):
-"""Analyzer for evaluating subject consistency across video frames using DINO features.
-
- This analyzer measures how consistently the main subject appears across frames by:
- 1. Extracting DINO features from each frame
- 2. Computing cosine similarity between consecutive frames and with the first frame
- 3. Averaging these similarities to get a consistency score
- """
-
-[docs]
- deftransform(self,img:Image.Image)->torch.Tensor:
-"""Transform PIL Image to tensor for DINO model."""
- transform=dino_transform_Image(224)
- returntransform(img)
-
-
-
-[docs]
- defanalyze(self,frames:List[Image.Image])->float:
-"""Analyze subject consistency across video frames.
-
- Args:
- frames: List of PIL Image frames representing the video
-
- Returns:
- Subject consistency score (higher is better, range [0, 1])
- """
- iflen(frames)<2:
- return1.0# Single frame is perfectly consistent with itself
-
- video_sim=0.0
- frame_count=0
-
- # Process frames and extract features
- withtorch.no_grad():
- fori,frameinenumerate(frames):
- # Transform and prepare frame
- frame_tensor=self.transform(frame).unsqueeze(0).to(self.device)
-
- # Extract features
- features=self.model(frame_tensor)
- features=F.normalize(features,dim=-1,p=2)
-
- ifi==0:
- # Store first frame features
- first_frame_features=features
- else:
- # Compute similarity with previous frame
- sim_prev=max(0.0,F.cosine_similarity(prev_features,features).item())
-
- # Compute similarity with first frame
- sim_first=max(0.0,F.cosine_similarity(first_frame_features,features).item())
-
- # Average the two similarities
- frame_sim=(sim_prev+sim_first)/2.0
- video_sim+=frame_sim
- frame_count+=1
-
- # Store current features as previous for next iteration
- prev_features=features
-
- # Return average similarity across all frame pairs
- ifframe_count>0:
- returnvideo_sim/frame_count
- else:
- return1.0
-
-
-
-
-
-[docs]
-classMotionSmoothnessAnalyzer(VideoQualityAnalyzer):
-"""Analyzer for evaluating motion smoothness in videos using AMT-S model.
-
- This analyzer measures motion smoothness by:
- 1. Extracting frames at even indices from the video
- 2. Using AMT-S model to interpolate between consecutive frames
- 3. Comparing interpolated frames with actual frames to compute smoothness score
-
- The score represents how well the motion can be predicted/interpolated,
- with smoother motion resulting in higher scores.
- """
-
-
-[docs]
- def__init__(self,model_path:str="model/amt/amt-s.pth",
- device:str="cuda",niters:int=1):
-"""Initialize the MotionSmoothnessAnalyzer.
-
- Args:
- model_path: Path to the AMT-S model checkpoint
- device: Device to run the model on ('cuda' or 'cpu')
- niters: Number of interpolation iterations (default: 1)
- """
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.niters=niters
-
- # Initialize model parameters
- self._initialize_params()
-
- # Load AMT-S model
- self.model=self._load_amt_model(model_path)
- self.model.eval()
- self.model.to(self.device)
-
-
- def_initialize_params(self):
-"""Initialize parameters for video processing."""
- ifself.device.type=='cuda':
- self.anchor_resolution=1024*512
- self.anchor_memory=1500*1024**2
- self.anchor_memory_bias=2500*1024**2
- self.vram_avail=torch.cuda.get_device_properties(self.device).total_memory
- else:
- # Do not resize in cpu mode
- self.anchor_resolution=8192*8192
- self.anchor_memory=1
- self.anchor_memory_bias=0
- self.vram_avail=1
-
- # Time embedding for interpolation (t=0.5)
- self.embt=torch.tensor(1/2).float().view(1,1,1,1).to(self.device)
-
- def_load_amt_model(self,model_path:str):
-"""Load AMT-S model.
-
- Args:
- model_path: Path to the model checkpoint
-
- Returns:
- Loaded AMT-S model
- """
- # Import AMT-S model (note the hyphen in filename)
- importsys
- importimportlib.util
-
- # Load the module with hyphen in filename
- spec=importlib.util.spec_from_file_location("amt_s","model/amt/networks/AMT-S.py")
- amt_s_module=importlib.util.module_from_spec(spec)
- spec.loader.exec_module(amt_s_module)
- Model=amt_s_module.Model
-
- # Create model with default parameters
- model=Model(
- corr_radius=3,
- corr_lvls=4,
- num_flows=3
- )
-
- # Load checkpoint
- ifos.path.exists(model_path):
- ckpt=torch.load(model_path,map_location="cpu",weights_only=False)
- model.load_state_dict(ckpt['state_dict'])
-
- returnmodel
-
- def_extract_frames(self,frames:List[Image.Image],start_from:int=0)->List[np.ndarray]:
-"""Extract frames at even indices starting from start_from.
-
- Args:
- frames: List of PIL Image frames
- start_from: Starting index (default: 0)
-
- Returns:
- List of extracted frames as numpy arrays
- """
- extracted=[]
- foriinrange(start_from,len(frames),2):
- # Convert PIL Image to numpy array
- frame_np=np.array(frames[i])
- extracted.append(frame_np)
- returnextracted
-
- def_img2tensor(self,img:np.ndarray)->torch.Tensor:
-"""Convert numpy image to tensor.
-
- Args:
- img: Image as numpy array (H, W, C)
-
- Returns:
- Image tensor (1, C, H, W)
- """
- frommodel.amt.utils.utilsimportimg2tensor
- returnimg2tensor(img)
-
- def_tensor2img(self,tensor:torch.Tensor)->np.ndarray:
-"""Convert tensor to numpy image.
-
- Args:
- tensor: Image tensor (1, C, H, W)
-
- Returns:
- Image as numpy array (H, W, C)
- """
- frommodel.amt.utils.utilsimporttensor2img
- returntensor2img(tensor)
-
- def_check_dim_and_resize(self,tensor_list:List[torch.Tensor])->List[torch.Tensor]:
-"""Check dimensions and resize tensors if needed.
-
- Args:
- tensor_list: List of image tensors
-
- Returns:
- List of resized tensors
- """
- frommodel.amt.utils.utilsimportcheck_dim_and_resize
- returncheck_dim_and_resize(tensor_list)
-
- def_calculate_scale(self,h:int,w:int)->float:
-"""Calculate scaling factor based on available VRAM.
-
- Args:
- h: Height of the image
- w: Width of the image
-
- Returns:
- Scaling factor
- """
- scale=self.anchor_resolution/(h*w)*np.sqrt((self.vram_avail-self.anchor_memory_bias)/self.anchor_memory)
- scale=1ifscale>1elsescale
- scale=1/np.floor(1/np.sqrt(scale)*16)*16
- returnscale
-
- def_interpolate_frames(self,inputs:List[torch.Tensor],scale:float)->List[torch.Tensor]:
-"""Interpolate frames using AMT-S model.
-
- Args:
- inputs: List of input frame tensors
- scale: Scaling factor for processing
-
- Returns:
- List of interpolated frame tensors
- """
- frommodel.amt.utils.utilsimportInputPadder
-
- # Pad inputs
- padding=int(16/scale)
- padder=InputPadder(inputs[0].shape,padding)
- inputs=padder.pad(*inputs)
-
- # Perform interpolation for specified iterations
- foriinrange(self.niters):
- outputs=[inputs[0]]
- forin_0,in_1inzip(inputs[:-1],inputs[1:]):
- in_0=in_0.to(self.device)
- in_1=in_1.to(self.device)
- withtorch.no_grad():
- imgt_pred=self.model(in_0,in_1,self.embt,scale_factor=scale,eval=True)['imgt_pred']
- outputs+=[imgt_pred.cpu(),in_1.cpu()]
- inputs=outputs
-
- # Unpad outputs
- outputs=padder.unpad(*outputs)
- returnoutputs
-
- def_compute_frame_difference(self,img1:np.ndarray,img2:np.ndarray)->float:
-"""Compute average absolute difference between two images.
-
- Args:
- img1: First image
- img2: Second image
-
- Returns:
- Average pixel difference
- """
- diff=cv2.absdiff(img1,img2)
- returnnp.mean(diff)
-
- def_compute_vfi_score(self,original_frames:List[np.ndarray],interpolated_frames:List[np.ndarray])->float:
-"""Compute video frame interpolation score.
-
- Args:
- original_frames: Original video frames
- interpolated_frames: Interpolated frames
-
- Returns:
- VFI score (lower difference means better interpolation)
- """
- # Extract frames at odd indices for comparison
- ori_compare=self._extract_frames([Image.fromarray(f)forfinoriginal_frames],start_from=1)
- interp_compare=self._extract_frames([Image.fromarray(f)forfininterpolated_frames],start_from=1)
-
- scores=[]
- forori,interpinzip(ori_compare,interp_compare):
- score=self._compute_frame_difference(ori,interp)
- scores.append(score)
-
- returnnp.mean(scores)ifscoreselse0.0
-
-
-[docs]
- defanalyze(self,frames:List[Image.Image])->float:
-"""Analyze motion smoothness in video frames.
-
- Args:
- frames: List of PIL Image frames representing the video
-
- Returns:
- Motion smoothness score (higher is better, range [0, 1])
- """
- iflen(frames)<2:
- return1.0# Single frame has perfect smoothness
-
- # Convert PIL Images to numpy arrays
- np_frames=[np.array(frame)forframeinframes]
-
- # Extract frames at even indices
- frame_list=self._extract_frames(frames,start_from=0)
-
- # Convert to tensors
- inputs=[self._img2tensor(frame).to(self.device)forframeinframe_list]
-
- iflen(inputs)<=1:
- return1.0# Not enough frames for interpolation
-
- # Check dimensions and resize if needed
- inputs=self._check_dim_and_resize(inputs)
- h,w=inputs[0].shape[-2:]
-
- # Calculate scale based on available memory
- scale=self._calculate_scale(h,w)
-
- # Perform frame interpolation
- outputs=self._interpolate_frames(inputs,scale)
-
- # Convert outputs back to images
- output_images=[self._tensor2img(out)foroutinoutputs]
-
- # Compute VFI score
- vfi_score=self._compute_vfi_score(np_frames,output_images)
-
- # Normalize score to [0, 1] range (higher is better)
- # Original score is average pixel difference [0, 255], we normalize and invert
- normalized_score=(255.0-vfi_score)/255.0
-
- returnnormalized_score
-
-
-
-
-
-[docs]
-classDynamicDegreeAnalyzer(VideoQualityAnalyzer):
-"""Analyzer for evaluating dynamic degree (motion intensity) in videos using RAFT optical flow.
-
- This analyzer measures the amount and intensity of motion in videos by:
- 1. Computing optical flow between consecutive frames using RAFT
- 2. Calculating flow magnitude for each pixel
- 3. Extracting top 5% highest flow magnitudes
- 4. Determining if video has sufficient dynamic motion based on thresholds
-
- The score represents whether the video contains dynamic motion (1.0) or is mostly static (0.0).
- """
-
-
-[docs]
- def__init__(self,model_path:str="model/raft/raft-things.pth",
- device:str="cuda",sample_fps:int=8):
-"""Initialize the DynamicDegreeAnalyzer.
-
- Args:
- model_path: Path to the RAFT model checkpoint
- device: Device to run the model on ('cuda' or 'cpu')
- sample_fps: Target FPS for frame sampling (default: 8)
- """
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
- self.sample_fps=sample_fps
-
- # Load RAFT model
- self.model=self._load_raft_model(model_path)
- self.model.eval()
- self.model.to(self.device)
-
-
- def_load_raft_model(self,model_path:str):
-"""Load RAFT optical flow model.
-
- Args:
- model_path: Path to the model checkpoint
-
- Returns:
- Loaded RAFT model
- """
- frommodel.raft.core.raftimportRAFT
- fromeasydictimportEasyDictasedict
-
- # Configure RAFT arguments
- args=edict({
- "model":model_path,
- "small":False,
- "mixed_precision":False,
- "alternate_corr":False
- })
-
- # Create and load model
- model=RAFT(args)
-
- ifos.path.exists(model_path):
- ckpt=torch.load(model_path,map_location="cpu")
- # Remove 'module.' prefix if present (from DataParallel)
- new_ckpt={k.replace('module.',''):vfork,vinckpt.items()}
- model.load_state_dict(new_ckpt)
-
- returnmodel
-
- def_extract_frames_for_flow(self,frames:List[Image.Image],target_fps:int=8)->List[torch.Tensor]:
-"""Extract and prepare frames for optical flow computation.
-
- Args:
- frames: List of PIL Image frames
- target_fps: Target sampling rate (default: 8 fps)
-
- Returns:
- List of prepared frame tensors
- """
- # Estimate original FPS and calculate sampling interval
- # Assuming 30fps original video, adjust sampling to get ~8fps
- total_frames=len(frames)
- assumed_fps=30# Common video fps
- interval=max(1,round(assumed_fps/target_fps))
-
- # Sample frames at interval
- sampled_frames=[]
- foriinrange(0,total_frames,interval):
- frame=frames[i]
- # Convert PIL to numpy array
- frame_np=np.array(frame)
- # Convert to tensor and normalize
- frame_tensor=torch.from_numpy(frame_np.astype(np.uint8)).permute(2,0,1).float()
- frame_tensor=frame_tensor[None].to(self.device)
- sampled_frames.append(frame_tensor)
-
- returnsampled_frames
-
- def_compute_flow_magnitude(self,flow:torch.Tensor)->float:
-"""Compute flow magnitude score from optical flow.
-
- Args:
- flow: Optical flow tensor (B, 2, H, W)
-
- Returns:
- Flow magnitude score
- """
- # Extract flow components
- flow_np=flow[0].permute(1,2,0).cpu().numpy()
- u=flow_np[:,:,0]
- v=flow_np[:,:,1]
-
- # Compute flow magnitude
- magnitude=np.sqrt(np.square(u)+np.square(v))
-
- # Get top 5% highest magnitudes
- h,w=magnitude.shape
- magnitude_flat=magnitude.flatten()
- cut_index=int(h*w*0.05)
-
- # Sort in descending order and take mean of top 5%
- top_magnitudes=np.sort(-magnitude_flat)[:cut_index]
- mean_magnitude=np.mean(np.abs(top_magnitudes))
-
- returnmean_magnitude.item()
-
- def_determine_dynamic_threshold(self,frame_shape:tuple,num_frames:int)->dict:
-"""Determine thresholds for dynamic motion detection.
-
- Args:
- frame_shape: Shape of the frame tensor
- num_frames: Number of frames in the video
-
- Returns:
- Dictionary with threshold parameters
- """
- # Scale threshold based on image resolution
- scale=min(frame_shape[-2:])# min of height and width
- magnitude_threshold=6.0*(scale/256.0)
-
- # Scale count threshold based on number of frames
- count_threshold=round(4*(num_frames/16.0))
-
- return{
- "magnitude_threshold":magnitude_threshold,
- "count_threshold":count_threshold
- }
-
- def_check_dynamic_motion(self,flow_scores:List[float],thresholds:dict)->bool:
-"""Check if video has dynamic motion based on flow scores.
-
- Args:
- flow_scores: List of optical flow magnitude scores
- thresholds: Threshold parameters
-
- Returns:
- True if video has dynamic motion, False otherwise
- """
- magnitude_threshold=thresholds["magnitude_threshold"]
- count_threshold=thresholds["count_threshold"]
-
- # Count frames with significant motion
- motion_count=0
- forscoreinflow_scores:
- ifscore>magnitude_threshold:
- motion_count+=1
- ifmotion_count>=count_threshold:
- returnTrue
-
- returnFalse
-
-
-[docs]
- defanalyze(self,frames:List[Image.Image])->float:
-"""Analyze dynamic degree (motion intensity) in video frames.
-
- Args:
- frames: List of PIL Image frames representing the video
-
- Returns:
- Dynamic degree score: 1.0 if video has dynamic motion, 0.0 if mostly static
- """
- iflen(frames)<2:
- return0.0# Cannot compute optical flow with less than 2 frames
-
- # Extract and prepare frames for optical flow
- prepared_frames=self._extract_frames_for_flow(frames,self.sample_fps)
-
- iflen(prepared_frames)<2:
- return0.0
-
- # Determine thresholds based on video characteristics
- thresholds=self._determine_dynamic_threshold(
- prepared_frames[0].shape,
- len(prepared_frames)
- )
-
- # Compute optical flow between consecutive frames
- flow_scores=[]
-
- withtorch.no_grad():
- forframe1,frame2inzip(prepared_frames[:-1],prepared_frames[1:]):
- # Pad frames if necessary
- frommodel.raft.core.utils_core.utilsimportInputPadder
- padder=InputPadder(frame1.shape)
- frame1_padded,frame2_padded=padder.pad(frame1,frame2)
-
- # Compute optical flow
- _,flow_up=self.model(frame1_padded,frame2_padded,iters=20,test_mode=True)
-
- # Calculate flow magnitude score
- magnitude_score=self._compute_flow_magnitude(flow_up)
- flow_scores.append(magnitude_score)
-
- # Check if video has dynamic motion
- has_dynamic_motion=self._check_dynamic_motion(flow_scores,thresholds)
-
- # Return binary score: 1.0 for dynamic, 0.0 for static
- return1.0ifhas_dynamic_motionelse0.0
-
-
-
-
-
-[docs]
-classBackgroundConsistencyAnalyzer(VideoQualityAnalyzer):
-"""Analyzer for evaluating background consistency across video frames using CLIP features.
-
- This analyzer measures how consistently the background appears across frames by:
- 1. Extracting CLIP visual features from each frame
- 2. Computing cosine similarity between consecutive frames and with the first frame
- 3. Averaging these similarities to get a consistency score
-
- Similar to SubjectConsistencyAnalyzer but focuses on overall visual consistency
- including background elements, making it suitable for detecting background stability.
- """
-
-
-[docs]
- def__init__(self,model_name:str="ViT-B/32",device:str="cuda"):
-"""Initialize the BackgroundConsistencyAnalyzer.
-
- Args:
- model_name: CLIP model name (default: "ViT-B/32")
- device: Device to run the model on ('cuda' or 'cpu')
- """
- self.device=torch.device(deviceiftorch.cuda.is_available()else"cpu")
-
- # Load CLIP model
- self.model,self.preprocess=self._load_clip_model(model_name)
- self.model.eval()
- self.model.to(self.device)
-
- # Image transform for CLIP (when processing tensor inputs)
- self.tensor_transform=self._get_clip_tensor_transform(224)
-
-
- def_load_clip_model(self,model_name:str):
-"""Load CLIP model.
-
- Args:
- model_name: Name of the CLIP model to load
-
- Returns:
- Tuple of (model, preprocess_function)
- """
- importclip
-
- model,preprocess=clip.load(model_name,device=self.device)
- returnmodel,preprocess
-
- def_get_clip_tensor_transform(self,n_px:int):
-"""Get CLIP transform for tensor inputs.
-
- Args:
- n_px: Target image size
-
- Returns:
- Transform composition for tensor inputs
- """
- try:
- fromtorchvision.transformsimportInterpolationMode
- BICUBIC=InterpolationMode.BICUBIC
- exceptImportError:
- BICUBIC=Image.BICUBIC
-
- returnCompose([
- Resize(n_px,interpolation=BICUBIC,antialias=False),
- CenterCrop(n_px),
- transforms.Lambda(lambdax:x.float().div(255.0)),
- Normalize((0.48145466,0.4578275,0.40821073),(0.26862954,0.26130258,0.27577711)),
- ])
-
- def_prepare_images_for_clip(self,frames:List[Image.Image])->torch.Tensor:
-"""Prepare PIL images for CLIP processing.
-
- Args:
- frames: List of PIL Image frames
-
- Returns:
- Batch tensor of preprocessed images
- """
- # Use CLIP's built-in preprocess for PIL images
- images=[]
- forframeinframes:
- processed=self.preprocess(frame)
- images.append(processed)
-
- # Stack into batch tensor
- returntorch.stack(images).to(self.device)
-
-
-[docs]
- defanalyze(self,frames:List[Image.Image])->float:
-"""Analyze background consistency across video frames.
-
- Args:
- frames: List of PIL Image frames representing the video
-
- Returns:
- Background consistency score (higher is better, range [0, 1])
- """
- iflen(frames)<2:
- return1.0# Single frame is perfectly consistent with itself
-
- # Prepare images for CLIP
- images=self._prepare_images_for_clip(frames)
-
- # Extract CLIP features
- withtorch.no_grad():
- image_features=self.model.encode_image(images)
- image_features=F.normalize(image_features,dim=-1,p=2)
-
- video_sim=0.0
- frame_count=0
-
- # Compute similarity between frames
- foriinrange(len(image_features)):
- image_feature=image_features[i].unsqueeze(0)
-
- ifi==0:
- # Store first frame features
- first_image_feature=image_feature
- else:
- # Compute similarity with previous frame
- sim_prev=max(0.0,F.cosine_similarity(former_image_feature,image_feature).item())
-
- # Compute similarity with first frame
- sim_first=max(0.0,F.cosine_similarity(first_image_feature,image_feature).item())
-
- # Average the two similarities
- frame_sim=(sim_prev+sim_first)/2.0
- video_sim+=frame_sim
- frame_count+=1
-
- # Store current features as previous for next iteration
- former_image_feature=image_feature
-
- # Return average similarity across all frame pairs
- ifframe_count>0:
- returnvideo_sim/frame_count
- else:
- return1.0
-
-
-
-
-[docs]
-classImagingQualityAnalyzer(VideoQualityAnalyzer):
-"""Analyzer for evaluating imaging quality of videos.
-
- This analyzer measures the quality of videos by:
- 1. Inputting frames into MUSIQ image quality predictor
- 2. Determining if the video is blurry or has artifacts
-
- The score represents the quality of the video (higher is better).
- """
-
-[docs]
- def__init__(self,save_every_n_steps:int=1,to_cpu:bool=True):
-"""Initialize the latents collector.
-
- Args:
- save_every_n_steps (int, optional): Save latents every n steps. Defaults to 1.
- to_cpu (bool, optional): Whether to move latents to CPU. Defaults to True.
- """
-
- self.save_every_n_steps=save_every_n_steps
- self.to_cpu=to_cpu
- self.data=[]
- self._call_count=0
-
-
- def__call__(self,step:int,timestep:int,latents:torch.Tensor):
- self._call_count+=1
-
- ifself._call_count%self.save_every_n_steps==0:
- latents_to_save=latents.clone()
- ifself.to_cpu:
- latents_to_save=latents_to_save.cpu()
-
- self.data.append({
- 'step':step,
- 'timestep':timestep,
- 'latents':latents_to_save,
- 'call_count':self._call_count
- })
-
- @property
- deflatents_list(self)->List[torch.Tensor]:
-"""Return the list of latents."""
- return[item['latents']foriteminself.data]
-
- @property
- deftimesteps_list(self)->List[int]:
-"""Return the list of timesteps."""
- return[item['timestep']foriteminself.data]
-
-
-[docs]
- defget_latents_at_step(self,step:int)->torch.Tensor:
-"""Get the latents at a specific step."""
- foriteminself.data:
- ifitem['step']==step:
- returnitem['latents']
- raiseValueError(f"No latents found for step {step}")
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromdataclassesimportdataclass
-fromtypingimportOptional,Union,Any,Dict
-importtorch
-fromdiffusersimportDPMSolverMultistepScheduler,StableDiffusionPipeline,TextToVideoSDPipeline,StableVideoDiffusionPipeline
-fromutils.pipeline_utilsimport(
- get_pipeline_type,
- PIPELINE_TYPE_IMAGE,
- PIPELINE_TYPE_TEXT_TO_VIDEO,
- PIPELINE_TYPE_IMAGE_TO_VIDEO
-)
-
-
-[docs]
-@dataclass
-classDiffusionConfig:
-"""Configuration class for diffusion models and parameters."""
-
-
-[docs]
-defpil_to_cv2(pil_img:Image.Image)->np.ndarray:
-"""Convert PIL image to cv2 format (numpy array)."""
- returnnp.asarray(pil_img)/255.0
-
-
-
-[docs]
-deftransform_to_model_format(media:Union[Image.Image,List[Image.Image],np.ndarray,torch.Tensor],
- target_size:Optional[int]=None)->torch.Tensor:
-"""
- Transform image or video frames to model input format.
- For image, `media` is a PIL image that will be resized to `target_size`(if provided) and then normalized to [-1, 1] and permuted to [C, H, W] from [H, W, C].
- For video, `media` is a list of frames (PIL images or numpy arrays) that will be normalized to [-1, 1] and permuted to [F, C, H, W] from [F, H, W, C].
-
- Args:
- media: PIL image or list of frames or video tensor
- target_size: Target size for resize operations (for images)
-
- Returns:
- torch.Tensor: Normalized tensor ready for model input
- """
- ifisinstance(media,Image.Image):
- # Single image
- iftarget_sizeisnotNone:
- transform=transforms.Compose([
- transforms.Resize(target_size),
- transforms.CenterCrop(target_size),
- transforms.ToTensor(),
- ])
- else:
- transform=transforms.ToTensor()
- return2.0*transform(media)-1.0
-
- elifisinstance(media,list):
- # List of frames (PIL images or numpy arrays)
- ifall(isinstance(frame,Image.Image)forframeinmedia):
- returntorch.stack([2.0*transforms.ToTensor()(frame)-1.0forframeinmedia])
- elifall(isinstance(frame,np.ndarray)forframeinmedia):
- returntorch.stack([2.0*transforms.ToTensor()(numpy_to_pil(frame))-1.0forframeinmedia])
- else:
- raiseValueError("All frames must be either PIL images or numpy arrays")
-
- elifisinstance(media,np.ndarray)andmedia.ndim>=3:
- # Video numpy array
- ifmedia.ndim==3:# Single frame: H, W, C
- return2.0*transforms.ToTensor()(media)-1.0
- elifmedia.ndim==4:# Multiple frames: F, H, W, C
- returntorch.stack([2.0*transforms.ToTensor()(frame)-1.0forframeinmedia])
- else:
- raiseValueError(f"Unsupported numpy array shape: {media.shape}")
-
- else:
- raiseValueError(f"Unsupported media type: {type(media)}")
-
-
-# ===== Image-Specific Functions =====
-
-def_get_image_latents(pipe:StableDiffusionPipeline,image:torch.Tensor,
- sample:bool=True,rng_generator:Optional[torch.Generator]=None,
- decoder_inv:bool=False)->torch.Tensor:
-"""Get the image latents for the given image."""
- encoding_dist=pipe.vae.encode(image).latent_dist
- ifsample:
- encoding=encoding_dist.sample(generator=rng_generator)
- else:
- encoding=encoding_dist.mode()
- latents=encoding*0.18215
- ifdecoder_inv:
- latents=decoder_inv_optimization(pipe,latents,image)
- returnlatents
-
-def_decode_image_latents(pipe:StableDiffusionPipeline,latents:torch.FloatTensor)->torch.Tensor:
-"""Decode the image from the given latents."""
- scaled_latents=1/0.18215*latents
- image=pipe.vae.decode(scaled_latents,return_dict=False)[0]
- image=(image/2+0.5).clamp(0,1)
- returnimage
-
-
-[docs]
-defdecoder_inv_optimization(pipe:StableDiffusionPipeline,latents:torch.FloatTensor,
- image:torch.FloatTensor,num_steps:int=100)->torch.Tensor:
-"""
- Optimize latents to better reconstruct the input image by minimizing the error between
- decoded latents and original image.
-
- Args:
- pipe: The diffusion pipeline
- latents: Initial latents
- image: Target image
- num_steps: Number of optimization steps
-
- Returns:
- torch.Tensor: Optimized latents
- """
- input_image=image.clone().float()
- z=latents.clone().float().detach()
- z.requires_grad_(True)
-
- loss_function=torch.nn.MSELoss(reduction='sum')
- optimizer=torch.optim.Adam([z],lr=0.1)
- lr_scheduler=get_cosine_schedule_with_warmup(optimizer,num_warmup_steps=10,num_training_steps=num_steps)
-
- foriintqdm(range(num_steps)):
- # Decode without normalization to match original implementation
- scaled_latents=1/0.18215*z
- x_pred=pipe.vae.decode(scaled_latents,return_dict=False)[0]
-
- loss=loss_function(x_pred,input_image)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- lr_scheduler.step()
-
- returnz.detach()
-
-
-# ===== Video-Specific Functions =====
-
-def_get_video_latents(pipe:Union[TextToVideoSDPipeline,StableVideoDiffusionPipeline],
- video_frames:torch.Tensor,sample:bool=True,
- rng_generator:Optional[torch.Generator]=None,
- permute:bool=True,
- decoder_inv:bool=False)->torch.Tensor:
-"""
- Encode video frames to latents.
-
- Args:
- pipe: Video diffusion pipeline
- video_frames: Tensor of video frames [F, C, H, W]
- sample: Whether to sample from the latent distribution
- rng_generator: Random generator for sampling
- permute: Whether to permute the latents to [B, C, F, H, W] format
- decoder_inv: Whether to decode the latents
-
- Returns:
- torch.Tensor: Video latents
- """
- encoding_dist=pipe.vae.encode(video_frames).latent_dist
- ifsample:
- encoding=encoding_dist.sample(generator=rng_generator)
- else:
- encoding=encoding_dist.mode()
- latents=(encoding*0.18215).unsqueeze(0)
- ifpermute:
- latents=latents.permute(0,2,1,3,4)
- ifdecoder_inv:# TODO: Implement decoder inversion for video latents
- raiseNotImplementedError("Decoder inversion is not implemented for video latents")
- returnlatents
-
-
-[docs]
-deftensor2vid(video:torch.Tensor,processor,output_type:str="np"):
-"""
- Convert video tensor to desired output format.
-
- Args:
- video: Video tensor [B, C, F, H, W]
- processor: Video processor from the diffusion pipeline
- output_type: Output type - 'np', 'pt', or 'pil'
-
- Returns:
- Video in requested format
- """
- batch_size,channels,num_frames,height,width=video.shape
- outputs=[]
- forbatch_idxinrange(batch_size):
- batch_vid=video[batch_idx].permute(1,0,2,3)
- batch_output=processor.postprocess(batch_vid,output_type)
- outputs.append(batch_output)
-
- ifoutput_type=="np":
- outputs=np.stack(outputs)
- elifoutput_type=="pt":
- outputs=torch.stack(outputs)
- elifnotoutput_type=="pil":
- raiseValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
-
- returnoutputs
-
-
-def_decode_video_latents(pipe:Union[TextToVideoSDPipeline,StableVideoDiffusionPipeline],
- latents:torch.Tensor,
- num_frames:Optional[int]=None)->np.ndarray:
-"""
- Decode latents to video frames.
-
- Args:
- pipe: Video diffusion pipeline
- latents: Video latents
- num_frames: Number of frames to decode
-
- Returns:
- np.ndarray: Video frames
- """
- ifnum_framesisNone:
- video_tensor=pipe.decode_latents(latents)
- else:
- video_tensor=pipe.decode_latents(latents,num_frames)
- video=tensor2vid(video_tensor,pipe.video_processor)
- returnvideo
-
-
-[docs]
-defconvert_video_frames_to_images(frames:List[Union[np.ndarray,Image.Image]])->List[Image.Image]:
-"""
- Convert video frames to a list of PIL.Image objects.
-
- Args:
- frames: List of video frames (numpy arrays or PIL images)
-
- Returns:
- List[Image.Image]: List of PIL images
- """
- pil_frames=[]
- forframeinframes:
- ifisinstance(frame,np.ndarray):
- # Convert numpy array to PIL
- pil_frames.append(numpy_to_pil(frame))
- elifisinstance(frame,Image.Image):
- # Already a PIL image
- pil_frames.append(frame)
- else:
- raiseValueError(f"Unsupported frame type: {type(frame)}")
- returnpil_frames
-
-
-
-[docs]
-defsave_video_frames(frames:List[Union[np.ndarray,Image.Image]],save_dir:str)->None:
-"""
- Save video frames to a directory.
-
- Args:
- frames: List of video frames (numpy arrays or PIL images)
- save_dir: Directory to save frames
- """
- ifisinstance(frames[0],np.ndarray):
- frames=[(frame*255).astype(np.uint8)ifframe.dtype!=np.uint8elseframeforframeinframes]
- elifisinstance(frames[0],Image.Image):
- frames=[np.array(frame)forframeinframes]
-
- fori,frameinenumerate(frames):
- img=cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
- cv2.imwrite(f'{save_dir}/{i:02d}.png',img)
-
-
-# ===== Utility Functions for Working with Different Pipeline Types =====
-
-
-[docs]
-defget_media_latents(pipe:Union[StableDiffusionPipeline,TextToVideoSDPipeline,StableVideoDiffusionPipeline],
- media:Union[torch.Tensor,List[torch.Tensor]],
- sample:bool=True,
- rng_generator:Optional[torch.Generator]=None,
- decoder_inv:bool=False)->torch.Tensor:
-"""
- Get latents from media (either image or video) based on pipeline type.
-
- Args:
- pipe: Diffusion pipeline
- media: Image tensor or video frames tensor
- sample: Whether to sample from the latent distribution
- rng_generator: Random generator for sampling
- decoder_inv: Whether to use decoder inversion optimization
- Returns:
- torch.Tensor: Media latents
- """
- pipeline_type=get_pipeline_type(pipe)
-
- ifpipeline_type==PIPELINE_TYPE_IMAGE:
- return_get_image_latents(pipe,media,sample,rng_generator,decoder_inv)
- elifpipeline_typein[PIPELINE_TYPE_TEXT_TO_VIDEO,PIPELINE_TYPE_IMAGE_TO_VIDEO]:
- permute=pipeline_type==PIPELINE_TYPE_TEXT_TO_VIDEO
- return_get_video_latents(pipe,media,sample,rng_generator,permute,decoder_inv)
- else:
- raiseValueError(f"Unsupported pipeline type: {pipeline_type}")
-
-
-
-[docs]
-defdecode_media_latents(pipe:Union[StableDiffusionPipeline,TextToVideoSDPipeline,StableVideoDiffusionPipeline],
- latents:torch.Tensor,
- num_frames:Optional[int]=None)->Union[torch.Tensor,np.ndarray]:
-"""
- Decode latents to media (either image or video) based on pipeline type.
-
- Args:
- pipe: Diffusion pipeline
- latents: Media latents
- num_frames: Number of frames (for video)
-
- Returns:
- Union[torch.Tensor, np.ndarray]: Decoded media
- """
- pipeline_type=get_pipeline_type(pipe)
-
- ifpipeline_type==PIPELINE_TYPE_IMAGE:
- return_decode_image_latents(pipe,latents)
- elifpipeline_typein[PIPELINE_TYPE_TEXT_TO_VIDEO,PIPELINE_TYPE_IMAGE_TO_VIDEO]:
- return_decode_video_latents(pipe,latents,num_frames)
- else:
- raiseValueError(f"Unsupported pipeline type: {pipeline_type}")
-[docs]
-defget_pipeline_type(pipeline)->Optional[str]:
-"""
- Determine the type of diffusion pipeline.
-
- Args:
- pipeline: The diffusion pipeline object
-
- Returns:
- str: One of the pipeline type constants or None if not recognized
- """
- ifisinstance(pipeline,StableDiffusionPipeline):
- returnPIPELINE_TYPE_IMAGE
- elifisinstance(pipeline,TextToVideoSDPipeline):
- returnPIPELINE_TYPE_TEXT_TO_VIDEO
- elifisinstance(pipeline,StableVideoDiffusionPipeline):
- returnPIPELINE_TYPE_IMAGE_TO_VIDEO
- else:
- returnNone
-
-
-
-[docs]
-defis_video_pipeline(pipeline)->bool:
-"""
- Check if the pipeline is a video generation pipeline.
-
- Args:
- pipeline: The diffusion pipeline object
-
- Returns:
- bool: True if the pipeline is a video generation pipeline, False otherwise
- """
- pipeline_type=get_pipeline_type(pipeline)
- returnpipeline_typein[PIPELINE_TYPE_TEXT_TO_VIDEO,PIPELINE_TYPE_IMAGE_TO_VIDEO]
-
-
-
-[docs]
-defis_image_pipeline(pipeline)->bool:
-"""
- Check if the pipeline is an image generation pipeline.
-
- Args:
- pipeline: The diffusion pipeline object
-
- Returns:
- bool: True if the pipeline is an image generation pipeline, False otherwise
- """
- returnget_pipeline_type(pipeline)==PIPELINE_TYPE_IMAGE
-
-
-
-[docs]
-defis_t2v_pipeline(pipeline)->bool:
-"""
- Check if the pipeline is a text-to-video pipeline.
-
- Args:
- pipeline: The diffusion pipeline object
-
- Returns:
- bool: True if the pipeline is a text-to-video pipeline, False otherwise
- """
- returnget_pipeline_type(pipeline)==PIPELINE_TYPE_TEXT_TO_VIDEO
-
-
-
-[docs]
-defis_i2v_pipeline(pipeline)->bool:
-"""
- Check if the pipeline is an image-to-video pipeline.
-
- Args:
- pipeline: The diffusion pipeline object
-
- Returns:
- bool: True if the pipeline is an image-to-video pipeline, False otherwise
- """
- returnget_pipeline_type(pipeline)==PIPELINE_TYPE_IMAGE_TO_VIDEO
-
-
-
-[docs]
-defget_pipeline_requirements(pipeline_type:str)->Dict[str,Any]:
-"""
- Get the requirements for a specific pipeline type (required parameters, etc.)
-
- Args:
- pipeline_type: The pipeline type string
-
- Returns:
- Dict: A dictionary containing the pipeline requirements
- """
- ifpipeline_type==PIPELINE_TYPE_IMAGE:
- return{
- "required_params":[],
- "optional_params":["height","width","num_images_per_prompt"]
- }
- elifpipeline_type==PIPELINE_TYPE_TEXT_TO_VIDEO:
- return{
- "required_params":["num_frames"],
- "optional_params":["height","width","fps"]
- }
- elifpipeline_type==PIPELINE_TYPE_IMAGE_TO_VIDEO:
- return{
- "required_params":["input_image","num_frames"],
- "optional_params":["height","width","fps"]
- }
- else:
- return{"required_params":[],"optional_params":[]}
-[docs]
-definherit_docstring(cls):
-"""
- Inherit docstrings from base classes to methods without docstrings.
-
- This decorator automatically applies the docstring from a base class method
- to a derived class method if the derived method doesn't have its own docstring.
-
- Args:
- cls: The class to enhance with inherited docstrings
-
- Returns:
- cls: The enhanced class
- """
- forname,funcinvars(cls).items():
- ifnotcallable(func)orfunc.__doc__isnotNone:
- continue
-
- # Look for same method in base classes
- forbaseincls.__bases__:
- base_func=getattr(base,name,None)
- ifbase_funcandgetattr(base_func,"__doc__",None):
- func.__doc__=base_func.__doc__
- break
-
- returncls
-
-
-
-
-[docs]
-defload_config_file(path:str)->dict:
-"""Load a JSON configuration file from the specified path and return it as a dictionary."""
- try:
- withopen(path,'r')asf:
- config_dict=json.load(f)
- returnconfig_dict
-
- exceptFileNotFoundError:
- print(f"Error: The file '{path}' does not exist.")
- returnNone
- exceptjson.JSONDecodeErrorase:
- print(f"Error decoding JSON in '{path}': {e}")
- # Handle other potential JSON decoding errors here
- returnNone
- exceptExceptionase:
- print(f"An unexpected error occurred: {e}")
- # Handle other unexpected errors here
- returnNone
-
-
-
-
-[docs]
-defload_json_as_list(input_file:str)->list:
-"""Load a JSON file as a list of dictionaries."""
- res=[]
- withopen(input_file,'r')asf:
- lines=f.readlines()
- forlineinlines:
- d=json.loads(line)
- res.append(d)
- returnres
-
-
-
-
-[docs]
-defcreate_directory_for_file(file_path)->None:
-"""Create the directory for the specified file path if it does not already exist."""
- directory=os.path.dirname(file_path)
- ifnotos.path.exists(directory):
- os.makedirs(directory)
-
-
-
-
-[docs]
-defset_random_seed(seed:int):
-"""Set random seeds for reproducibility."""
-
- torch.manual_seed(seed+0)
- torch.cuda.manual_seed(seed+1)
- torch.cuda.manual_seed_all(seed+2)
- np.random.seed((seed+3)%2**32)
- torch.cuda.manual_seed_all(seed+4)
- random.seed(seed+5)
-[docs]
-classAutoVisualizer:
-"""
- Factory class for creating visualization data instances.
-
- This is a generic visualization data factory that will instantiate the appropriate
- visualization data class based on the algorithm name.
-
- This class cannot be instantiated directly using __init__() (throws an error).
- """
-
-
-[docs]
- def__init__(self):
- raiseEnvironmentError(
- "AutoVisualizer is designed to be instantiated "
- "using the `AutoVisualizer.load(algorithm_name, **kwargs)` method."
- )
-
-
- @staticmethod
- def_get_visualization_class_name(algorithm_name:str)->Optional[str]:
-"""Get the visualization data class name from the algorithm name."""
- foralg_name,class_pathinVISUALIZATION_DATA_MAPPING.items():
- ifalgorithm_name.lower()==alg_name.lower():
- returnclass_path
- returnNone
-
-
-[docs]
- @classmethod
- defload(cls,algorithm_name:str,data_for_visualization:DataForVisualization,dpi:int=300,watermarking_step:int=-1)->BaseVisualizer:
-"""
- Load the visualization data instance based on the algorithm name.
-
- Args:
- algorithm_name: Name of the watermarking algorithm (e.g., 'TR', 'GS', 'PRC')
- data_for_visualization: DataForVisualization instance
-
- Returns:
- BaseVisualizer: Instance of the appropriate visualization data class
-
- Raises:
- ValueError: If the algorithm name is not supported
- """
- # Check if the algorithm exists
- class_path=cls._get_visualization_class_name(algorithm_name)
-
- ifalgorithm_name!=data_for_visualization.algorithm_name:
- raiseValueError(f"Algorithm name mismatch: {algorithm_name} != {data_for_visualization.algorithm_name}")
-
- ifclass_pathisNone:
- supported_algs=list(VISUALIZATION_DATA_MAPPING.keys())
- raiseValueError(
- f"Invalid algorithm name: {algorithm_name}. "
- f"Supported algorithms: {', '.join(supported_algs)}"
- )
-
- # Load the visualization data module and class
- module_name,class_name=class_path.rsplit('.',1)
- try:
- module=importlib.import_module(module_name)
- visualization_class=getattr(module,class_name)
- except(ImportError,AttributeError)ase:
- raiseImportError(
- f"Failed to load visualization data class '{class_name}' "
- f"from module '{module_name}': {e}"
- )
-
- # Create and validate the instance
- instance=visualization_class(data_for_visualization=data_for_visualization,dpi=dpi,watermarking_step=watermarking_step)
- returninstance
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromabcimportABC,abstractmethod
-fromtypingimportOptional,Dict,Any,List
-importtorch
-fromPILimportImage
-fromvisualize.data_for_visualizationimportDataForVisualization
-importmatplotlib.pyplotasplt
-frommatplotlib.axesimportAxes
-importnumpyasnp
-fromtypingimportTuple
-fromnumpy.fftimportfft2,fftshift,ifft2,ifftshift
-fromPILimportImage
-frommatplotlib.gridspecimportGridSpecFromSubplotSpec
-
-
-[docs]
-classBaseVisualizer(ABC):
-"""Base class for watermark visualization data"""
-
-
-[docs]
- def__init__(self,data_for_visualization:DataForVisualization,dpi:int=300,watermarking_step:int=-1,is_video:bool=False):
-"""Initialize with common attributes"""
- self.data=data_for_visualization
- self.dpi=dpi
- self.watermarking_step=-1# The step for inserting the watermark, defaults to -1 for the last step
- self.is_video=is_video# Whether this is for T2V (video) or T2I (image) model
-
-
- def_fft_transform(self,latent:torch.Tensor)->np.ndarray:
-"""
- Apply FFT transform to the latent tensor of the watermarked image.
- """
- returnfftshift(fft2(latent.cpu().numpy()))
-
- def_ifft_transform(self,fft_data:np.ndarray)->np.ndarray:
-"""
- Apply inverse FFT transform to the fft data.
- """
- returnifft2(ifftshift(fft_data))
-
- def_get_latent_data(self,latents:torch.Tensor,channel:Optional[int]=None,frame:Optional[int]=None)->torch.Tensor:
-"""
- Extract latent data with proper indexing for both T2I and T2V models.
-
- Parameters:
- latents: The latent tensor [B, C, H, W] for T2I or [B, C, F, H, W] for T2V
- channel: The channel index to extract
- frame: The frame index to extract (only for T2V models)
-
- Returns:
- The extracted latent tensor
- """
- ifself.is_video:
- # T2V model: [B, C, F, H, W]
- ifframeisnotNone:
- ifchannelisnotNone:
- returnlatents[0,channel,frame]# [H, W]
- else:
- returnlatents[0,:,frame]# [C, H, W]
- else:
- # If no frame specified, use the middle frame
- mid_frame=latents.shape[2]//2
- ifchannelisnotNone:
- returnlatents[0,channel,mid_frame]# [H, W]
- else:
- returnlatents[0,:,mid_frame]# [C, H, W]
- else:
- # T2I model: [B, C, H, W]
- ifchannelisnotNone:
- returnlatents[0,channel]# [H, W]
- else:
- returnlatents[0]# [C, H, W]
-
-
-[docs]
- defdraw_orig_latents(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- title:str="Original Latents",
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the original latents of the watermarked image.
-
- Parameters:
- channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
- frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
- title (str): The title of the plot.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifchannelisnotNone:
- # Single channel visualization
- latent_data=self._get_latent_data(self.data.orig_watermarked_latents,channel,frame).cpu().numpy()
- im=ax.imshow(latent_data,cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
- else:
- # Multi-channel visualization
- num_channels=4
- rows=2
- cols=2
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots for each channel
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Draw the latent channel
- latent_data=self._get_latent_data(self.data.orig_watermarked_latents,i,frame).cpu().numpy()
- im=sub_ax.imshow(latent_data,cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
- # Add small colorbar for each subplot
- ifuse_color_bar:
- cbar=ax.figure.colorbar(im,ax=sub_ax,fraction=0.046,pad=0.04)
- cbar.ax.tick_params(labelsize=6)
-
- returnax
-
-
-
-[docs]
- defdraw_orig_latents_fft(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- title:str="Original Latents in Fourier Domain",
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the original latents of the watermarked image in the Fourier domain.
-
- Parameters:
- channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
- frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
- title (str): The title of the plot.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifchannelisnotNone:
- # Single channel visualization
- latent_data=self._get_latent_data(self.data.orig_watermarked_latents,channel,frame)
- fft_data=self._fft_transform(latent_data)
-
- im=ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
- else:
- # Multi-channel visualization
- num_channels=4
- rows=2
- cols=2
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots for each channel
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Draw the FFT of latent channel
- latent_data=self._get_latent_data(self.data.orig_watermarked_latents,i,frame)
- fft_data=self._fft_transform(latent_data)
- im=sub_ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
- # Add small colorbar for each subplot
- ifuse_color_bar:
- cbar=ax.figure.colorbar(im,ax=sub_ax,fraction=0.046,pad=0.04)
- cbar.ax.tick_params(labelsize=6)
-
- returnax
-
-
-
-[docs]
- defdraw_inverted_latents(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- step:Optional[int]=None,
- title:str="Inverted Latents",
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the inverted latents of the watermarked image.
-
- Parameters:
- channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
- frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
- step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
- title (str): The title of the plot.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifchannelisnotNone:
- # Single channel visualization
- # Get inverted latents data
- ifstepisNone:
- reversed_latents=self.data.reversed_latents[self.watermarking_step]
- else:
- reversed_latents=self.data.reversed_latents[step]
-
- latent_data=self._get_latent_data(reversed_latents,channel,frame).cpu().numpy()
- im=ax.imshow(latent_data,cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
- else:
- # Multi-channel visualization
- num_channels=4
- rows=2
- cols=2
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots for each channel
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Get inverted latents data
- ifstepisNone:
- reversed_latents=self.data.reversed_latents[self.watermarking_step]
- else:
- reversed_latents=self.data.reversed_latents[step]
-
- latent_data=self._get_latent_data(reversed_latents,i,frame).cpu().numpy()
-
- # Draw the latent channel
- im=sub_ax.imshow(latent_data,cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
- # Add small colorbar for each subplot
- ifuse_color_bar:
- cbar=ax.figure.colorbar(im,ax=sub_ax,fraction=0.046,pad=0.04)
- cbar.ax.tick_params(labelsize=6)
-
- returnax
-
-
-
-[docs]
- defdraw_inverted_latents_fft(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- step:int=-1,
- title:str="Inverted Latents in Fourier Domain",
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the inverted latents of the watermarked image in the Fourier domain.
-
- Parameters:
- channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
- frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
- step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
- title (str): The title of the plot.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifchannelisnotNone:
- # Single channel visualization
- reversed_latents=self.data.reversed_latents[step]
- latent_data=self._get_latent_data(reversed_latents,channel,frame)
- fft_data=self._fft_transform(latent_data)
-
- im=ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
- else:
- # Multi-channel visualization
- num_channels=4
- rows=2
- cols=2
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots for each channel
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Draw the FFT of inverted latent channel
- reversed_latents=self.data.reversed_latents[step]
- latent_data=self._get_latent_data(reversed_latents,i,frame)
- fft_data=self._fft_transform(latent_data)
- im=sub_ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
- # Add small colorbar for each subplot
- ifuse_color_bar:
- cbar=ax.figure.colorbar(im,ax=sub_ax,fraction=0.046,pad=0.04)
- cbar.ax.tick_params(labelsize=6)
-
- returnax
-
-
-
-[docs]
- defdraw_diff_latents_fft(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- title:str="Difference between Original and Inverted Latents in Fourier Domain",
- cmap:str="coolwarm",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the difference between the original and inverted initial latents of the watermarked image in the Fourier domain.
-
- Parameters:
- channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
- frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
- title (str): The title of the plot.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifchannelisnotNone:
- # Single channel visualization
- # Get original and inverted latents
- orig_data=self._get_latent_data(self.data.orig_watermarked_latents,channel,frame).cpu().numpy()
-
- reversed_latents=self.data.reversed_latents[self.watermarking_step]
- inv_data=self._get_latent_data(reversed_latents,channel,frame).cpu().numpy()
-
- # Compute difference
- diff_data=orig_data-inv_data
-
- # Convert to tensor for FFT transform
- diff_tensor=torch.tensor(diff_data)
- fft_data=self._fft_transform(diff_tensor)
-
- im=ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
- else:
- # Multi-channel visualization
- num_channels=4
- rows=2
- cols=2
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots for each channel
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Get original and inverted latents
- orig_data=self._get_latent_data(self.data.orig_watermarked_latents,i,frame).cpu().numpy()
-
- reversed_latents=self.data.reversed_latents[self.watermarking_step]
- inv_data=self._get_latent_data(reversed_latents,i,frame).cpu().numpy()
-
- # Compute difference and FFT
- diff_data=orig_data-inv_data
- diff_tensor=torch.tensor(diff_data)
- fft_data=self._fft_transform(diff_tensor)
-
- # Draw the FFT of difference
- im=sub_ax.imshow(np.abs(fft_data),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
- # Add small colorbar for each subplot
- ifuse_color_bar:
- cbar=ax.figure.colorbar(im,ax=sub_ax,fraction=0.046,pad=0.04)
- cbar.ax.tick_params(labelsize=6)
-
- returnax
-
-
-
-[docs]
- defdraw_watermarked_image(self,
- title:str="Watermarked Image",
- num_frames:int=4,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the watermarked image or video frames.
-
- For images (is_video=False), displays a single image.
- For videos (is_video=True), displays a grid of video frames.
-
- Parameters:
- title (str): The title of the plot.
- num_frames (int): Number of frames to display for videos (default: 4).
- vmin (Optional[float]): Minimum value for colormap.
- vmax (Optional[float]): Maximum value for colormap.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifself.is_video:
- # Video visualization: display multiple frames
- returnself._draw_video_frames(title=title,num_frames=num_frames,ax=ax,**kwargs)
- else:
- # Image visualization: display single image
- returnself._draw_single_image(title=title,vmin=vmin,vmax=vmax,ax=ax,**kwargs)
-
-
- def_draw_single_image(self,
- title:str="Watermarked Image",
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw a single watermarked image.
-
- Parameters:
- title (str): The title of the plot.
- vmin (Optional[float]): Minimum value for colormap.
- vmax (Optional[float]): Maximum value for colormap.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- # Convert image data to numpy array
- iftorch.is_tensor(self.data.image):
- # Handle tensor format (like in RI watermark)
- ifself.data.image.dim()==4:# [B, C, H, W]
- image_array=self.data.image[0].permute(1,2,0).cpu().numpy()
- elifself.data.image.dim()==3:# [C, H, W]
- image_array=self.data.image.permute(1,2,0).cpu().numpy()
- else:
- image_array=self.data.image.cpu().numpy()
-
- # Normalize to 0-1 if needed
- ifimage_array.max()>1.0:
- image_array=image_array/255.0
-
- # Normalize [-1, 1] range to [0, 1] for imshow
- ifimage_array.min()<0:
- image_array=(image_array+1.0)/2.0
-
- # Clip to valid range
- image_array=np.clip(image_array,0,1)
- else:
- # Handle PIL Image format
- image_array=np.array(self.data.image)
-
- im=ax.imshow(image_array,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title,fontsize=12)
- ax.axis('off')
-
- # Hidden colorbar for nice visualization
- cbar=ax.figure.colorbar(im,ax=ax,alpha=0.0)
- cbar.ax.set_visible(False)
-
- returnax
-
- def_draw_video_frames(self,
- title:str="Watermarked Video Frames",
- num_frames:int=4,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw multiple frames from the watermarked video.
-
- This method displays a grid of video frames to show the temporal
- consistency of the watermarked video.
-
- Parameters:
- title (str): The title of the plot.
- num_frames (int): Number of frames to display (default: 4).
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- ifnothasattr(self.data,'video_frames')orself.data.video_framesisNone:
- raiseValueError("No video frames available for visualization. Please ensure video_frames is provided in data_for_visualization.")
-
- video_frames=self.data.video_frames
- total_frames=len(video_frames)
-
- # Limit num_frames to available frames
- num_frames=min(num_frames,total_frames)
-
- # Calculate which frames to show (evenly distributed)
- ifnum_frames==1:
- frame_indices=[total_frames//2]# Middle frame
- else:
- frame_indices=[int(i*(total_frames-1)/(num_frames-1))foriinrange(num_frames)]
-
- # Calculate grid layout
- rows=int(np.ceil(np.sqrt(num_frames)))
- cols=int(np.ceil(num_frames/rows))
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ax.set_title(title,pad=20,fontsize=12)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.1,hspace=0.4)
-
- # Create subplots for each frame
- fori,frame_idxinenumerate(frame_indices):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Get the frame
- frame=video_frames[frame_idx]
-
- # Convert frame to displayable format
- try:
- # First, convert tensor to numpy if needed
- ifhasattr(frame,'cpu'):# PyTorch tensor
- frame=frame.cpu().numpy()
- elifhasattr(frame,'numpy'):# Other tensor types
- frame=frame.numpy()
- elifhasattr(frame,'convert'):# PIL Image
- frame=np.array(frame)
-
- # Handle channels-first format (C, H, W) -> (H, W, C) for numpy arrays
- ifisinstance(frame,np.ndarray)andlen(frame.shape)==3:
- ifframe.shape[0]in[1,3,4]:# Channels first
- frame=np.transpose(frame,(1,2,0))
-
- # Ensure proper data type for matplotlib
- ifisinstance(frame,np.ndarray):
- ifframe.dtype==np.float64:
- frame=frame.astype(np.float32)
- elifframe.dtypenotin[np.uint8,np.float32]:
- # Convert to float32 and normalize if needed
- frame=frame.astype(np.float32)
- ifframe.max()>1.0:
- frame=frame/255.0
-
- # Normalize [-1, 1] range to [0, 1] for imshow
- ifframe.min()<0:
- frame=(frame+1.0)/2.0
-
- # Clip to valid range [0, 1]
- frame=np.clip(frame,0,1)
-
- im=sub_ax.imshow(frame)
-
- exceptExceptionase:
- print(f"Error displaying frame {frame_idx}: {e}")
-
- sub_ax.set_title(f'Frame {frame_idx}',fontsize=10,pad=5)
- sub_ax.axis('off')
-
- # Hide unused subplots
- foriinrange(num_frames,rows*cols):
- row_idx=i//cols
- col_idx=i%cols
- ifrow_idx<rowsandcol_idx<cols:
- empty_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
- empty_ax.axis('off')
-
- returnax
-
-
-[docs]
- defvisualize(self,
- rows:int,
- cols:int,
- methods:List[str],
- figsize:Optional[Tuple[int,int]]=None,
- method_kwargs:Optional[List[Dict[str,Any]]]=None,
- save_path:Optional[str]=None)->plt.Figure:
-"""
- Comprehensive visualization of watermark analysis.
-
- Parameters:
- rows (int): The number of rows of the subplots.
- cols (int): The number of columns of the subplots.
- methods (List[str]): List of methods to call.
- method_kwargs (Optional[List[Dict[str, Any]]]): List of keyword arguments for each method.
- figsize (Tuple[int, int]): The size of the figure.
- save_path (Optional[str]): The path to save the figure.
-
- Returns:
- plt.Figure: The matplotlib figure object.
- """
- # Check if the rows and cols are compatible with the number of methods
- iflen(methods)!=rows*cols:
- raiseValueError(f"The number of methods ({len(methods)}) is not compatible with the layout ({rows}x{cols})")
-
- # Initialize the figure size if not provided
- iffigsizeisNone:
- figsize=(cols*5,rows*5)
-
- # Create figure and subplots
- fig,axes=plt.subplots(rows,cols,figsize=figsize)
-
- # Ensure axes is always a 2D array for consistent indexing
- ifrows==1andcols==1:
- axes=np.array([[axes]])
- elifrows==1:
- axes=axes.reshape(1,-1)
- elifcols==1:
- axes=axes.reshape(-1,1)
-
- ifmethod_kwargsisNone:
- method_kwargs=[{}for_inmethods]
-
- # Plot each method
- fori,method_nameinenumerate(methods):
- row=i//cols
- col=i%cols
- ax=axes[row,col]
-
- try:
- method=getattr(self,method_name)
- exceptAttributeError:
- raiseValueError(f"Method '{method_name}' not found in {self.__class__.__name__}")
-
- try:
- # print(method_kwargs[i])
- method(ax=ax,**method_kwargs[i])
- exceptTypeError:
- raiseValueError(f"Method '{method_name}' does not accept the given arguments: {method_kwargs[i]}")
-
- # if the number of methods is less than the number of axes, hide the unused axes
- foriinrange(len(methods),rows*cols):
- row=i//cols
- col=i%cols
- axes[row,col].axis('off')
-
- plt.tight_layout(pad=2.0,w_pad=3.0,h_pad=2.0)
-
- ifsave_pathisnotNone:
- plt.savefig(save_path,bbox_inches='tight',dpi=self.dpi)
-
- returnfig
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromtypingimportOptional
-importtorch
-importmatplotlib.pyplotasplt
-frommatplotlib.axesimportAxes
-frommatplotlib.gridspecimportGridSpec
-importnumpyasnp
-fromvisualize.baseimportBaseVisualizer
-fromvisualize.data_for_visualizationimportDataForVisualization
-
-
-[docs]
- def__init__(self,data_for_visualization:DataForVisualization,dpi:int=300,watermarking_step:int=-1):
- super().__init__(data_for_visualization,dpi,watermarking_step)
- # ROBIN uses a specific watermarking step
- ifhasattr(self.data,'watermarking_step'):
- self.watermarking_step=self.data.watermarking_step
- else:
- raiseValueError("watermarking_step is required for ROBIN visualization")
-
-
-
-[docs]
- defdraw_pattern_fft(self,
- title:str=None,
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw FFT visualization with original watermark pattern, with all 0 background.
-
- Parameters:
- title (str): The title of the plot. If None, includes watermarking step info.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- # Use custom title with watermarking step if not provided
- iftitleisNone:
- title=f"ROBIN FFT with Watermark Area (Step {self.watermarking_step})"
-
- orig_latent=self.data.optimized_watermark[0,self.data.w_channel].cpu()
- watermarking_mask=self.data.watermarking_mask[0,self.data.w_channel].cpu()
-
- fft_data=torch.from_numpy(self._fft_transform(orig_latent))
- fft_vis=torch.zeros_like(fft_data)
- fft_vis[watermarking_mask]=fft_data[watermarking_mask]
-
- im=ax.imshow(np.abs(fft_vis.cpu().numpy()),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
-
- returnax
-
-
-
-[docs]
- defdraw_inverted_pattern_fft(self,
- step:Optional[int]=None,
- title:str=None,
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw FFT visualization with inverted pattern, with all 0 background.
-
- Parameters:
- step (Optional[int]): The timestep of the inverted latents. If None, uses ROBIN's specific step.
- title (str): The title of the plot. If None, includes watermarking step info.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- # For ROBIN, we need to use the specific watermarking step
- ifstepisNone:
- # Calculate the actual step index for ROBIN
- # ROBIN uses: num_steps_to_use - 1 - self.config.watermarking_step
- num_steps=len(self.data.reversed_latents)
- actual_step=num_steps-1-self.watermarking_step
- inverted_latent=self.data.reversed_latents[actual_step][0,self.data.w_channel]
- else:
- inverted_latent=self.data.reversed_latents[step][0,self.data.w_channel]
-
- # Use custom title with watermarking step if not provided
- iftitleisNone:
- title=f"ROBIN FFT with Inverted Watermark Area (Step {self.watermarking_step})"
-
- watermarking_mask=self.data.watermarking_mask[0,self.data.w_channel].cpu()
-
- fft_data=torch.from_numpy(self._fft_transform(inverted_latent))
- fft_vis=torch.zeros_like(fft_data).to(fft_data.device)
- fft_vis[watermarking_mask]=fft_data[watermarking_mask]
-
- im=ax.imshow(np.abs(fft_vis.cpu().numpy()),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
-
- returnax
-
-
-
-[docs]
- defdraw_optimized_watermark(self,
- title:str=None,
- cmap:str="viridis",
- use_color_bar:bool=True,
- vmin:Optional[float]=None,
- vmax:Optional[float]=None,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the optimized watermark pattern (ROBIN-specific).
-
- Parameters:
- title (str): The title of the plot. If None, includes watermarking step info.
- cmap (str): The colormap to use.
- use_color_bar (bool): Whether to display the colorbar.
- ax (Axes): The axes to plot on.
-
- Returns:
- Axes: The plotted axes.
- """
- # Use custom title with watermarking step if not provided
- iftitleisNone:
- title=f"ROBIN Optimized Watermark (Step {self.watermarking_step})"
-
- optimized_watermark=self.data.optimized_watermark[0,self.data.w_channel].cpu()
-
- im=ax.imshow(np.abs(optimized_watermark.numpy()),cmap=cmap,vmin=vmin,vmax=vmax,**kwargs)
- iftitle!="":
- ax.set_title(title)
- ifuse_color_bar:
- ax.figure.colorbar(im,ax=ax)
- ax.axis('off')
-
- returnax
Source code for visualize.videomark.video_mark_visualizer
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importmatplotlib.pyplotasplt
-frommatplotlib.axesimportAxes
-frommatplotlib.gridspecimportGridSpecFromSubplotSpec
-importnumpyasnp
-fromtypingimportOptional
-fromvisualize.baseimportBaseVisualizer
-fromvisualize.data_for_visualizationimportDataForVisualization
-
-
-
-[docs]
-classVideoMarkVisualizer(BaseVisualizer):
-"""VideoMark watermark visualization class.
-
- This visualizer handles watermark visualization for VideoShield algorithm,
- which extends Gaussian Shading to the video domain by adding frame dimensions.
-
- Key Members for VideoMarkVisualizer:
- - self.data.orig_watermarked_latents: [B, C, F, H, W]
- - self.data.reversed_latents: List[[B, C, F, H, W]]
- """
-
-
-[docs]
- defdraw_watermarked_video_frames(self,
- num_frames:int=4,
- title:str="Watermarked Video Frames",
- ax:Optional[Axes]=None)->Axes:
-"""
- Draw multiple frames from the watermarked video.
-
- DEPRECATED:
- This method is deprecated and will be removed in a future version.
- Please use `draw_watermarked_image` instead.
-
- This method displays a grid of video frames to show the temporal
- consistency of the watermarked video.
-
- Args:
- num_frames: Number of frames to display (default: 4)
- title: The title of the plot
- ax: The axes to plot on
-
- Returns:
- The plotted axes
- """
- returnself._draw_video_frames(
- title=title,
- num_frames=num_frames,
- ax=ax
- )
-
-
-
-[docs]
- defdraw_generator_matrix(self,
- title:str="Generator Matrix G",
- cmap:str="Blues",
- use_color_bar:bool=True,
- max_display_size:int=50,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the generator matrix visualization
-
- Parameters:
- title (str): The title of the plot
- cmap (str): The colormap to use
- use_color_bar (bool): Whether to display the colorbar
- max_display_size (int): Maximum size to display (for large matrices)
- ax (Axes): The axes to plot on
-
- Returns:
- Axes: The plotted axes
- """
- ifhasattr(self.data,'generator_matrix')andself.data.generator_matrixisnotNone:
- gen_matrix=self.data.generator_matrix.cpu().numpy()
-
- # Show a sample of the matrix if it's too large
- ifgen_matrix.shape[0]>max_display_sizeorgen_matrix.shape[1]>max_display_size:
- sample_size=min(max_display_size,min(gen_matrix.shape))
- matrix_sample=gen_matrix[:sample_size,:sample_size]
- title+=f" (Sample {sample_size}x{sample_size})"
- else:
- matrix_sample=gen_matrix
-
- im=ax.imshow(matrix_sample,cmap=cmap,aspect='auto',**kwargs)
-
- ifuse_color_bar:
- plt.colorbar(im,ax=ax,shrink=0.8)
- else:
- ax.text(0.5,0.5,'Generator Matrix\nNot Available',
- ha='center',va='center',fontsize=12,transform=ax.transAxes)
-
- ax.set_title(title,fontsize=10)
- ax.set_xlabel('Columns')
- ax.set_ylabel('Rows')
- returnax
-
-
-
-[docs]
- defdraw_codeword(self,
- title:str="VideoMark Codeword",
- cmap:str="viridis",
- use_color_bar:bool=True,
- ax:Optional[Axes]=None,
- **kwargs)->Axes:
-"""
- Draw the PRC codeword visualization
-
- Parameters:
- title (str): The title of the plot
- cmap (str): The colormap to use
- use_color_bar (bool): Whether to display the colorbar
- ax (Axes): The axes to plot on
-
- Returns:
- Axes: The plotted axes
- """
- ifhasattr(self.data,'prc_codeword')andself.data.prc_codewordisnotNone:
- codeword=self.data.prc_codeword[0].cpu().numpy()#Get the first-frame codeword for visualization
-
- # If 1D, reshape for visualization
- iflen(codeword.shape)==1:
- # Create a reasonable 2D shape
- length=len(codeword)
- height=int(np.sqrt(length))
- width=length//height
- ifheight*width<length:
- width+=1
- # Pad if necessary
- padded_codeword=np.zeros(height*width)
- padded_codeword[:length]=codeword
- codeword=padded_codeword.reshape(height,width)
-
- im=ax.imshow(codeword,cmap=cmap,aspect='equal',**kwargs)
-
- ifuse_color_bar:
- plt.colorbar(im,ax=ax,shrink=0.8)
- else:
- ax.text(0.5,0.5,'PRC Codeword\nNot Available',
- ha='center',va='center',fontsize=12,transform=ax.transAxes)
-
- ax.set_title(title,fontsize=12)
- returnax
Source code for visualize.videoshield.video_shield_visualizer
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importmatplotlib.pyplotasplt
-frommatplotlib.axesimportAxes
-frommatplotlib.gridspecimportGridSpecFromSubplotSpec
-importnumpyasnp
-fromtypingimportOptional
-fromvisualize.baseimportBaseVisualizer
-fromvisualize.data_for_visualizationimportDataForVisualization
-fromCrypto.CipherimportChaCha20
-
-
-
-[docs]
-classVideoShieldVisualizer(BaseVisualizer):
-"""VideoShield watermark visualization class.
-
- This visualizer handles watermark visualization for VideoShield algorithm,
- which extends Gaussian Shading to the video domain by adding frame dimensions.
-
- Key Members for VideoShieldVisualizer:
- - self.data.orig_watermarked_latents: [B, C, F, H, W]
- - self.data.reversed_latents: List[[B, C, F, H, W]]
- """
-
-
-[docs]
- defdraw_watermark_bits(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- title:str="Original Watermark Bits",
- cmap:str="binary",
- ax:Optional[Axes]=None)->Axes:
-"""Draw the original watermark bits for VideoShield.
-
- For video watermarks, this method can visualize specific frames or average
- across frames to create a 2D visualization.
-
- Args:
- channel: The channel to visualize. If None, all channels are shown.
- frame: The frame to visualize. If None, uses middle frame for videos.
- title: The title of the plot.
- cmap: The colormap to use.
- ax: The axes to plot on.
-
- Returns:
- The plotted axes.
- """
- # Reshape watermark to video dimensions based on repetition factors
- # VideoShield watermark shape: [1, C//k_c, F//k_f, H//k_h, W//k_w]
- ch_stride=4//self.data.k_c
- frame_stride=self.data.num_frames//self.data.k_f
- h_stride=self.data.latents_height//self.data.k_h
- w_stride=self.data.latents_width//self.data.k_w
-
- watermark=self.data.watermark.reshape(1,ch_stride,frame_stride,h_stride,w_stride)
-
- ifchannelisnotNone:
- # Single channel visualization
- ifchannel>=ch_stride:
- raiseValueError(f"Channel {channel} is out of range. Max channel: {ch_stride-1}")
-
- # Select specific frame or use middle frame
- ifframeisnotNone:
- ifframe>=frame_stride:
- raiseValueError(f"Frame {frame} is out of range. Max frame: {frame_stride-1}")
- watermark_data=watermark[0,channel,frame].cpu().numpy()
- frame_info=f" - Frame {frame}"
- else:
- # Use middle frame
- mid_frame=frame_stride//2
- watermark_data=watermark[0,channel,mid_frame].cpu().numpy()
- frame_info=f" - Frame {mid_frame} (middle)"
-
- im=ax.imshow(watermark_data,cmap=cmap,vmin=0,vmax=1,interpolation='nearest')
- iftitle!="":
- ax.set_title(f"{title} - Channel {channel}{frame_info}",fontsize=10)
- ax.axis('off')
-
- cbar=ax.figure.colorbar(im,ax=ax,alpha=0.0)
- cbar.ax.set_visible(False)
- else:
- # Multi-channel visualization
- num_channels=ch_stride
-
- # Calculate grid layout
- rows=int(np.ceil(np.sqrt(num_channels)))
- cols=int(np.ceil(num_channels/rows))
-
- # Clear the axis and set title
- ax.clear()
- iftitle!="":
- ifframeisnotNone:
- ax.set_title(f"{title} - Frame {frame}",pad=20,fontsize=10)
- else:
- mid_frame=frame_stride//2
- ax.set_title(f"{title} - Frame {mid_frame} (middle)",pad=20,fontsize=10)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Select specific frame or use middle frame
- ifframeisnotNone:
- ifframe>=frame_stride:
- raiseValueError(f"Frame {frame} is out of range. Max frame: {frame_stride-1}")
- watermark_data=watermark[0,i,frame].cpu().numpy()
- else:
- mid_frame=frame_stride//2
- watermark_data=watermark[0,i,mid_frame].cpu().numpy()
-
- # Draw the watermark channel
- sub_ax.imshow(watermark_data,cmap=cmap,vmin=0,vmax=1,interpolation='nearest')
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
-
- returnax
-
-
-
-[docs]
- defdraw_reconstructed_watermark_bits(self,
- channel:Optional[int]=None,
- frame:Optional[int]=None,
- title:str="Reconstructed Watermark Bits",
- cmap:str="binary",
- ax:Optional[Axes]=None)->Axes:
-"""Draw the reconstructed watermark bits for VideoShield.
-
- Args:
- channel: The channel to visualize. If None, all channels are shown.
- frame: The frame to visualize. If None, uses middle frame for videos.
- title: The title of the plot.
- cmap: The colormap to use.
- ax: The axes to plot on.
-
- Returns:
- The plotted axes.
- """
- # Step 1: Get reversed latents and reconstruct the watermark bits
- reversed_latent=self.data.reversed_latents[self.watermarking_step]
-
- # Convert to binary bits
- reversed_m=(reversed_latent>0).int()
-
- # Decrypt
- reversed_sd_flat=self._stream_key_decrypt(reversed_m.flatten().cpu().numpy())
- # Reshape back to video tensor
- reversed_sd=torch.from_numpy(reversed_sd_flat).reshape(reversed_latent.shape).to(torch.uint8)
-
- # Extract watermark through voting mechanism
- reversed_watermark=self._diffusion_inverse(reversed_sd.cuda())
-
- # Calculate bit accuracy
- bit_acc=(reversed_watermark==self.data.watermark).float().mean().item()
-
- # Reshape to video dimensions for visualization
- ch_stride=4//self.data.k_c
- frame_stride=self.data.num_frames//self.data.k_f
- h_stride=self.data.latents_height//self.data.k_h
- w_stride=self.data.latents_width//self.data.k_w
-
- reconstructed_watermark=reversed_watermark.reshape(1,ch_stride,frame_stride,h_stride,w_stride)
-
- ifchannelisnotNone:
- # Single channel visualization
- ifchannel>=ch_stride:
- raiseValueError(f"Channel {channel} is out of range. Max channel: {ch_stride-1}")
-
- # Select specific frame or use middle frame
- ifframeisnotNone:
- ifframe>=frame_stride:
- raiseValueError(f"Frame {frame} is out of range. Max frame: {frame_stride-1}")
- reconstructed_watermark_data=reconstructed_watermark[0,channel,frame].cpu().numpy()
- frame_info=f" - Frame {frame}"
- else:
- # Use middle frame
- mid_frame=frame_stride//2
- reconstructed_watermark_data=reconstructed_watermark[0,channel,mid_frame].cpu().numpy()
- frame_info=f" - Frame {mid_frame} (middle)"
-
- im=ax.imshow(reconstructed_watermark_data,cmap=cmap,vmin=0,vmax=1,interpolation='nearest')
- iftitle!="":
- ax.set_title(f"{title} - Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})",fontsize=10)
- else:
- ax.set_title(f"Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})",fontsize=10)
- ax.axis('off')
- cbar=ax.figure.colorbar(im,ax=ax,alpha=0.0)
- cbar.ax.set_visible(False)
- else:
- # Multi-channel visualization
- num_channels=ch_stride
-
- # Calculate grid layout
- rows=int(np.ceil(np.sqrt(num_channels)))
- cols=int(np.ceil(num_channels/rows))
-
- # Clear the axis and set title with bit accuracy
- ax.clear()
- iftitle!="":
- ifframeisnotNone:
- ax.set_title(f'{title} - Frame {frame} (Bit Acc: {bit_acc:.3f})',pad=20,fontsize=10)
- else:
- mid_frame=frame_stride//2
- ax.set_title(f'{title} - Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})',pad=20,fontsize=10)
- else:
- ifframeisnotNone:
- ax.set_title(f'Frame {frame} (Bit Acc: {bit_acc:.3f})',pad=20,fontsize=10)
- else:
- mid_frame=frame_stride//2
- ax.set_title(f'Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})',pad=20,fontsize=10)
- ax.axis('off')
-
- # Use gridspec for better control
- gs=GridSpecFromSubplotSpec(rows,cols,subplot_spec=ax.get_subplotspec(),
- wspace=0.3,hspace=0.4)
-
- # Create subplots
- foriinrange(num_channels):
- row_idx=i//cols
- col_idx=i%cols
-
- # Create subplot using gridspec
- sub_ax=ax.figure.add_subplot(gs[row_idx,col_idx])
-
- # Select specific frame or use middle frame
- ifframeisnotNone:
- ifframe>=frame_stride:
- raiseValueError(f"Frame {frame} is out of range. Max frame: {frame_stride-1}")
- reconstructed_watermark_data=reconstructed_watermark[0,i,frame].cpu().numpy()
- else:
- mid_frame=frame_stride//2
- reconstructed_watermark_data=reconstructed_watermark[0,i,mid_frame].cpu().numpy()
-
- # Draw the reconstructed watermark channel
- sub_ax.imshow(reconstructed_watermark_data,cmap=cmap,vmin=0,vmax=1,interpolation='nearest')
- sub_ax.set_title(f'Channel {i}',fontsize=8,pad=3)
- sub_ax.axis('off')
-
- returnax
-
-
-
-[docs]
- defdraw_watermarked_video_frames(self,
- num_frames:int=4,
- title:str="Watermarked Video Frames",
- ax:Optional[Axes]=None)->Axes:
-"""
- Draw multiple frames from the watermarked video.
-
- DEPRECATED:
- This method is deprecated and will be removed in a future version.
- Please use `draw_watermarked_image` instead.
-
- This method displays a grid of video frames to show the temporal
- consistency of the watermarked video.
-
- Args:
- num_frames: Number of frames to display (default: 4)
- title: The title of the plot
- ax: The axes to plot on
-
- Returns:
- The plotted axes
- """
- returnself._draw_video_frames(
- title=title,
- num_frames=num_frames,
- ax=ax
- )
-[docs]
-classAutoWatermark:
-"""
- This is a generic watermark class that will be instantiated as one of the watermark classes of the library when
- created with the [`AutoWatermark.load`] class method.
-
- This class cannot be instantiated directly using `__init__()` (throws an error).
- """
-
-
-[docs]
- def__init__(self):
- raiseEnvironmentError(
- "AutoWatermark is designed to be instantiated "
- "using the `AutoWatermark.load(algorithm_name, algorithm_config, diffusion_config)` method."
- )
-
-
- @staticmethod
- def_check_pipeline_compatibility(pipeline_type:str,algorithm_name:str)->bool:
-"""Check if the pipeline type is compatible with the watermarking algorithm."""
- ifpipeline_typeisNone:
- returnFalse
-
- ifalgorithm_namenotinWATERMARK_MAPPING_NAMES:
- returnFalse
-
- returnalgorithm_nameinPIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type,[])
-
-
-[docs]
- @classmethod
- defload(cls,algorithm_name,algorithm_config=None,diffusion_config=None,*args,**kwargs)->BaseWatermark:
-"""Load the watermark algorithm instance based on the algorithm name."""
- # Check if the algorithm exists
- watermark_name=watermark_name_from_alg_name(algorithm_name)
- ifwatermark_nameisNone:
- supported_algs=list(WATERMARK_MAPPING_NAMES.keys())
- raiseValueError(f"Invalid algorithm name: {algorithm_name}. Please use one of the supported algorithms: {', '.join(supported_algs)}")
-
- # Check pipeline compatibility
- ifdiffusion_configanddiffusion_config.pipe:
- pipeline_type=get_pipeline_type(diffusion_config.pipe)
- ifnotcls._check_pipeline_compatibility(pipeline_type,algorithm_name):
- supported_algs=PIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type,[])
- raiseValueError(
- f"The algorithm '{algorithm_name}' is not compatible with the {pipeline_type} pipeline type. "
- f"Supported algorithms for this pipeline type are: {', '.join(supported_algs)}"
- )
-
- # Load the watermark module
- module_name,class_name=watermark_name.rsplit('.',1)
- module=importlib.import_module(module_name)
- watermark_class=getattr(module,class_name)
- watermark_config=AutoConfig.load(algorithm_name,diffusion_config,algorithm_config_path=algorithm_config,**kwargs)
- watermark_instance=watermark_class(watermark_config)
- returnwatermark_instance
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-fromabcimportABC,abstractmethod
-importtorch
-fromtypingimportDict,List,Union,Optional,Any,Tuple
-fromutils.diffusion_configimportDiffusionConfig
-fromutils.utilsimportload_config_file,set_random_seed
-fromutils.media_utilsimport*
-fromutils.pipeline_utilsimport(
- get_pipeline_type,
- is_image_pipeline,
- is_video_pipeline,
- is_t2v_pipeline,
- is_i2v_pipeline,
- PIPELINE_TYPE_IMAGE,
- PIPELINE_TYPE_TEXT_TO_VIDEO,
- PIPELINE_TYPE_IMAGE_TO_VIDEO
-)
-fromPILimportImage
-fromdiffusersimport(
- StableDiffusionPipeline,
- TextToVideoSDPipeline,
- StableVideoDiffusionPipeline,
- DDIMInverseScheduler
-)
-
-classBaseConfig(ABC):
-"""Base configuration class for diffusion watermarking methods."""
-
- def__init__(self,algorithm_config:str,diffusion_config:DiffusionConfig,*args,**kwargs)->None:
-"""Initialize base configuration with common parameters."""
-
- # Load config file
- self.config_dict=load_config_file(f'config/{self.algorithm_name()}.json')ifalgorithm_configisNoneelseload_config_file(algorithm_config)
-
- # Diffusion model parameters
- ifdiffusion_configisNone:
- raiseValueError("diffusion_config cannot be None for BaseConfig initialization")
-
- ifkwargs:
- self.config_dict.update(kwargs)
-
- self.pipe=diffusion_config.pipe
- self.scheduler=diffusion_config.scheduler
- self.device=diffusion_config.device
- self.guidance_scale=diffusion_config.guidance_scale
- self.num_images=diffusion_config.num_images
- self.num_inference_steps=diffusion_config.num_inference_steps
- self.num_inversion_steps=diffusion_config.num_inversion_steps
- self.image_size=diffusion_config.image_size
- self.dtype=diffusion_config.dtype
- self.gen_seed=diffusion_config.gen_seed
- self.init_latents_seed=diffusion_config.init_latents_seed
- self.inversion_type=diffusion_config.inversion_type
- self.num_frames=diffusion_config.num_frames
-
- # Set inversion module
- self.inversion=set_inversion(self.pipe,self.inversion_type)
- # Set generation kwargs
- self.gen_kwargs=diffusion_config.gen_kwargs
-
- # Get initial latents
- init_latents_rng=torch.Generator(device=self.device)
- init_latents_rng.manual_seed(self.init_latents_seed)
- ifself.num_frames<1:
- self.init_latents=get_random_latents(self.pipe,height=self.image_size[0],width=self.image_size[1],generator=init_latents_rng)
- else:
- self.init_latents=get_random_latents(self.pipe,num_frames=self.num_frames,height=self.image_size[0],width=self.image_size[1],generator=init_latents_rng)
-
- # Initialize algorithm-specific parameters
- self.initialize_parameters()
-
- @abstractmethod
- definitialize_parameters(self)->None:
-"""Initialize algorithm-specific parameters. Should be overridden by subclasses."""
- raiseNotImplementedError
-
- @property
- defalgorithm_name(self)->str:
-"""Return the algorithm name."""
- raiseNotImplementedError
-
-
-[docs]
-classBaseWatermark(ABC):
-"""Base class for diffusion watermarking methods."""
-
-
-
-
- def_detect_pipeline_type(self)->str:
-"""Detect the type of pipeline being used."""
- pipeline_type=get_pipeline_type(self.config.pipe)
- ifpipeline_typeisNone:
- raiseValueError(f"Unsupported pipeline type: {type(self.config.pipe)}")
- returnpipeline_type
-
- def_validate_pipeline_config(self)->None:
-"""Validate that the pipeline configuration is correct for the pipeline type."""
- # For image-to-video pipelines, ensure num_frames is set correctly
- ifself.pipeline_type==PIPELINE_TYPE_IMAGE_TO_VIDEOorself.pipeline_type==PIPELINE_TYPE_TEXT_TO_VIDEO:
- ifself.config.num_frames<1:
- raiseValueError(f"For {self.pipeline_type} pipelines, num_frames must be >= 1, got {self.config.num_frames}")
- # For image pipelines, ensure num_frames is -1
- elifself.pipeline_type==PIPELINE_TYPE_IMAGE:
- ifself.config.num_frames>=1:
- raiseValueError(f"For {self.pipeline_type} pipelines, num_frames should be -1, got {self.config.num_frames}")
-
-
-[docs]
- defget_orig_watermarked_latents(self)->torch.Tensor:
-"""Get the original watermarked latents."""
- returnself.orig_watermarked_latents
-
-
-
-[docs]
- defset_orig_watermarked_latents(self,value:torch.Tensor)->None:
-"""Set the original watermarked latents."""
- self.orig_watermarked_latents=value
-
-
-
-[docs]
- defgenerate_watermarked_media(self,
- input_data:Union[str,Image.Image],
- *args,
- **kwargs)->Union[Image.Image,List[Image.Image]]:
-"""
- Generate watermarked media (image or video) based on pipeline type.
-
- This is the main interface for generating watermarked content with any
- watermarking algorithm. It automatically routes to the appropriate generation
- method based on the pipeline type (image or video).
-
- Args:
- input_data: Text prompt (for T2I or T2V) or input image (for I2V)
- *args: Additional positional arguments
- **kwargs: Additional keyword arguments, including:
- - guidance_scale: Guidance scale for generation
- - num_inference_steps: Number of inference steps
- - height, width: Dimensions of generated media
- - seed: Random seed for generation
-
- Returns:
- Union[Image.Image, List[Image.Image]]: Generated watermarked media
- - For image pipelines: Returns a single PIL Image
- - For video pipelines: Returns a list of PIL Images (frames)
-
- Examples:
- ```python
- # Image watermarking
- watermark = AutoWatermark.load('TR', diffusion_config=config)
- image = watermark.generate_watermarked_media(
- input_data="A beautiful landscape",
- guidance_scale=7.5,
- num_inference_steps=50
- )
-
- # Video watermarking (T2V)
- watermark = AutoWatermark.load('VideoShield', diffusion_config=config)
- frames = watermark.generate_watermarked_media(
- input_data="A dog running in a park",
- num_frames=16
- )
-
- # Video watermarking (I2V)
- watermark = AutoWatermark.load('VideoShield', diffusion_config=config)
- frames = watermark.generate_watermarked_media(
- input_data=reference_image,
- num_frames=16
- )
- ```
- """
- # Route to the appropriate generation method based on pipeline type
- ifis_image_pipeline(self.config.pipe):
- ifnotisinstance(input_data,str):
- raiseValueError("For image generation, input_data must be a text prompt (string)")
- returnself._generate_watermarked_image(input_data,*args,**kwargs)
- elifis_video_pipeline(self.config.pipe):
- returnself._generate_watermarked_video(input_data,*args,**kwargs)
-
-
-
-[docs]
- defgenerate_unwatermarked_media(self,
- input_data:Union[str,Image.Image],
- *args,
- **kwargs)->Union[Image.Image,List[Image.Image]]:
-"""
- Generate unwatermarked media (image or video) based on pipeline type.
-
- Args:
- input_data: Text prompt (for T2I or T2V) or input image (for I2V)
- *args: Additional positional arguments
- **kwargs: Additional keyword arguments, including:
- - save_path: Path to save the generated media
-
- Returns:
- Union[Image.Image, List[Image.Image]]: Generated unwatermarked media
- """
- # Route to the appropriate generation method based on pipeline type
- ifis_image_pipeline(self.config.pipe):
- ifnotisinstance(input_data,str):
- raiseValueError("For image generation, input_data must be a text prompt (string)")
- returnself._generate_unwatermarked_image(input_data,*args,**kwargs)
- elifis_video_pipeline(self.config.pipe):
- returnself._generate_unwatermarked_video(input_data,*args,**kwargs)
-
-
-
-[docs]
- defdetect_watermark_in_media(self,
- media:Union[Image.Image,List[Image.Image],np.ndarray,torch.Tensor],
- *args,
- **kwargs)->Dict[str,Any]:
-"""
- Detect watermark in media (image or video).
-
- Args:
- media: The media to detect watermark in (can be PIL image, list of frames, numpy array, or tensor)
- *args: Additional positional arguments
- **kwargs: Additional keyword arguments, including:
- - prompt: Optional text prompt used to generate the media (for some algorithms)
- - num_inference_steps: Optional number of inference steps
- - guidance_scale: Optional guidance scale
- - num_frames: Optional number of frames
- - decoder_inv: Optional decoder inversion
- - inv_order: Inverse order for Exact Inversion
- - detector_type: Type of detector to use
-
- Returns:
- Dict[str, Any]: Detection results with metrics and possibly visualizations
- """
- # Process the input media into the right format based on pipeline type
- processed_media=self._preprocess_media_for_detection(media)
-
- # Route to the appropriate detection method
- ifis_image_pipeline(self.config.pipe):
- returnself._detect_watermark_in_image(
- processed_media,
- *args,
- **kwargs
- )
- else:
- returnself._detect_watermark_in_video(
- processed_media,
- *args,
- **kwargs
- )
-
-
- def_preprocess_media_for_detection(self,
- media:Union[Image.Image,List[Image.Image],np.ndarray,torch.Tensor]
- )->Union[Image.Image,List[Image.Image],torch.Tensor]:
-"""
- Preprocess media for detection based on its type and the pipeline type.
-
- Args:
- media: The media to preprocess
-
- Returns:
- Union[Image.Image, List[Image.Image], torch.Tensor]: Preprocessed media
- """
- ifis_image_pipeline(self.config.pipe):
- ifisinstance(media,Image.Image):
- returnmedia
- elifisinstance(media,np.ndarray):
- returncv2_to_pil(media)
- elifisinstance(media,torch.Tensor):
- # Convert tensor to PIL image
- ifmedia.dim()==3:# C, H, W
- media=media.unsqueeze(0)# Add batch dimension
- img_np=torch_to_numpy(media)[0]# Take first image
- returncv2_to_pil(img_np)
- elifisinstance(media,list):# Compatible for detection pipeline
- returnmedia[0]
- else:
- raiseValueError(f"Unsupported media type for image pipeline: {type(media)}")
- else:
- # Video pipeline
- ifisinstance(media,list):
- # List of frames
- ifall(isinstance(frame,Image.Image)forframeinmedia):
- returnmedia
- elifall(isinstance(frame,np.ndarray)forframeinmedia):
- return[cv2_to_pil(frame)forframeinmedia]
- else:
- raiseValueError("All frames must be either PIL images or numpy arrays")
- elifisinstance(media,np.ndarray):
- # Convert numpy video to list of PIL images
- ifmedia.ndim==4:# F, H, W, C
- return[cv2_to_pil(frame)forframeinmedia]
- else:
- raiseValueError(f"Unsupported numpy array shape for video: {media.shape}")
- elifisinstance(media,torch.Tensor):
- # Convert tensor to list of PIL images
- ifmedia.dim()==5:# B, C, F, H, W
- video_np=torch_to_numpy(media)[0]# Take first batch
- return[cv2_to_pil(frame)forframeinvideo_np]
- elifmedia.dim()==4andmedia.shape[0]>3:# F, C, H, W (assuming F > 3)
- frames=[]
- foriinrange(media.shape[0]):
- frame_np=torch_to_numpy(media[i].unsqueeze(0))[0]
- frames.append(cv2_to_pil(frame_np))
- returnframes
- else:
- raiseValueError(f"Unsupported tensor shape for video: {media.shape}")
- else:
- raiseValueError(f"Unsupported media type for video pipeline: {type(media)}")
-
- def_generate_watermarked_image(self,
- prompt:str,
- *args,
- **kwargs)->Image.Image:
-"""
- Generate watermarked image from text prompt.
-
- Parameters:
- prompt (str): The input prompt.
-
- Returns:
- Image.Image: The generated watermarked image.
-
- Raises:
- ValueError: If the pipeline doesn't support image generation.
- """
- ifself.pipeline_type!=PIPELINE_TYPE_IMAGE:
- raiseValueError(f"This pipeline ({self.pipeline_type}) does not support image generation. Use generate_watermarked_video instead.")
-
- # The implementation depends on the specific watermarking algorithm
- # This method should be implemented by subclasses
- raiseNotImplementedError("This method is not implemented for this watermarking algorithm.")
-
- def_generate_watermarked_video(self,
- input_data:Union[str,Image.Image],
- *args,
- **kwargs)->Union[List[Image.Image],Image.Image]:
-"""
- Generate watermarked video based on text prompt or input image.
-
- Parameters:
- input_data (Union[str, Image.Image]): Either a text prompt (for T2V) or an input image (for I2V).
- - If the pipeline is T2V, input_data should be a string prompt.
- - If the pipeline is I2V, input_data should be an Image object or can be passed as kwargs['input_image'].
- kwargs:
- - 'input_image': The input image for I2V pipelines.
- - 'prompt': The text prompt for T2V pipelines.
- - 'image_path': The path to the input image for I2V pipelines.
-
- Returns:
- Union[List[Image.Image], Image.Image]: The generated watermarked video frames.
-
- Raises:
- ValueError: If the pipeline doesn't support video generation or if input type is incompatible.
- """
- ifnotis_video_pipeline(self.config.pipe):
- raiseValueError(f"This pipeline ({self.pipeline_type}) does not support video generation. Use generate_watermarked_image instead.")
-
- # The implementation depends on the specific watermarking algorithm
- # This method should be implemented by subclasses
- raiseNotImplementedError("This method is not implemented for this watermarking algorithm.")
-
- def_generate_unwatermarked_image(self,prompt:str,*args,**kwargs)->Image.Image:
-"""
- Generate unwatermarked image from text prompt.
-
- Parameters:
- prompt (str): The input prompt.
-
- Returns:
- Image.Image: The generated unwatermarked image.
-
- Raises:
- ValueError: If the pipeline doesn't support image generation.
- """
- ifnotis_image_pipeline(self.config.pipe):
- raiseValueError(f"This pipeline ({self.pipeline_type}) does not support image generation. Use generate_unwatermarked_video instead.")
-
- # Construct generation parameters
- generation_params={
- "num_images_per_prompt":self.config.num_images,
- "guidance_scale":self.config.guidance_scale,
- "num_inference_steps":self.config.num_inference_steps,
- "height":self.config.image_size[0],
- "width":self.config.image_size[1],
- "latents":self.config.init_latents,
- }
-
- # Add parameters from config.gen_kwargs
- ifhasattr(self.config,"gen_kwargs")andself.config.gen_kwargs:
- forkey,valueinself.config.gen_kwargs.items():
- ifkeynotingeneration_params:
- generation_params[key]=value
-
- # Use kwargs to override default parameters
- forkey,valueinkwargs.items():
- generation_params[key]=value
-
- set_random_seed(self.config.gen_seed)
- returnself.config.pipe(
- prompt,
- **generation_params
- ).images[0]
-
- def_generate_unwatermarked_video(self,input_data:Union[str,Image.Image],*args,**kwargs)->List[Image.Image]:
-"""
- Generate unwatermarked video based on text prompt or input image.
-
- Parameters:
- input_data (Union[str, Image.Image]): Either a text prompt (for T2V) or an input image (for I2V).
- - If the pipeline is T2V, input_data should be a string prompt.
- - If the pipeline is I2V, input_data should be an Image object or can be passed as kwargs['input_image'].
- kwargs:
- - 'input_image': The input image for I2V pipelines.
- - 'prompt': The text prompt for T2V pipelines.
- - 'image_path': The path to the input image for I2V pipelines.
-
- Returns:
- List[Image.Image]: The generated unwatermarked video frames.
-
- Raises:
- ValueError: If the pipeline doesn't support video generation or if input type is incompatible.
- """
- ifnotis_video_pipeline(self.config.pipe):
- raiseValueError(f"This pipeline ({self.pipeline_type}) does not support video generation. Use generate_unwatermarked_image instead.")
-
- # Handle Text-to-Video pipeline
- ifis_t2v_pipeline(self.config.pipe):
- # For T2V, input should be a text prompt
- ifnotisinstance(input_data,str):
- raiseValueError("Text-to-Video pipeline requires a text prompt as input_data")
-
- # Construct generation parameters
- generation_params={
- "latents":self.config.init_latents,
- "num_frames":self.config.num_frames,
- "height":self.config.image_size[0],
- "width":self.config.image_size[1],
- "num_inference_steps":self.config.num_inference_steps,
- "guidance_scale":self.config.guidance_scale,
- }
-
- # Add parameters from config.gen_kwargs
- ifhasattr(self.config,"gen_kwargs")andself.config.gen_kwargs:
- forkey,valueinself.config.gen_kwargs.items():
- ifkeynotingeneration_params:
- generation_params[key]=value
-
- # Use kwargs to override default parameters
- forkey,valueinkwargs.items():
- generation_params[key]=value
-
- # Generate the video
- set_random_seed(self.config.gen_seed)
- output=self.config.pipe(
- input_data,# Use prompt
- **generation_params
- )
-
- # 根据测试结果,我们知道 TextToVideoSDPipeline 的输出有 frames 属性
- ifhasattr(output,'frames'):
- frames=output.frames[0]
- elifhasattr(output,'videos'):
- frames=output.videos[0]
- else:
- frames=output[0]ifisinstance(output,tuple)elseoutput
-
- # Convert frames to PIL images
- frame_list=[cv2_to_pil(frame)forframeinframes]
- returnframe_list
-
- # Handle Image-to-Video pipeline
- elifis_i2v_pipeline(self.config.pipe):
- # For I2V, input should be an image, text prompt is optional
- input_image=None
- text_prompt=None
-
- # Check if input_data is an image passed via kwargs
- if"input_image"inkwargsandisinstance(kwargs["input_image"],Image.Image):
- input_image=kwargs["input_image"]
-
- # Check if input_data is an image
- elifisinstance(input_data,Image.Image):
- input_image=input_data
-
- # If input_data is a string but we need an image, check if an image path was provided
- elifisinstance(input_data,str):
- importos
- fromPILimportImageasPILImage
-
- ifos.path.exists(input_data):
- try:
- input_image=PILImage.open(input_data).convert("RGB")
- exceptExceptionase:
- raiseValueError(f"Input data is neither an Image object nor a valid image path. Failed to load image from path: {e}")
- else:
- # Treat as text prompt if no valid image path
- text_prompt=input_data
- ifinput_imageisNone:
- raiseValueError("Input image is required for Image-to-Video pipeline")
-
- # Construct generation parameters
- generation_params={
- "image":input_image,
- "height":self.config.image_size[0],
- "width":self.config.image_size[1],
- "num_frames":self.config.num_frames,
- "latents":self.config.init_latents,
- "num_inference_steps":self.config.num_inference_steps,
- "max_guidance_scale":self.config.guidance_scale,
- "output_type":"np",
- }
- # In I2VGen-XL, the text prompt is needed
- iftext_promptisnotNone:
- generation_params["prompt"]=text_prompt
-
- # Add parameters from config.gen_kwargs
- ifhasattr(self.config,"gen_kwargs")andself.config.gen_kwargs:
- forkey,valueinself.config.gen_kwargs.items():
- ifkeynotingeneration_params:
- generation_params[key]=value
-
- # Use kwargs to override default parameters
- forkey,valueinkwargs.items():
- generation_params[key]=value
-
- # Generate the video
- set_random_seed(self.config.gen_seed)
- video=self.config.pipe(
- **generation_params
- ).frames[0]
-
- # Convert frames to PIL images
- frame_list=[cv2_to_pil(frame)forframeinvideo]
- returnframe_list
-
- # This should never happen since we already checked pipeline type
- raiseNotImplementedError(f"Unsupported video pipeline type: {self.pipeline_type}")
-
- def_detect_watermark_in_video(self,
- video_frames:List[Image.Image],
- *args,
- **kwargs)->Dict[str,Any]:
-"""
- Detect watermark in video frames.
-
- Args:
- video_frames: List of video frames as PIL images
- kwargs:
- - 'prompt': Optional text prompt used for generation (for T2V pipelines)
- - 'reference_image': Optional reference image (for I2V pipelines)
- - 'guidance_scale': The guidance scale for the detector (optional)
- - 'detector_type': The type of detector to use (optional)
- - 'num_inference_steps': Number of inference steps for inversion (optional)
- - 'num_frames': Number of frames to use for detection (optional for I2V pipelines)
- - 'decoder_inv': Whether to use decoder inversion (optional)
- - 'inv_order': Inverse order for Exact Inversion (optional)
-
- Returns:
- Dict[str, Any]: Detection results
-
- Raises:
- NotImplementedError: If the watermarking algorithm doesn't support video watermark detection
- """
- raiseNotImplementedError("Video watermark detection is not implemented for this algorithm")
-
- def_detect_watermark_in_image(self,
- image:Image.Image,
- prompt:str="",
- *args,
- **kwargs)->Dict[str,float]:
-"""
- Detect watermark in image.
-
- Args:
- image (Image.Image): The input image.
- prompt (str): The prompt used for generation.
- kwargs:
- - 'guidance_scale': The guidance scale for the detector.
- - 'detector_type': The type of detector to use.
- - 'num_inference_steps': Number of inference steps for inversion.
- - 'decoder_inv': Whether to use decoder inversion.
- - 'inv_order': Inverse order for Exact Inversion.
-
- Returns:
- Dict[str, float]: The detection result.
- """
- raiseNotImplementedError("Watermark detection in image is not implemented for this algorithm")
-
-
-[docs]
- @abstractmethod
- defget_data_for_visualize(self,media,*args,**kwargs):
-"""Get data for visualization."""
- pass
-[docs]
-classROBINUtils:
-"""Utility class for ROBIN algorithm, contains helper functions."""
-
-
-[docs]
- def__init__(self,config:ROBINConfig,*args,**kwargs)->None:
-"""
- Initialize the ROBIN watermarking algorithm.
-
- Parameters:
- config (ROBINConfig): Configuration for the ROBIN algorithm.
- """
- self.config=config
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from..baseimportBaseWatermark,BaseConfig
-importtorch
-fromtypingimportList
-fromutils.utilsimportset_random_seed
-fromutils.diffusion_configimportDiffusionConfig
-fromvisualize.data_for_visualizationimportDataForVisualization
-fromtransformersimportBlip2Processor,Blip2ForConditionalGeneration
-fromsentence_transformersimportSentenceTransformer
-fromPILimportImage
-importmath
-fromdetection.seal.seal_detectionimportSEALDetector
-fromutils.media_utilsimport*
-
-
-[docs]
-classSEALConfig(BaseConfig):
-"""Config class for SEAL algorithm."""
-
-
-[docs]
-classVideoMarkConfig(BaseConfig):
-"""Config class for VideoMark algorithm."""
-
-
-[docs]
- definitialize_parameters(self)->None:
-"""Initialize algorithm-specific parameters."""
- self.fpr=self.config_dict['fpr']
- self.t=self.config_dict['prc_t']
- self.var=self.config_dict['var']
- self.threshold=self.config_dict['threshold']
- self.sequence_length=self.config_dict['sequence_length']# Length of the watermark sequence
- self.message_length=self.config_dict['message_length']# Number of bits in each sequence
- self.message_sequence=np.random.randint(0,2,size=(self.sequence_length,self.message_length))# <= 512 bits for robustness
- self.shift=np.random.default_rng().integers(0,self.sequence_length-self.num_frames)
- self.message=self.message_sequence[self.shift:self.shift+self.num_frames]
- self.latents_height=self.image_size[0]//self.pipe.vae_scale_factor
- self.latents_width=self.image_size[1]//self.pipe.vae_scale_factor
- self.latents_channel=self.pipe.unet.config.in_channels
- self.n=self.latents_height*self.latents_width*self.latents_channel# Dimension of the latent space
- self.GF=galois.GF(2)
-
- # Seeds for key generation
- self.gen_matrix_seed=self.config_dict['keygen']['gen_matrix_seed']
- self.indice_seed=self.config_dict['keygen']['indice_seed']
- self.one_time_pad_seed=self.config_dict['keygen']['one_time_pad_seed']
- self.test_bits_seed=self.config_dict['keygen']['test_bits_seed']
- self.permute_bits_seed=self.config_dict['keygen']['permute_bits_seed']
-
- # Seeds for encoding
- self.payload_seed=self.config_dict['encode']['payload_seed']
- self.error_seed=self.config_dict['encode']['error_seed']
- self.pseudogaussian_seed=self.config_dict['encode']['pseudogaussian_seed']
-
-
- @property
- defalgorithm_name(self)->str:
-"""Return the algorithm name."""
- return'VideoMark'
-
- def_get_message(length:int,window:int,seed=None)->int:
-"""Return a random start index for a subarray of size `window` in array of size `length`."""
- rng=np.random.default_rng()
- returnrng.integers(0,length-window)
-
-
-
-[docs]
-classVideoMarkUtils:
-"""Utility class for VideoMark algorithm."""
-
-
-# Copyright 2025 THU-BPM MarkDiffusion.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-importtorch
-importhashlib
-importnumpyasnp
-importlogging
-fromtypingimportDict,Any,Union,List,Optional
-fromPILimportImage
-fromutils.media_utilsimport*
-fromutils.utilsimportload_config_file,set_random_seed
-fromutils.diffusion_configimportDiffusionConfig
-fromutils.media_utilsimporttransform_to_model_format,get_media_latents
-fromwatermark.baseimportBaseConfig,BaseWatermark
-fromexceptions.exceptionsimportAlgorithmNameMismatchError
-fromdetection.wind.wind_detectionimportWINDetector
-fromvisualize.data_for_visualizationimportDataForVisualization
-
-logger=logging.getLogger(__name__)
-
-
This module adapts the official GaussMarker detection pipeline to the
-MarkDiffusion detection API. It evaluates recovered diffusion latents to
-decide whether a watermark is present, reporting both hard decisions and
-auxiliary scores (bit/message accuracies, frequency-domain distances).
This class analyzes the quality of images by directly comparing the characteristics
-of watermarked images with unwatermarked images. It evaluates metrics such as PSNR,
-SSIM, LPIPS, FID, BRISQUE without the need for any external reference image.
-
Use this pipeline to assess the impact of watermarking on image quality directly.
This pipeline assesses image quality by comparing both watermarked and unwatermarked
-images against a common reference image. It measures the degree of similarity or
-deviation from the reference.
-
Ideal for scenarios where the impact of watermarking on image quality needs to be
-assessed, particularly in relation to specific reference images or ground truth.
This pipeline analyzes quality metrics that require comparing distributions
-of multiple images (e.g., FID). It generates all images upfront and then
-performs a single analysis on the entire collection.
This pipeline analyzes diversity metrics by generating multiple images
-for each prompt (e.g., LPIPS diversity). It generates multiple versions
-per prompt and analyzes the diversity within each group.
This pipeline directly compares watermarked and unwatermarked images
-to compute metrics like PSNR, SSIM, VIF, FSIM and MS-SSIM. The analyzer receives
-both images and outputs a single comparison score.
Calculator for fundamental success rates of watermark detection.
-
This class specifically handles the calculation of success rates for scenarios involving
-watermark detection after fixed thresholding. It provides metrics based on comparisons
-between expected watermarked results and actual detection outputs.
-
Use this class when you need to evaluate the effectiveness of watermark detection algorithms
-under fixed thresholding conditions.
Inception Score (IS) calculator for evaluating image generation quality.
-
Inception Score measures both the quality and diversity of generated images
-by evaluating how confidently an Inception model can classify them and how
-diverse the predictions are across the image set.
-
Higher IS indicates better image quality and diversity (typical range: 1-10+).
device (str) – Device to run the model on (“cuda” or “cpu”)
-
batch_size (int) – Batch size for processing images
-
splits (int) – Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison.
-For calculating the mean and standard error of IS, the splits should be set greater than 1.
-If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0)
reference (Union[Image, str]) – Reference image or text for comparison
-- If reference_source is ‘image’: expects PIL Image
-- If reference_source is ‘text’: expects string
device (str) – Device to run the model on (“cuda” or “cpu”)
-
batch_size (int) – Batch size for processing images
-
splits (int) – Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison.
-For calculating the mean and standard error of FID, the splits should be set greater than 1.
-If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0)
Natural Image Quality Evaluator (NIQE) for no-reference image quality assessment.
-
NIQE evaluates image quality based on deviations from natural scene statistics.
-It uses a pre-trained model of natural image statistics to assess quality without
-requiring reference images.
VIF (Visual Information Fidelity) analyzer using piq.
-
VIF compares a distorted image with a reference image to
-quantify the amount of visual information preserved.
-Higher VIF indicates better quality/similarity.
-Typical range: 0 ~ 1 (sometimes higher for good quality).
Analyzer for evaluating subject consistency across video frames using DINO features.
-
This analyzer measures how consistently the main subject appears across frames by:
-1. Extracting DINO features from each frame
-2. Computing cosine similarity between consecutive frames and with the first frame
-3. Averaging these similarities to get a consistency score
Analyzer for evaluating motion smoothness in videos using AMT-S model.
-
This analyzer measures motion smoothness by:
-1. Extracting frames at even indices from the video
-2. Using AMT-S model to interpolate between consecutive frames
-3. Comparing interpolated frames with actual frames to compute smoothness score
-
The score represents how well the motion can be predicted/interpolated,
-with smoother motion resulting in higher scores.
Analyzer for evaluating dynamic degree (motion intensity) in videos using RAFT optical flow.
-
This analyzer measures the amount and intensity of motion in videos by:
-1. Computing optical flow between consecutive frames using RAFT
-2. Calculating flow magnitude for each pixel
-3. Extracting top 5% highest flow magnitudes
-4. Determining if video has sufficient dynamic motion based on thresholds
-
The score represents whether the video contains dynamic motion (1.0) or is mostly static (0.0).
Analyzer for evaluating background consistency across video frames using CLIP features.
-
This analyzer measures how consistently the background appears across frames by:
-1. Extracting CLIP visual features from each frame
-2. Computing cosine similarity between consecutive frames and with the first frame
-3. Averaging these similarities to get a consistency score
-
Similar to SubjectConsistencyAnalyzer but focuses on overall visual consistency
-including background elements, making it suitable for detecting background stability.
Analyzer for evaluating imaging quality of videos.
-
This analyzer measures the quality of videos by:
-1. Inputting frames into MUSIQ image quality predictor
-2. Determining if the video is blurry or has artifacts
-
The score represents the quality of the video (higher is better).
Transform image or video frames to model input format.
-For image, media is a PIL image that will be resized to target_size`(if provided) and then normalized to [-1, 1] and permuted to [C, H, W] from [H, W, C].
-For video, `media is a list of frames (PIL images or numpy arrays) that will be normalized to [-1, 1] and permuted to [F, C, H, W] from [F, H, W, C].
-
-
Parameters:
-
-
media (Union[Image, List[Image], ndarray, Tensor]) – PIL image or list of frames or video tensor
-
target_size (Optional[int]) – Target size for resize operations (for images)
Inherit docstrings from base classes to methods without docstrings.
-
This decorator automatically applies the docstring from a base class method
-to a derived class method if the derived method doesn’t have its own docstring.
-
-
Parameters:
-
cls – The class to enhance with inherited docstrings
This visualizer handles watermark visualization for VideoShield algorithm,
-which extends Gaussian Shading to the video domain by adding frame dimensions.
This visualizer handles watermark visualization for VideoShield algorithm,
-which extends Gaussian Shading to the video domain by adding frame dimensions.
This is a generic watermark class that will be instantiated as one of the watermark classes of the library when
-created with the [AutoWatermark.load] class method.
-
This class cannot be instantiated directly using __init__() (throws an error).
Generate watermarked media (image or video) based on pipeline type.
-
This is the main interface for generating watermarked content with any
-watermarking algorithm. It automatically routes to the appropriate generation
-method based on the pipeline type (image or video).
-
-
Parameters:
-
-
input_data (Union[str, Image]) – Text prompt (for T2I or T2V) or input image (for I2V)
-
*args – Additional positional arguments
-
**kwargs – Additional keyword arguments, including:
-- guidance_scale: Guidance scale for generation
-- num_inference_steps: Number of inference steps
-- height, width: Dimensions of generated media
-- seed: Random seed for generation
-
-
-
Returns:
-
Generated watermarked media
-- For image pipelines: Returns a single PIL Image
-- For video pipelines: Returns a list of PIL Images (frames)
media (Union[Image, List[Image], ndarray, Tensor]) – The media to detect watermark in (can be PIL image, list of frames, numpy array, or tensor)
-
*args – Additional positional arguments
-
**kwargs – Additional keyword arguments, including:
-- prompt: Optional text prompt used to generate the media (for some algorithms)
-- num_inference_steps: Optional number of inference steps
-- guidance_scale: Optional guidance scale
-- num_frames: Optional number of frames
-- decoder_inv: Optional decoder inversion
-- inv_order: Inverse order for Exact Inversion
-- detector_type: Type of detector to use
-
-
-
Returns:
-
Detection results with metrics and possibly visualizations
This method generates the necessary data for visualizing VideoShield watermarks,
-including original watermarked latents and reversed latents from inversion.
-
-
Parameters:
-
-
image – The image to visualize watermarks for (can be None for generation only)
-
prompt (str) – The text prompt used for generation
-
guidance_scale (float) – Guidance scale for generation and inversion
-
-
-
Return type:
-
DataForVisualization
-
-
Returns:
-
DataForVisualization object containing visualization data
This method generates the necessary data for visualizing VideoMark watermarks,
-including original watermarked latents and reversed latents from inversion.
-
-
Parameters:
-
-
image – The image to visualize watermarks for (can be None for generation only)
-
prompt (str) – The text prompt used for generation
-
guidance_scale (float) – Guidance scale for generation and inversion
-
-
-
Return type:
-
DataForVisualization
-
-
Returns:
-
DataForVisualization object containing visualization data
@article{pan2025markdiffusion,
-title={MarkDiffusion: An Open-Source Toolkit for Generative Watermarking of Latent Diffusion Models},
-author={Pan, Leyi and Guan, Sheng and Fu, Zheyu and Si, Luyang and Wang, Zian and Hu, Xuming and King, Irwin and Yu, Philip S and Liu, Aiwei and Wen, Lijie},
-journal={arXiv preprint arXiv:2509.10569},
-year={2025}
-}
-
Pan, Leyi, et al. “MarkDiffusion: An Open-Source Toolkit for Generative Watermarking of Latent Diffusion Models.”
-arXiv preprint arXiv:2509.10569 (2025).
@misc{wen2023treeringwatermarksfingerprintsdiffusion,
-title={Tree-Ring Watermarks: Fingerprints for Diffusion Images that are Invisible and Robust},
-author={Yuxin Wen and John Kirchenbauer and Jonas Geiping and Tom Goldstein},
-year={2023},
-eprint={2305.20030},
-archivePrefix={arXiv},
-primaryClass={cs.LG},
-url={https://arxiv.org/abs/2305.20030},
-}
-
@article{ci2024ringid,
-title={RingID: Rethinking Tree-Ring Watermarking for Enhanced Multi-Key Identification},
-author={Ci, Hai and Yang, Pei and Song, Yiren and Shou, Mike Zheng},
-journal={arXiv preprint arXiv:2404.14055},
-year={2024}
-}
-
@inproceedings{huangrobin,
-title={ROBIN: Robust and Invisible Watermarks for Diffusion Models with Adversarial Optimization},
-author={Huang, Huayang and Wu, Yu and Wang, Qian},
-booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
-}
-
@article{arabi2024hidden,
-title={Hidden in the Noise: Two-Stage Robust Watermarking for Images},
-author={Arabi, Kasra and Feuer, Benjamin and Witter, R Teal and Hegde, Chinmay and Cohen, Niv},
-journal={arXiv preprint arXiv:2412.04653},
-year={2024}
-}
-
@inproceedings{lee2025semantic,
-title={Semantic Watermarking Reinvented: Enhancing Robustness and Generation Quality with Fourier Integrity},
-author={Lee, Sung Ju and Cho, Nam Ik},
-booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
-pages={18759--18769},
-year={2025}
-}
-
@article{yang2024gaussian,
-title={Gaussian Shading: Provable Performance-Lossless Image Watermarking for Diffusion Models},
-author={Yang, Zijin and Zeng, Kai and Chen, Kejiang and Fang, Han and Zhang, Weiming and Yu, Nenghai},
-journal={arXiv preprint arXiv:2404.04956},
-year={2024},
-}
-
@inproceedings{hu2025videoshield,
-title={VideoShield: Regulating Diffusion-based Video Generation Models via Watermarking},
-author={Runyi Hu and Jie Zhang and Yiming Li and Jiwei Li and Qing Guo and Han Qiu and Tianwei Zhang},
-booktitle={International Conference on Learning Representations (ICLR)},
-year={2025}
-}
-
@article{hu2025videomark,
-title={VideoMark: A Distortion-Free Robust Watermarking Framework for Video Diffusion Models},
-author={Hu, Xuming and Li, Hanqian and Li, Jungang and Liu, Aiwei},
-journal={arXiv preprint arXiv:2504.16359},
-year={2025}
-}
-
“This research utilized MarkDiffusion [1], an open-source toolkit for generative
-watermarking. We specifically employed the Gaussian-Shading algorithm [2] for
-watermark embedding and detection.”
This document provides technical guidelines for contributing to MarkDiffusion. For general contribution
-workflow (forking, cloning, creating branches, submitting PRs), please refer to the
-Contributing Guidelines in the repository root.
All participants are expected to follow our Code of Conduct.
-Please be respectful, constructive, and help create a welcoming environment for everyone.
-
Reporting Issues
-
For bug reports and feature requests, please use the appropriate templates configured in the GitHub repository.
MarkDiffusion is an open-source Python toolkit for generative watermarking of latent diffusion models.
-As the use of diffusion-based generative models expands, ensuring the authenticity and origin of generated
-media becomes critical. MarkDiffusion simplifies the access, understanding, and assessment of watermarking
-technologies, making it accessible to both researchers and the broader community.
-
-
Note
-
If you are interested in LLM watermarking (text watermark), please refer to the
-MarkLLM toolkit from our group.
The toolkit includes custom visualization tools that enable clear and insightful views into
-how different watermarking algorithms operate under various scenarios.
-
-
📊 Comprehensive Evaluation Module
With 24 evaluation tools covering detectability, robustness, and impact on output quality,
-MarkDiffusion provides comprehensive assessment capabilities with 8 automated evaluation pipelines.
MarkDiffusion uses pre-trained models stored on Hugging Face. Download the required models:
-
# The models will be downloaded to the ckpts/ directory
-# Visit: https://huggingface.co/Generative-Watermark-Toolkits
-
-
-
For each algorithm you plan to use, download the corresponding model weights from the
-Generative-Watermark-Toolkits
-repository and place them in the appropriate ckpts/ subdirectory.
If you prefer using Conda for environment management:
-
# Create a new conda environment
-condacreate-nmarkdiffusionpython=3.10
-condaactivatemarkdiffusion
-
-# Install PyTorch with CUDA support
-condainstallpytorchtorchvisiontorchaudiopytorch-cuda=12.6-cpytorch-cnvidia
-
-# Install other dependencies
-pipinstall-rrequirements.txt
-
SEAL uses pre-trained models from Hugging Face for caption generation and embedding:
-
-
BLIP2 Model: For generating image captions (blip2-flan-t5-xl)
-
Sentence Transformer: For caption embedding
-
-
Setup:
-
# These models will be automatically downloaded from Hugging Face on first use
-# Or you can pre-download them:
-
-# Download BLIP2 model
-python-c"from transformers import Blip2Processor, Blip2ForConditionalGeneration; \
-Blip2Processor.from_pretrained('Salesforce/blip2-flan-t5-xl'); \
-Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-flan-t5-xl')"
-
-# Download sentence transformer (if using custom fine-tuned model)
-# Update config/SEAL.json paths accordingly
-
-
-
Configuration (in config/SEAL.json):
-
-
cap_processor: Path or model name for BLIP2 processor
-
cap_model: Path or model name for BLIP2 model
-
sentence_model: Path or model name for sentence transformer
-
-
-
Note
-
SEAL models are large (~15GB for BLIP2). Ensure you have sufficient disk space and memory.
# Define your prompt
-prompt="A beautiful landscape with mountains and a lake at sunset"
-
-# Generate watermarked image
-watermarked_image=watermark.generate_watermarked_media(prompt)
-
-# Save the image
-watermarked_image.save("watermarked_output.png")
-
-# Display the image
-watermarked_image.show()
-
Tree-Ring embeds circular patterns in the Fourier domain of initial latents, making them invisible in the spatial domain but detectable through frequency analysis.
MarkDiffusion provides comprehensive evaluation tools to assess watermark performance across three key dimensions: detectability, robustness, and output quality.