Skip to content
4 changes: 4 additions & 0 deletions packages/inference/src/providers/fal-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ export const FAL_AI_SUPPORTED_MODEL_IDS: ProviderMapping<FalAiId> = {
"automatic-speech-recognition": {
"openai/whisper-large-v3": "fal-ai/whisper",
},
"text-to-video": {
"genmo/mochi-1-preview": "fal-ai/mochi-v1",
"tencent/HunyuanVideo": "fal-ai/hunyuan-video",
},
};
3 changes: 3 additions & 0 deletions packages/inference/src/providers/replicate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
"text-to-speech": {
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
},
"text-to-video": {
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",
},
};
20 changes: 11 additions & 9 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { omit } from "../../utils/omit";
import { request } from "../custom/request";

export type TextToImageArgs = BaseArgs & {
Expand Down Expand Up @@ -57,15 +58,16 @@ interface OutputUrlImageGeneration {
* Recommended model: stabilityai/stable-diffusion-2
*/
export async function textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageOutput> {
if (args.provider === "together" || args.provider === "fal-ai") {
args.prompt = args.inputs;
delete (args as unknown as { inputs: unknown }).inputs;
args.response_format = "base64";
} else if (args.provider === "replicate") {
args.prompt = args.inputs;
delete (args as unknown as { inputs: unknown }).inputs;
}
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
const payload =
args.provider === "together" || args.provider === "fal-ai" || args.provider === "replicate"
? {
...omit(args, ["inputs", "parameters"]),
...args.parameters,
...(args.provider !== "replicate" ? { response_format: "base64" } : undefined),
prompt: args.inputs,
}
: args;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed this is maybe better 👍

const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(payload, {
...options,
taskHint: "text-to-image",
});
Expand Down
67 changes: 67 additions & 0 deletions packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import type { BaseArgs, InferenceProvider, Options } from "../../types";
import type { TextToVideoInput } from "@huggingface/tasks";
import { request } from "../custom/request";
import { omit } from "../../utils/omit";
import { isUrl } from "../../lib/isUrl";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { typedInclude } from "../../utils/typedInclude";

export type TextToVideoArgs = BaseArgs & TextToVideoInput;

export type TextToVideoOutput = Blob;

interface FalAiOutput {
video: {
url: string;
};
}

interface ReplicateOutput {
output: string;
}

const SUPPORTED_PROVIDERS = ["fal-ai", "replicate"] as const satisfies readonly InferenceProvider[];

export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
if (!args.provider || !typedInclude(SUPPORTED_PROVIDERS, args.provider)) {
throw new Error(
`textToVideo inference is only supported for the following providers: ${SUPPORTED_PROVIDERS.join(", ")}`
);
}

const payload =
args.provider === "fal-ai" || args.provider === "replicate"
? { ...omit(args, ["inputs", "parameters"]), ...args.parameters, prompt: args.inputs }
: args;
const res = await request<FalAiOutput | ReplicateOutput>(payload, {
...options,
taskHint: "text-to-video",
});

if (args.provider === "fal-ai") {
const isValidOutput =
typeof res === "object" &&
!!res &&
"video" in res &&
typeof res.video === "object" &&
!!res.video &&
"url" in res.video &&
typeof res.video.url === "string" &&
isUrl(res.video.url);
if (!isValidOutput) {
throw new InferenceOutputError("Expected { video: { url: string } }");
}
const urlResponse = await fetch(res.video.url);
return await urlResponse.blob();
} else {
/// TODO: Replicate: handle the case where the generation request "times out" / is async (ie output is null)
/// https://replicate.com/docs/topics/predictions/create-a-prediction
const isValidOutput =
typeof res === "object" && !!res && "output" in res && typeof res.output === "string" && isUrl(res.output);
if (!isValidOutput) {
throw new InferenceOutputError("Expected { output: string }");
}
const urlResponse = await fetch(res.output);
return await urlResponse.blob();
}
}
2 changes: 1 addition & 1 deletion packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export interface BaseArgs {
}

export type RequestArgs = BaseArgs &
({ data: Blob | ArrayBuffer } | { inputs: unknown } | ChatCompletionInput) & {
({ data: Blob | ArrayBuffer } | { inputs: unknown } | { prompt: string } | ChatCompletionInput) & {
parameters?: Record<string, unknown>;
accessToken?: string;
};
48 changes: 47 additions & 1 deletion packages/inference/test/HfInference.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
import { chatCompletion, FAL_AI_SUPPORTED_MODEL_IDS, HfInference } from "../src";
import "./vcr";
import { readTestFile } from "./test-files";
import { textToVideo } from "../src/tasks/cv/textToVideo";

const TIMEOUT = 60000 * 3;
const env = import.meta.env;
Expand Down Expand Up @@ -47,7 +48,7 @@ describe.concurrent("HfInference", () => {
);
});

it("works without model", async () => {
it.skip("works without model", async () => {
expect(
await hf.fillMask({
inputs: "[MASK] world!",
Expand Down Expand Up @@ -799,6 +800,35 @@ describe.concurrent("HfInference", () => {
});
});
}

it("textToVideo - genmo/mochi-1-preview", async () => {
const res = await textToVideo({
model: "genmo/mochi-1-preview",
inputs: "A running dog",
parameters: {
seed: 176,
},
provider: "fal-ai",
accessToken: env.HF_FAL_KEY,
});
expect(res).toBeInstanceOf(Blob);
});

it("textToVideo - HunyuanVideo", async () => {
const res = await textToVideo({
model: "genmo/mochi-1-preview",
inputs: "A running dog",
parameters: {
seed: 176,
num_inference_steps: 2,
num_frames: 85,
resolution: "480p",
},
provider: "fal-ai",
accessToken: env.HF_FAL_KEY,
});
expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);
Expand Down Expand Up @@ -844,6 +874,22 @@ describe.concurrent("HfInference", () => {

expect(res).toBeInstanceOf(Blob);
});

it("textToVideo Mochi", async () => {
const res = await textToVideo({
accessToken: env.HF_REPLICATE_KEY,
model: "genmo/mochi-1-preview",
provider: "replicate",
inputs: "A running dog",
parameters: {
num_inference_steps: 10,
seed: 178,
num_frames: 30,
},
});

expect(res).toBeInstanceOf(Blob);
});
},
TIMEOUT
);
Expand Down
Loading
Loading