diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index e18a258a0f..a733a37b65 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -28,6 +28,8 @@ import type { ImageToImageTaskHelper, ImageToTextTaskHelper, ImageToVideoTaskHelper, + ImageTextToImageTaskHelper, + ImageTextToVideoTaskHelper, ObjectDetectionTaskHelper, QuestionAnsweringTaskHelper, SentenceSimilarityTaskHelper, @@ -276,6 +278,14 @@ export function getProviderHelper( provider: InferenceProviderOrPolicy, task: "image-to-video" ): ImageToVideoTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProviderOrPolicy, + task: "image-text-to-image" +): ImageTextToImageTaskHelper & TaskProviderHelper; +export function getProviderHelper( + provider: InferenceProviderOrPolicy, + task: "image-text-to-video" +): ImageTextToVideoTaskHelper & TaskProviderHelper; export function getProviderHelper( provider: InferenceProviderOrPolicy, task: "sentence-similarity" diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index c92d106317..1766f597f1 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -19,6 +19,8 @@ import type { ImageToTextInput, ImageToTextOutput, ImageToVideoInput, + ImageTextToImageInput, + ImageTextToVideoInput, ObjectDetectionInput, ObjectDetectionOutput, QuestionAnsweringInput, @@ -54,6 +56,8 @@ import { toArray } from "../utils/toArray.js"; import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition.js"; import type { ImageToVideoArgs } from "../tasks/cv/imageToVideo.js"; +import type { ImageTextToImageArgs } from "../tasks/cv/imageTextToImage.js"; +import type { ImageTextToVideoArgs } from "../tasks/cv/imageTextToVideo.js"; import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js"; /** @@ -159,6 +163,18 @@ export interface ImageToVideoTaskHelper { preparePayloadAsync(args: ImageToVideoArgs): Promise; } +export interface ImageTextToImageTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; + preparePayloadAsync(args: ImageTextToImageArgs): Promise; +} + +export interface ImageTextToVideoTaskHelper { + getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; + preparePayload(params: BodyParams): Record; + preparePayloadAsync(args: ImageTextToVideoArgs): Promise; +} + export interface ImageSegmentationTaskHelper { getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise; preparePayload(params: BodyParams): Record | BodyInit; diff --git a/packages/inference/src/snippets/getInferenceSnippets.ts b/packages/inference/src/snippets/getInferenceSnippets.ts index 7b18a3c8c8..d9ecdbc36b 100644 --- a/packages/inference/src/snippets/getInferenceSnippets.ts +++ b/packages/inference/src/snippets/getInferenceSnippets.ts @@ -96,7 +96,10 @@ const HF_PYTHON_METHODS: Partial> = { "image-classification": "image_classification", "image-segmentation": "image_segmentation", "image-to-image": "image_to_image", + "image-to-video": "image_to_video", "image-to-text": "image_to_text", + "image-text-to-image": "image_text_to_image", + "image-text-to-video": "image_text_to_video", "object-detection": "object_detection", "question-answering": "question_answering", "sentence-similarity": "sentence_similarity", @@ -390,7 +393,9 @@ const snippets: Partial< "fill-mask": snippetGenerator("basic"), "image-classification": snippetGenerator("basicImage"), "image-segmentation": snippetGenerator("basicImage"), + "image-text-to-image": snippetGenerator("imageToImage", prepareImageToImageInput), "image-text-to-text": snippetGenerator("conversational"), + "image-text-to-video": snippetGenerator("imageToVideo", prepareImageToImageInput), "image-to-image": snippetGenerator("imageToImage", prepareImageToImageInput), "image-to-text": snippetGenerator("basicImage"), "image-to-video": snippetGenerator("imageToVideo", prepareImageToImageInput), diff --git a/packages/inference/src/tasks/cv/imageTextToImage.ts b/packages/inference/src/tasks/cv/imageTextToImage.ts new file mode 100644 index 0000000000..a621e6ef68 --- /dev/null +++ b/packages/inference/src/tasks/cv/imageTextToImage.ts @@ -0,0 +1,22 @@ +import type { ImageTextToImageInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; +import { getProviderHelper } from "../../lib/getProviderHelper.js"; +import type { BaseArgs, Options } from "../../types.js"; +import { innerRequest } from "../../utils/request.js"; + +export type ImageTextToImageArgs = BaseArgs & ImageTextToImageInput; + +/** + * This task takes an image and text input and outputs a new generated image. + * Recommended model: black-forest-labs/FLUX.2-dev + */ +export async function imageTextToImage(args: ImageTextToImageArgs, options?: Options): Promise { + const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); + const providerHelper = getProviderHelper(provider, "image-text-to-image"); + const payload = await providerHelper.preparePayloadAsync(args); + const { data: res, requestContext } = await innerRequest(payload, providerHelper, { + ...options, + task: "image-text-to-image", + }); + return providerHelper.getResponse(res, requestContext.url, requestContext.info.headers as Record); +} diff --git a/packages/inference/src/tasks/cv/imageTextToVideo.ts b/packages/inference/src/tasks/cv/imageTextToVideo.ts new file mode 100644 index 0000000000..7f71f02984 --- /dev/null +++ b/packages/inference/src/tasks/cv/imageTextToVideo.ts @@ -0,0 +1,22 @@ +import type { ImageTextToVideoInput } from "@huggingface/tasks"; +import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; +import { getProviderHelper } from "../../lib/getProviderHelper.js"; +import type { BaseArgs, Options } from "../../types.js"; +import { innerRequest } from "../../utils/request.js"; + +export type ImageTextToVideoArgs = BaseArgs & ImageTextToVideoInput; + +/** + * This task takes an image and text input and outputs a generated video. + * Recommended model: Lightricks/LTX-Video + */ +export async function imageTextToVideo(args: ImageTextToVideoArgs, options?: Options): Promise { + const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); + const providerHelper = getProviderHelper(provider, "image-text-to-video"); + const payload = await providerHelper.preparePayloadAsync(args); + const { data: res, requestContext } = await innerRequest(payload, providerHelper, { + ...options, + task: "image-text-to-video", + }); + return providerHelper.getResponse(res, requestContext.url, requestContext.info.headers as Record); +} diff --git a/packages/inference/src/tasks/index.ts b/packages/inference/src/tasks/index.ts index a39779a227..685091a4c2 100644 --- a/packages/inference/src/tasks/index.ts +++ b/packages/inference/src/tasks/index.ts @@ -14,6 +14,8 @@ export * from "./cv/imageSegmentation.js"; export * from "./cv/imageToImage.js"; export * from "./cv/imageToText.js"; export * from "./cv/imageToVideo.js"; +export * from "./cv/imageTextToImage.js"; +export * from "./cv/imageTextToVideo.js"; export * from "./cv/objectDetection.js"; export * from "./cv/textToImage.js"; export * from "./cv/textToVideo.js"; diff --git a/packages/tasks/src/pipelines.ts b/packages/tasks/src/pipelines.ts index 293c25b664..36c9c90fca 100644 --- a/packages/tasks/src/pipelines.ts +++ b/packages/tasks/src/pipelines.ts @@ -557,6 +557,14 @@ export const PIPELINE_DATA = { name: "Image-Text-to-Text", modality: "multimodal", }, + "image-text-to-image": { + name: "Image-Text-to-Image", + modality: "multimodal", + }, + "image-text-to-video": { + name: "Image-Text-to-Video", + modality: "multimodal", + }, "visual-question-answering": { name: "Visual Question Answering", subtasks: [ diff --git a/packages/tasks/src/snippets/inputs.ts b/packages/tasks/src/snippets/inputs.ts index e36c05487b..4a55d20a65 100644 --- a/packages/tasks/src/snippets/inputs.ts +++ b/packages/tasks/src/snippets/inputs.ts @@ -94,6 +94,16 @@ const inputsImageToVideo = () => `{ "prompt": "The cat starts to dance" }`; +const inputsImageTextToImage = () => `{ + "image": "cat.png", + "prompt": "Turn the cat into a tiger." +}`; + +const inputsImageTextToVideo = () => `{ + "image": "cat.png", + "prompt": "The cat starts to dance" +}`; + const inputsImageSegmentation = () => `"cats.jpg"`; const inputsObjectDetection = () => `"cats.jpg"`; @@ -130,6 +140,8 @@ const modelInputSnippets: { "image-to-text": inputsImageToText, "image-to-image": inputsImageToImage, "image-to-video": inputsImageToVideo, + "image-text-to-image": inputsImageTextToImage, + "image-text-to-video": inputsImageTextToVideo, "image-segmentation": inputsImageSegmentation, "object-detection": inputsObjectDetection, "question-answering": inputsQuestionAnswering, diff --git a/packages/tasks/src/tasks/image-text-to-image/about.md b/packages/tasks/src/tasks/image-text-to-image/about.md new file mode 100644 index 0000000000..c05a81fcea --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-image/about.md @@ -0,0 +1,73 @@ +## Use Cases + +### Instruction-based Image Editing + +Image-text-to-image models can be used to edit images based on natural language instructions. For example, you can provide an image of a summer landscape and the instruction "Make it winter, add snow" to generate a winter version of the same scene. + +### Style Transfer + +These models can apply artistic styles or transformations to images based on text descriptions. For instance, you can transform a photo into a painting style by providing prompts like "Make it look like a Van Gogh painting" or "Convert to watercolor style." + +### Image Variations + +Generate variations of an existing image by providing different text prompts. This is useful for creative workflows where you want to explore different versions of the same image with specific modifications. + +### Guided Image Generation + +Use a reference image along with text prompts to guide the generation process. This allows for more controlled image generation compared to text-to-image models alone, as the reference image provides structural guidance. + +### Image Inpainting and Outpainting + +Fill in missing or masked parts of an image based on text descriptions, or extend an image beyond its original boundaries with text-guided generation. + +## Task Variants + +### Instruction-based Editing + +Models that follow natural language instructions to edit images, which can perform complex edits like object removal, color changes, and compositional modifications. + +### Reference-guided Generation + +Models that use a reference image to guide the generation process while incorporating text prompts to control specific attributes or modifications. + +### Conditional Image-to-Image + +Models that perform specific transformations based on text conditions, such as changing weather conditions, time of day, or seasonal variations. + +## Inference + +You can use the Diffusers library to interact with image-text-to-image models. + +```python +import torch +from diffusers import Flux2Pipeline +from diffusers.utils import load_image + +repo_id = "black-forest-labs/FLUX.2-dev" +device = "cuda:0" +torch_dtype = torch.bfloat16 + +pipe = Flux2Pipeline.from_pretrained( + repo_id, torch_dtype=torch_dtype +) +pipe.enable_model_cpu_offload() #no need to do cpu offload for >80G VRAM carts like H200, B200, etc. and do a `pipe.to(device)` instead + +prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." + +#cat_image = load_image("https://huggingface.co/spaces/zerogpu-aoti/FLUX.1-Kontext-Dev-fp8-dynamic/resolve/main/cat.png") +image = pipe( + prompt=prompt, + #image=[cat_image] #multi-image input + generator=torch.Generator(device=device).manual_seed(42), + num_inference_steps=50, + guidance_scale=4, +).images[0] + +image.save("flux2_output.png") +``` + +## Useful Resources + +- [FLUX.2 Model Card](https://huggingface.co/black-forest-labs/FLUX.2-dev) +- [Diffusers documentation on Image-to-Image](https://huggingface.co/docs/diffusers/using-diffusers/img2img) +- [ControlNet for Conditional Image Generation](https://huggingface.co/docs/diffusers/using-diffusers/controlnet) diff --git a/packages/tasks/src/tasks/image-text-to-image/data.ts b/packages/tasks/src/tasks/image-text-to-image/data.ts new file mode 100644 index 0000000000..f45410e91c --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-image/data.ts @@ -0,0 +1,54 @@ +import type { TaskDataCustom } from "../index.js"; + +const taskData: TaskDataCustom = { + datasets: [], + demo: { + inputs: [ + { + filename: "image-text-to-image-input.jpeg", + type: "img", + }, + { + label: "Input", + content: "A city above clouds, pastel colors, Victorian style", + type: "text", + }, + ], + outputs: [ + { + filename: "image-text-to-image-output.png", + type: "img", + }, + ], + }, + metrics: [ + { + description: + "The Fréchet Inception Distance (FID) calculates the distance between distributions between synthetic and real samples. A lower FID score indicates better similarity between the distributions of real and generated images.", + id: "FID", + }, + { + description: + "CLIP Score measures the similarity between the generated image and the text prompt using CLIP embeddings. A higher score indicates better alignment with the text prompt.", + id: "CLIP", + }, + ], + models: [ + { + description: "A powerful model for image-text-to-image generation.", + id: "black-forest-labs/FLUX.2-dev", + }, + ], + spaces: [ + { + description: "An application for image-text-to-image generation.", + id: "black-forest-labs/FLUX.2-dev", + }, + ], + summary: + "Image-text-to-image models take an image and a text prompt as input and generate a new image based on the reference image and text instructions. These models are useful for image editing, style transfer, image variations, and guided image generation tasks.", + widgetModels: ["black-forest-labs/FLUX.2-dev"], + youtubeId: undefined, +}; + +export default taskData; diff --git a/packages/tasks/src/tasks/image-text-to-image/inference.ts b/packages/tasks/src/tasks/image-text-to-image/inference.ts new file mode 100644 index 0000000000..549b0f4118 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-image/inference.ts @@ -0,0 +1,75 @@ +/** + * Inference code generated from the JSON schema spec in ./spec + * + * Using src/scripts/inference-codegen + */ +/** + * Inputs for Image Text To Image inference. Either inputs (image) or prompt (in parameters) + * must be provided, or both. + */ +export interface ImageTextToImageInput { + /** + * The input image data as a base64-encoded string. If no `parameters` are provided, you can + * also provide the image data as a raw bytes payload. Either this or prompt must be + * provided. + */ + inputs?: Blob; + /** + * Additional inference parameters for Image Text To Image + */ + parameters?: ImageTextToImageParameters; + [property: string]: unknown; +} +/** + * Additional inference parameters for Image Text To Image + */ +export interface ImageTextToImageParameters { + /** + * For diffusion models. A higher guidance scale value encourages the model to generate + * images closely linked to the text prompt at the expense of lower image quality. + */ + guidance_scale?: number; + /** + * One prompt to guide what NOT to include in image generation. + */ + negative_prompt?: string; + /** + * For diffusion models. The number of denoising steps. More denoising steps usually lead to + * a higher quality image at the expense of slower inference. + */ + num_inference_steps?: number; + /** + * The text prompt to guide the image generation. Either this or inputs (image) must be + * provided. + */ + prompt?: string; + /** + * Seed for the random number generator. + */ + seed?: number; + /** + * The size in pixels of the output image. This parameter is only supported by some + * providers and for specific models. It will be ignored when unsupported. + */ + target_size?: TargetSize; + [property: string]: unknown; +} +/** + * The size in pixels of the output image. This parameter is only supported by some + * providers and for specific models. It will be ignored when unsupported. + */ +export interface TargetSize { + height: number; + width: number; + [property: string]: unknown; +} +/** + * Outputs of inference for the Image Text To Image task + */ +export interface ImageTextToImageOutput { + /** + * The generated image returned as raw bytes in the payload. + */ + image: unknown; + [property: string]: unknown; +} diff --git a/packages/tasks/src/tasks/image-text-to-image/spec/input.json b/packages/tasks/src/tasks/image-text-to-image/spec/input.json new file mode 100644 index 0000000000..259353e337 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-image/spec/input.json @@ -0,0 +1,59 @@ +{ + "$id": "/inference/schemas/image-text-to-image/input.json", + "$schema": "http://json-schema.org/draft-06/schema#", + "description": "Inputs for Image Text To Image inference. Either inputs (image) or prompt (in parameters) must be provided, or both.", + "title": "ImageTextToImageInput", + "type": "object", + "properties": { + "inputs": { + "type": "string", + "description": "The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. Either this or prompt must be provided.", + "comment": "type=binary" + }, + "parameters": { + "description": "Additional inference parameters for Image Text To Image", + "$ref": "#/$defs/ImageTextToImageParameters" + } + }, + "$defs": { + "ImageTextToImageParameters": { + "title": "ImageTextToImageParameters", + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The text prompt to guide the image generation. Either this or inputs (image) must be provided." + }, + "guidance_scale": { + "type": "number", + "description": "For diffusion models. A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality." + }, + "negative_prompt": { + "type": "string", + "description": "One prompt to guide what NOT to include in image generation." + }, + "num_inference_steps": { + "type": "integer", + "description": "For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference." + }, + "target_size": { + "type": "object", + "description": "The size in pixels of the output image. This parameter is only supported by some providers and for specific models. It will be ignored when unsupported.", + "properties": { + "width": { + "type": "integer" + }, + "height": { + "type": "integer" + } + }, + "required": ["width", "height"] + }, + "seed": { + "type": "integer", + "description": "Seed for the random number generator." + } + } + } + } +} diff --git a/packages/tasks/src/tasks/image-text-to-image/spec/output.json b/packages/tasks/src/tasks/image-text-to-image/spec/output.json new file mode 100644 index 0000000000..f5b782b625 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-image/spec/output.json @@ -0,0 +1,13 @@ +{ + "$id": "/inference/schemas/image-text-to-image/output.json", + "$schema": "http://json-schema.org/draft-06/schema#", + "description": "Outputs of inference for the Image Text To Image task", + "title": "ImageTextToImageOutput", + "type": "object", + "properties": { + "image": { + "description": "The generated image returned as raw bytes in the payload." + } + }, + "required": ["image"] +} diff --git a/packages/tasks/src/tasks/image-text-to-video/about.md b/packages/tasks/src/tasks/image-text-to-video/about.md new file mode 100644 index 0000000000..0119b6a517 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-video/about.md @@ -0,0 +1,71 @@ +## Use Cases + +### Image Animation + +Image-text-to-video models can be used to animate still images based on text descriptions. For example, you can provide a landscape photo and the instruction "A camera pan from left to right" to create a video with camera movement. + +### Dynamic Content Creation + +Transform images into video by adding motion, transformations, or effects described in text prompts. This is useful for creating engaging social media content, presentations, or marketing materials. + +### Guided Video Generation + +Use a reference image with text prompts to guide the video generation process. This provides more control over the visual style and composition compared to text-to-video models alone. + +### Story Visualization + +Create video sequences from storyboards or concept art by providing scene descriptions. This can help filmmakers and animators visualize scenes before production. + +### Motion Control + +Generate videos with specific camera movements, object motions, or scene transitions by combining reference images with detailed motion descriptions. + +## Task Variants + +### Image-to-Video with Motion Control + +Models that generate videos from images while following specific motion instructions, such as camera movements, object animations, or scene dynamics. + +### Reference-guided Video Generation + +Models that use a reference image to guide the visual style and composition of the generated video while incorporating text prompts for motion and transformation control. + +### Conditional Video Synthesis + +Models that perform specific video transformations based on text conditions, such as adding weather effects, time-of-day changes, or environmental animations. + +## Inference + +You can use the Diffusers library to interact with image-text-to-video models. Here's example snippet to use `LTXImageToVideoPipeline`. + +```python +import torch +from diffusers import LTXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image + +pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +image = load_image( + "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" +) +prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_frames=161, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + +## Useful Resources + +- [LTX-Video Model Card](https://huggingface.co/Lightricks/LTX-Video) +- [Text-to-Video: The Task, Challenges and the Current State](https://huggingface.co/blog/text-to-video) +- [Diffusers documentation on Video Generation](https://huggingface.co/docs/diffusers/using-diffusers/text-img2vid) diff --git a/packages/tasks/src/tasks/image-text-to-video/data.ts b/packages/tasks/src/tasks/image-text-to-video/data.ts new file mode 100644 index 0000000000..ded013562d --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-video/data.ts @@ -0,0 +1,54 @@ +import type { TaskDataCustom } from "../index.js"; + +const taskData: TaskDataCustom = { + datasets: [], + demo: { + inputs: [ + { + filename: "image-text-to-video-input.jpg", + type: "img", + }, + { + label: "Input", + content: "Darth Vader is surfing on the waves.", + type: "text", + }, + ], + outputs: [ + { + filename: "image-text-to-video-output.gif", + type: "img", + }, + ], + }, + metrics: [ + { + description: + "Frechet Video Distance uses a model that captures coherence for changes in frames and the quality of each frame. A smaller score indicates better video generation.", + id: "fvd", + }, + { + description: + "CLIPSIM measures similarity between video frames and text using an image-text similarity model. A higher score indicates better video generation.", + id: "clipsim", + }, + ], + models: [ + { + description: "A powerful model for image-text-to-video generation.", + id: "Lightricks/LTX-Video", + }, + ], + spaces: [ + { + description: "An application for image-text-to-video generation.", + id: "Lightricks/ltx-video-distilled", + }, + ], + summary: + "Image-text-to-video models take an reference image and a text instructions as and generate a video based on them. These models are useful for animating still images, creating dynamic content from static references, and generating videos with specific motion or transformation guidance.", + widgetModels: ["Lightricks/LTX-Video"], + youtubeId: undefined, +}; + +export default taskData; diff --git a/packages/tasks/src/tasks/image-text-to-video/inference.ts b/packages/tasks/src/tasks/image-text-to-video/inference.ts new file mode 100644 index 0000000000..ad048a7bc9 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-video/inference.ts @@ -0,0 +1,77 @@ +/** + * Inference code generated from the JSON schema spec in ./spec + * + * Using src/scripts/inference-codegen + */ +/** + * Inputs for Image Text To Video inference. Either inputs (image) or prompt (in parameters) + * must be provided, or both. + */ +export interface ImageTextToVideoInput { + /** + * The input image data as a base64-encoded string. If no `parameters` are provided, you can + * also provide the image data as a raw bytes payload. Either this or prompt must be + * provided. + */ + inputs?: Blob; + /** + * Additional inference parameters for Image Text To Video + */ + parameters?: ImageTextToVideoParameters; + [property: string]: unknown; +} +/** + * Additional inference parameters for Image Text To Video + */ +export interface ImageTextToVideoParameters { + /** + * For diffusion models. A higher guidance scale value encourages the model to generate + * videos closely linked to the text prompt at the expense of lower image quality. + */ + guidance_scale?: number; + /** + * One prompt to guide what NOT to include in video generation. + */ + negative_prompt?: string; + /** + * The num_frames parameter determines how many video frames are generated. + */ + num_frames?: number; + /** + * The number of denoising steps. More denoising steps usually lead to a higher quality + * video at the expense of slower inference. + */ + num_inference_steps?: number; + /** + * The text prompt to guide the video generation. Either this or inputs (image) must be + * provided. + */ + prompt?: string; + /** + * Seed for the random number generator. + */ + seed?: number; + /** + * The size in pixel of the output video frames. + */ + target_size?: TargetSize; + [property: string]: unknown; +} +/** + * The size in pixel of the output video frames. + */ +export interface TargetSize { + height: number; + width: number; + [property: string]: unknown; +} +/** + * Outputs of inference for the Image Text To Video task + */ +export interface ImageTextToVideoOutput { + /** + * The generated video returned as raw bytes in the payload. + */ + video: unknown; + [property: string]: unknown; +} diff --git a/packages/tasks/src/tasks/image-text-to-video/spec/input.json b/packages/tasks/src/tasks/image-text-to-video/spec/input.json new file mode 100644 index 0000000000..3874073e48 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-video/spec/input.json @@ -0,0 +1,63 @@ +{ + "$id": "/inference/schemas/image-text-to-video/input.json", + "$schema": "http://json-schema.org/draft-06/schema#", + "description": "Inputs for Image Text To Video inference. Either inputs (image) or prompt (in parameters) must be provided, or both.", + "title": "ImageTextToVideoInput", + "type": "object", + "properties": { + "inputs": { + "type": "string", + "description": "The input image data as a base64-encoded string. If no `parameters` are provided, you can also provide the image data as a raw bytes payload. Either this or prompt must be provided.", + "comment": "type=binary" + }, + "parameters": { + "description": "Additional inference parameters for Image Text To Video", + "$ref": "#/$defs/ImageTextToVideoParameters" + } + }, + "$defs": { + "ImageTextToVideoParameters": { + "title": "ImageTextToVideoParameters", + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The text prompt to guide the video generation. Either this or inputs (image) must be provided." + }, + "guidance_scale": { + "type": "number", + "description": "For diffusion models. A higher guidance scale value encourages the model to generate videos closely linked to the text prompt at the expense of lower image quality." + }, + "negative_prompt": { + "type": "string", + "description": "One prompt to guide what NOT to include in video generation." + }, + "num_inference_steps": { + "type": "integer", + "description": "The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference." + }, + "num_frames": { + "type": "number", + "description": "The num_frames parameter determines how many video frames are generated." + }, + "target_size": { + "type": "object", + "description": "The size in pixel of the output video frames.", + "properties": { + "width": { + "type": "integer" + }, + "height": { + "type": "integer" + } + }, + "required": ["width", "height"] + }, + "seed": { + "type": "integer", + "description": "Seed for the random number generator." + } + } + } + } +} diff --git a/packages/tasks/src/tasks/image-text-to-video/spec/output.json b/packages/tasks/src/tasks/image-text-to-video/spec/output.json new file mode 100644 index 0000000000..26a08e9be5 --- /dev/null +++ b/packages/tasks/src/tasks/image-text-to-video/spec/output.json @@ -0,0 +1,13 @@ +{ + "$id": "/inference/schemas/image-text-to-video/output.json", + "$schema": "http://json-schema.org/draft-06/schema#", + "description": "Outputs of inference for the Image Text To Video task", + "title": "ImageTextToVideoOutput", + "type": "object", + "properties": { + "video": { + "description": "The generated video returned as raw bytes in the payload." + } + }, + "required": ["video"] +} diff --git a/packages/tasks/src/tasks/index.ts b/packages/tasks/src/tasks/index.ts index 9f33085d60..5d1c0284bc 100644 --- a/packages/tasks/src/tasks/index.ts +++ b/packages/tasks/src/tasks/index.ts @@ -14,6 +14,8 @@ import imageFeatureExtraction from "./image-feature-extraction/data.js"; import imageToImage from "./image-to-image/data.js"; import imageToText from "./image-to-text/data.js"; import imageTextToText from "./image-text-to-text/data.js"; +import imageTextToImage from "./image-text-to-image/data.js"; +import imageTextToVideo from "./image-text-to-video/data.js"; import imageSegmentation from "./image-segmentation/data.js"; import imageToVideo from "./image-to-video/data.js"; import maskGeneration from "./mask-generation/data.js"; @@ -74,6 +76,16 @@ export type * from "./image-to-image/inference.js"; export type { ImageToTextInput, ImageToTextOutput, ImageToTextParameters } from "./image-to-text/inference.js"; export type * from "./image-segmentation/inference.js"; export type { ImageToVideoInput, ImageToVideoOutput, ImageToVideoParameters } from "./image-to-video/inference.js"; +export type { + ImageTextToImageInput, + ImageTextToImageOutput, + ImageTextToImageParameters, +} from "./image-text-to-image/inference.js"; +export type { + ImageTextToVideoInput, + ImageTextToVideoOutput, + ImageTextToVideoParameters, +} from "./image-text-to-video/inference.js"; export type * from "./object-detection/inference.js"; export type * from "./depth-estimation/inference.js"; export type * from "./question-answering/inference.js"; @@ -133,6 +145,8 @@ export const TASKS_MODEL_LIBRARIES: Record = { "image-feature-extraction": ["timm", "transformers"], "image-segmentation": ["transformers", "transformers.js"], "image-text-to-text": ["transformers"], + "image-text-to-image": ["diffusers"], + "image-text-to-video": ["diffusers"], "image-to-image": ["diffusers", "transformers", "transformers.js"], "image-to-text": ["transformers", "transformers.js"], "image-to-video": ["diffusers"], @@ -220,6 +234,8 @@ export const TASKS_DATA: Record = { "image-segmentation": getData("image-segmentation", imageSegmentation), "image-to-image": getData("image-to-image", imageToImage), "image-text-to-text": getData("image-text-to-text", imageTextToText), + "image-text-to-image": getData("image-text-to-image", imageTextToImage), + "image-text-to-video": getData("image-text-to-video", imageTextToVideo), "image-to-text": getData("image-to-text", imageToText), "image-to-video": getData("image-to-video", imageToVideo), "keypoint-detection": getData("keypoint-detection", keypointDetection),