Skip to content

Commit 699140e

Browse files
committed
wip
1 parent b193fb1 commit 699140e

File tree

10 files changed

+149
-108
lines changed

10 files changed

+149
-108
lines changed

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import type { TextToSpeechInput } from "@huggingface/tasks";
1+
import type { TextToSpeechInput, TextToSpeechOutput } from "@huggingface/tasks";
22
import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, Options } from "../../types";
44
import { request } from "../custom/request";
55

66
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;
77

8-
interface TextToSpeechOutput {
9-
audio: Blob;
10-
}
118
/**
129
* This task synthesize an audio of a voice pronouncing a given text.
1310
* Recommended model: espnet/kan-bayashi_ljspeech_vits

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
3131
...options,
3232
taskHint: "text-to-image",
3333
});
34-
console.log(res);
3534
if (res && typeof res === "object") {
3635
if (args.provider === "fal-ai" && "images" in res && Array.isArray(res.images) && res.images[0].url) {
3736
const image = await fetch(res.images[0].url);

packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ import { request } from "../custom/request";
44
import type { RequestArgs } from "../../types";
55
import { toArray } from "../../utils/toArray";
66
import { base64FromBytes } from "../../utils/base64FromBytes";
7-
import type { DocumentQuestionAnsweringInput, DocumentQuestionAnsweringOutput } from "@huggingface/tasks";
7+
import type {
8+
DocumentQuestionAnsweringInput,
9+
DocumentQuestionAnsweringInputData,
10+
DocumentQuestionAnsweringOutput,
11+
} from "@huggingface/tasks";
812

9-
export type DocumentQuestionAnsweringArgs = BaseArgs & DocumentQuestionAnsweringInput;
13+
/// Override the type to properly set inputs.image as Blob
14+
export type DocumentQuestionAnsweringArgs = BaseArgs &
15+
DocumentQuestionAnsweringInput & { inputs: DocumentQuestionAnsweringInputData & { image: Blob } };
1016

1117
/**
1218
* Answers a question on a document image. Recommended model: impira/layoutlm-document-qa.
@@ -20,22 +26,30 @@ export async function documentQuestionAnswering(
2026
inputs: {
2127
question: args.inputs.question,
2228
// convert Blob or ArrayBuffer to base64
23-
image: base64FromBytes(new Uint8Array(await args.inputs.arrayBuffer())),
29+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
2430
},
2531
} as RequestArgs;
2632
const res = toArray(
27-
await request<[DocumentQuestionAnsweringOutput] | DocumentQuestionAnsweringOutput>(reqArgs, {
33+
await request<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(reqArgs, {
2834
...options,
2935
taskHint: "document-question-answering",
3036
})
31-
)?.[0];
37+
);
38+
3239
const isValidOutput =
33-
typeof res?.answer === "string" &&
34-
(typeof res.end === "number" || typeof res.end === "undefined") &&
35-
(typeof res.score === "number" || typeof res.score === "undefined") &&
36-
(typeof res.start === "number" || typeof res.start === "undefined");
40+
Array.isArray(res) &&
41+
res.every(
42+
(elem) =>
43+
typeof elem === "object" &&
44+
!!elem &&
45+
typeof elem?.answer === "string" &&
46+
(typeof elem.end === "number" || typeof elem.end === "undefined") &&
47+
(typeof elem.score === "number" || typeof elem.score === "undefined") &&
48+
(typeof elem.start === "number" || typeof elem.start === "undefined")
49+
);
3750
if (!isValidOutput) {
3851
throw new InferenceOutputError("Expected Array<{answer: string, end?: number, score?: number, start?: number}>");
3952
}
53+
4054
return res;
4155
}

packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
import type { VisualQuestionAnsweringInput, VisualQuestionAnsweringOutput } from "@huggingface/tasks";
1+
import type {
2+
VisualQuestionAnsweringInput,
3+
VisualQuestionAnsweringInputData,
4+
VisualQuestionAnsweringOutput,
5+
} from "@huggingface/tasks";
26
import { InferenceOutputError } from "../../lib/InferenceOutputError";
37
import type { BaseArgs, Options, RequestArgs } from "../../types";
48
import { base64FromBytes } from "../../utils/base64FromBytes";
59
import { request } from "../custom/request";
610

7-
export type VisualQuestionAnsweringArgs = BaseArgs & VisualQuestionAnsweringInput;
11+
/// Override the type to properly set inputs.image as Blob
12+
export type VisualQuestionAnsweringArgs = BaseArgs &
13+
VisualQuestionAnsweringInput & { inputs: VisualQuestionAnsweringInputData & { image: Blob } };
814

915
/**
1016
* Answers a question on an image. Recommended model: dandelin/vilt-b32-finetuned-vqa.
@@ -18,18 +24,18 @@ export async function visualQuestionAnswering(
1824
inputs: {
1925
question: args.inputs.question,
2026
// convert Blob or ArrayBuffer to base64
21-
image: base64FromBytes(
22-
new Uint8Array(args.inputs.image instanceof ArrayBuffer ? args.inputs.image : await args.inputs.arrayBuffer())
23-
),
27+
image: base64FromBytes(new Uint8Array(await args.inputs.image.arrayBuffer())),
2428
},
2529
} as RequestArgs;
26-
const res = (
27-
await request<[VisualQuestionAnsweringOutput]>(reqArgs, {
28-
...options,
29-
taskHint: "visual-question-answering",
30-
})
31-
)?.[0];
32-
const isValidOutput = typeof res?.answer === "string" && typeof res.score === "number";
30+
const res = await request<VisualQuestionAnsweringOutput>(reqArgs, {
31+
...options,
32+
taskHint: "visual-question-answering",
33+
});
34+
const isValidOutput =
35+
Array.isArray(res) &&
36+
res.every(
37+
(elem) => typeof elem === "object" && !!elem && typeof elem?.answer === "string" && typeof elem.score === "number"
38+
);
3339
if (!isValidOutput) {
3440
throw new InferenceOutputError("Expected Array<{answer: string, score: number}>");
3541
}

packages/inference/src/tasks/nlp/questionAnswering.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@ export async function questionAnswering(
1616
...options,
1717
taskHint: "question-answering",
1818
});
19-
const isValidOutput =
20-
Array.isArray(res) ?
21-
res.every(
19+
const isValidOutput = Array.isArray(res)
20+
? res.every(
2221
(elem) =>
2322
typeof elem === "object" &&
2423
!!elem &&
2524
typeof elem.answer === "string" &&
2625
typeof elem.end === "number" &&
2726
typeof elem.score === "number" &&
2827
typeof elem.start === "number"
29-
) : (typeof res === "object" &&
30-
!!res &&
31-
typeof res.answer === "string" &&
32-
typeof res.end === "number" &&
33-
typeof res.score === "number" &&
34-
typeof res.start === "number");
28+
)
29+
: typeof res === "object" &&
30+
!!res &&
31+
typeof res.answer === "string" &&
32+
typeof res.end === "number" &&
33+
typeof res.score === "number" &&
34+
typeof res.start === "number";
3535
if (!isValidOutput) {
3636
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
3737
}

packages/inference/src/tasks/nlp/tableQuestionAnswering.ts

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,34 @@ export async function tableQuestionAnswering(
1212
args: TableQuestionAnsweringArgs,
1313
options?: Options
1414
): Promise<TableQuestionAnsweringOutput> {
15-
const res = await request<TableQuestionAnsweringOutput>(args, {
15+
const res = await request<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(args, {
1616
...options,
1717
taskHint: "table-question-answering",
1818
});
19-
const isValidOutput =
20-
Array.isArray(res) &&
21-
res.every((elem) => {
22-
typeof elem?.aggregator === "string" &&
23-
typeof elem.answer === "string" &&
24-
Array.isArray(elem.cells) &&
25-
elem.cells.every((x) => typeof x === "string") &&
26-
Array.isArray(elem.coordinates) &&
27-
elem.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
28-
});
19+
const isValidOutput = Array.isArray(res) ? res.every((elem) => validate(elem)) : validate(res);
2920
if (!isValidOutput) {
3021
throw new InferenceOutputError(
3122
"Expected {aggregator: string, answer: string, cells: string[], coordinates: number[][]}"
3223
);
3324
}
34-
return res;
25+
return Array.isArray(res) ? res : [res];
26+
}
27+
28+
function validate(elem: unknown): elem is TableQuestionAnsweringOutput[number] {
29+
return (
30+
typeof elem === "object" &&
31+
!!elem &&
32+
"aggregator" in elem &&
33+
typeof elem.aggregator === "string" &&
34+
"answer" in elem &&
35+
typeof elem.answer === "string" &&
36+
"cells" in elem &&
37+
Array.isArray(elem.cells) &&
38+
elem.cells.every((x: unknown): x is string => typeof x === "string") &&
39+
"coordinates" in elem &&
40+
Array.isArray(elem.coordinates) &&
41+
elem.coordinates.every(
42+
(coord: unknown): coord is number[] => Array.isArray(coord) && coord.every((x) => typeof x === "number")
43+
)
44+
);
3545
}

packages/inference/test/HfInference.spec.ts

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { expect, it, describe, assert } from "vitest";
22

3-
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
3+
import type { ChatCompletionStreamOutput, VisualQuestionAnsweringInput } from "@huggingface/tasks";
44

55
import { chatCompletion, HfInference } from "../src";
66
import "./vcr";
@@ -87,13 +87,14 @@ describe.concurrent("HfInference", () => {
8787
context: "The capital of France is Paris.",
8888
},
8989
});
90-
91-
expect(res).toMatchObject([{
92-
answer: "Paris",
93-
score: expect.any(Number),
94-
start: expect.any(Number),
95-
end: expect.any(Number),
96-
}]);
90+
expect(res).toMatchObject([
91+
{
92+
answer: "Paris",
93+
score: expect.any(Number),
94+
start: expect.any(Number),
95+
end: expect.any(Number),
96+
},
97+
]);
9798
});
9899

99100
it("tableQuestionAnswering", async () => {
@@ -110,30 +111,31 @@ describe.concurrent("HfInference", () => {
110111
},
111112
},
112113
})
113-
).toMatchObject({
114-
answer: "AVERAGE > 36542",
115-
coordinates: [[0, 1]],
116-
cells: ["36542"],
117-
aggregator: "AVERAGE",
118-
});
114+
).toMatchObject([
115+
{
116+
answer: "AVERAGE > 36542",
117+
coordinates: [[0, 1]],
118+
cells: ["36542"],
119+
aggregator: "AVERAGE",
120+
},
121+
]);
119122
});
120123

121124
it("documentQuestionAnswering", async () => {
122-
expect(
123-
await hf.documentQuestionAnswering({
124-
model: "impira/layoutlm-document-qa",
125-
inputs: {
126-
question: "Invoice number?",
127-
image: new Blob([readTestFile("invoice.png")], { type: "image/png" }),
128-
},
129-
})
130-
).toMatchObject({
131-
answer: "us-001",
132-
score: expect.any(Number),
133-
// not sure what start/end refers to in this case
134-
start: expect.any(Number),
135-
end: expect.any(Number),
125+
const res = await hf.documentQuestionAnswering({
126+
model: "impira/layoutlm-document-qa",
127+
inputs: {
128+
question: "Invoice number?",
129+
image: new Blob([readTestFile("invoice.png")], { type: "image/png" }),
130+
},
136131
});
132+
expect(res).toBeInstanceOf(Array);
133+
for (const elem of res) {
134+
expect(elem).toMatchObject({
135+
answer: expect.any(String),
136+
score: expect.any(Number),
137+
});
138+
}
137139
});
138140

139141
// Errors with "Error: If you are using a VisionEncoderDecoderModel, you must provide a feature extractor"
@@ -152,18 +154,20 @@ describe.concurrent("HfInference", () => {
152154
});
153155

154156
it("visualQuestionAnswering", async () => {
155-
expect(
156-
await hf.visualQuestionAnswering({
157-
model: "dandelin/vilt-b32-finetuned-vqa",
158-
inputs: {
159-
question: "How many cats are lying down?",
160-
image: new Blob([readTestFile("cats.png")], { type: "image/png" }),
161-
},
162-
})
163-
).toMatchObject({
164-
answer: "2",
165-
score: expect.any(Number),
166-
});
157+
const res = await hf.visualQuestionAnswering({
158+
model: "dandelin/vilt-b32-finetuned-vqa",
159+
inputs: {
160+
question: "How many cats are lying down?",
161+
image: new Blob([readTestFile("cats.png")], { type: "image/png" }),
162+
},
163+
} satisfies VisualQuestionAnsweringInput);
164+
expect(res).toBeInstanceOf(Array);
165+
for (const elem of res) {
166+
expect(elem).toMatchObject({
167+
answer: expect.any(String),
168+
score: expect.any(Number),
169+
});
170+
}
167171
});
168172

169173
it("textClassification", async () => {
@@ -451,7 +455,9 @@ describe.concurrent("HfInference", () => {
451455
model: "espnet/kan-bayashi_ljspeech_vits",
452456
inputs: "hello there!",
453457
})
454-
).toSatisfy((out) => typeof out === "object" && !!out && "image" in out && out.image instanceof Blob);
458+
).toMatchObject({
459+
audio: expect.any(Blob),
460+
});
455461
});
456462

457463
it("imageClassification", async () => {
@@ -473,7 +479,7 @@ describe.concurrent("HfInference", () => {
473479
it("zeroShotImageClassification", async () => {
474480
expect(
475481
await hf.zeroShotImageClassification({
476-
inputs: { image: new Blob([readTestFile("cheetah.png")], { type: "image/png" }) },
482+
inputs: new Blob([readTestFile("cheetah.png")], { type: "image/png" }),
477483
model: "openai/clip-vit-large-patch14-336",
478484
parameters: {
479485
candidate_labels: ["animal", "toy", "car"],

packages/tasks/src/tasks/text-to-speech/inference.ts

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
/**
2+
* Outputs of inference for the Text To Speech task
3+
*/
4+
export interface TextToSpeechOutput {
5+
/**
6+
* The generated audio
7+
*/
8+
audio: Blob;
9+
/**
10+
* The sampling rate of the generated audio waveform.
11+
*/
12+
sampling_rate?: number;
13+
[property: string]: unknown;
14+
}
115
/**
216
* Inference code generated from the JSON schema spec in ./spec
317
*
@@ -117,19 +131,3 @@ export interface GenerationParameters {
117131
* Controls the stopping condition for beam-based methods.
118132
*/
119133
export type EarlyStoppingUnion = boolean | "never";
120-
/**
121-
* Outputs for Text to Speech inference
122-
*
123-
* Outputs of inference for the Text To Audio task
124-
*/
125-
export interface TextToSpeechOutput {
126-
/**
127-
* The generated audio waveform.
128-
*/
129-
audio: unknown;
130-
/**
131-
* The sampling rate of the generated audio waveform.
132-
*/
133-
sampling_rate: number;
134-
[property: string]: unknown;
135-
}
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
{
2-
"$ref": "/inference/schemas/text-to-audio/output.json",
32
"$id": "/inference/schemas/text-to-speech/output.json",
43
"$schema": "http://json-schema.org/draft-06/schema#",
4+
"description": "Outputs of inference for the Text To Speech task",
55
"title": "TextToSpeechOutput",
6-
"description": "Outputs for Text to Speech inference"
6+
"type": "object",
7+
"properties": {
8+
"audio": {
9+
"description": "The generated audio",
10+
"comment": "type=binary"
11+
},
12+
"sampling_rate": {
13+
"type": "number",
14+
"description": "The sampling rate of the generated audio waveform."
15+
}
16+
},
17+
"required": ["audio"]
718
}

0 commit comments

Comments
 (0)