1616 */
1717import { InferenceOutputError } from "../lib/InferenceOutputError.js" ;
1818import { isUrl } from "../lib/isUrl.js" ;
19+ import type { TextToVideoArgs } from "../tasks/index.js" ;
1920import type { BodyParams , UrlParams } from "../types.js" ;
21+ import { delay } from "../utils/delay.js" ;
2022import { omit } from "../utils/omit.js" ;
2123import {
2224 BaseConversationalTask ,
@@ -26,11 +28,11 @@ import {
2628} from "./providerHelper.js" ;
2729
2830const NOVITA_API_BASE_URL = "https://api.novita.ai" ;
29- export interface NovitaOutput {
30- video : {
31- video_url : string ;
32- } ;
31+
32+ export interface NovitaAsyncAPIOutput {
33+ task_id : string ;
3334}
35+
3436export class NovitaTextGenerationTask extends BaseTextGenerationTask {
3537 constructor ( ) {
3638 super ( "novita" , NOVITA_API_BASE_URL ) ;
@@ -50,38 +52,94 @@ export class NovitaConversationalTask extends BaseConversationalTask {
5052 return "/v3/openai/chat/completions" ;
5153 }
5254}
55+
5356export class NovitaTextToVideoTask extends TaskProviderHelper implements TextToVideoTaskHelper {
5457 constructor ( ) {
5558 super ( "novita" , NOVITA_API_BASE_URL ) ;
5659 }
5760
58- makeRoute ( params : UrlParams ) : string {
59- return `/v3/hf /${ params . model } ` ;
61+ override makeRoute ( params : UrlParams ) : string {
62+ return `/v3/async /${ params . model } ` ;
6063 }
6164
62- preparePayload ( params : BodyParams ) : Record < string , unknown > {
65+ override preparePayload ( params : BodyParams < TextToVideoArgs > ) : Record < string , unknown > {
66+ const { num_inference_steps, ...restParameters } = params . args . parameters ?? { } ;
6367 return {
6468 ...omit ( params . args , [ "inputs" , "parameters" ] ) ,
65- ...( params . args . parameters as Record < string , unknown > ) ,
69+ ...restParameters ,
70+ steps : num_inference_steps ,
6671 prompt : params . args . inputs ,
6772 } ;
6873 }
69- override async getResponse ( response : NovitaOutput ) : Promise < Blob > {
70- const isValidOutput =
71- typeof response === "object" &&
72- ! ! response &&
73- "video" in response &&
74- typeof response . video === "object" &&
75- ! ! response . video &&
76- "video_url" in response . video &&
77- typeof response . video . video_url === "string" &&
78- isUrl ( response . video . video_url ) ;
7974
80- if ( ! isValidOutput ) {
81- throw new InferenceOutputError ( "Expected { video: { video_url: string } }" ) ;
75+ override async getResponse (
76+ response : NovitaAsyncAPIOutput ,
77+ url ?: string ,
78+ headers ?: Record < string , string >
79+ ) : Promise < Blob > {
80+ if ( ! url || ! headers ) {
81+ throw new InferenceOutputError ( "URL and headers are required for text-to-video task" ) ;
8282 }
83+ const taskId = response . task_id ;
84+ if ( ! taskId ) {
85+ throw new InferenceOutputError ( "No task ID found in the response" ) ;
86+ }
87+
88+ const parsedUrl = new URL ( url ) ;
89+ const baseUrl = `${ parsedUrl . protocol } //${ parsedUrl . host } ${
90+ parsedUrl . host === "router.huggingface.co" ? "/novita" : ""
91+ } `;
92+ const resultUrl = `${ baseUrl } /v3/async/task-result?task_id=${ taskId } ` ;
93+
94+ let status = "" ;
95+ let taskResult : unknown ;
8396
84- const urlResponse = await fetch ( response . video . video_url ) ;
85- return await urlResponse . blob ( ) ;
97+ while ( status !== "TASK_STATUS_SUCCEED" && status !== "TASK_STATUS_FAILED" ) {
98+ await delay ( 500 ) ;
99+ const resultResponse = await fetch ( resultUrl , { headers } ) ;
100+ if ( ! resultResponse . ok ) {
101+ throw new InferenceOutputError ( "Failed to fetch task result" ) ;
102+ }
103+ try {
104+ taskResult = await resultResponse . json ( ) ;
105+ if (
106+ taskResult &&
107+ typeof taskResult === "object" &&
108+ "task" in taskResult &&
109+ taskResult . task &&
110+ typeof taskResult . task === "object" &&
111+ "status" in taskResult . task &&
112+ typeof taskResult . task . status === "string"
113+ ) {
114+ status = taskResult . task . status ;
115+ } else {
116+ throw new InferenceOutputError ( "Failed to get task status" ) ;
117+ }
118+ } catch ( error ) {
119+ throw new InferenceOutputError ( "Failed to parse task result" ) ;
120+ }
121+ }
122+
123+ if ( status === "TASK_STATUS_FAILED" ) {
124+ throw new InferenceOutputError ( "Task failed" ) ;
125+ }
126+
127+ if (
128+ typeof taskResult === "object" &&
129+ ! ! taskResult &&
130+ "videos" in taskResult &&
131+ typeof taskResult . videos === "object" &&
132+ ! ! taskResult . videos &&
133+ Array . isArray ( taskResult . videos ) &&
134+ taskResult . videos . length > 0 &&
135+ "video_url" in taskResult . videos [ 0 ] &&
136+ typeof taskResult . videos [ 0 ] . video_url === "string" &&
137+ isUrl ( taskResult . videos [ 0 ] . video_url )
138+ ) {
139+ const urlResponse = await fetch ( taskResult . videos [ 0 ] . video_url ) ;
140+ return await urlResponse . blob ( ) ;
141+ } else {
142+ throw new InferenceOutputError ( "Expected { videos: [{ video_url: string }] }" ) ;
143+ }
86144 }
87145}
0 commit comments