Skip to content

Commit b193fb1

Browse files
committed
question answering test + inference code
1 parent ef5693c commit b193fb1

File tree

3 files changed

+30
-24
lines changed

3 files changed

+30
-24
lines changed

packages/inference/src/tasks/nlp/questionAnswering.ts

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,28 @@ export async function questionAnswering(
1212
args: QuestionAnsweringArgs,
1313
options?: Options
1414
): Promise<QuestionAnsweringOutput> {
15-
const res = await request<QuestionAnsweringOutput>(args, {
15+
const res = await request<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(args, {
1616
...options,
1717
taskHint: "question-answering",
1818
});
1919
const isValidOutput =
20-
Array.isArray(res) &&
21-
res.every(
22-
(elem) =>
23-
typeof elem === "object" &&
24-
!!elem &&
25-
typeof elem.answer === "string" &&
26-
typeof elem.end === "number" &&
27-
typeof elem.score === "number" &&
28-
typeof elem.start === "number"
29-
);
20+
Array.isArray(res) ?
21+
res.every(
22+
(elem) =>
23+
typeof elem === "object" &&
24+
!!elem &&
25+
typeof elem.answer === "string" &&
26+
typeof elem.end === "number" &&
27+
typeof elem.score === "number" &&
28+
typeof elem.start === "number"
29+
) : (typeof res === "object" &&
30+
!!res &&
31+
typeof res.answer === "string" &&
32+
typeof res.end === "number" &&
33+
typeof res.score === "number" &&
34+
typeof res.start === "number");
3035
if (!isValidOutput) {
31-
throw new InferenceOutputError("Expected {answer: string, end: number, score: number, start: number}");
36+
throw new InferenceOutputError("Expected Array<{answer: string, end: number, score: number, start: number}>");
3237
}
33-
return res;
38+
return Array.isArray(res) ? res : [res];
3439
}

packages/inference/test/HfInference.spec.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,20 @@ describe.concurrent("HfInference", () => {
8080
});
8181

8282
it("questionAnswering", async () => {
83-
expect(
84-
await hf.questionAnswering({
85-
model: "deepset/roberta-base-squad2",
86-
inputs: {
87-
question: "What is the capital of France?",
88-
context: "The capital of France is Paris.",
89-
},
90-
})
91-
).toMatchObject({
83+
const res = await hf.questionAnswering({
84+
model: "deepset/roberta-base-squad2",
85+
inputs: {
86+
question: "What is the capital of France?",
87+
context: "The capital of France is Paris.",
88+
},
89+
});
90+
91+
expect(res).toMatchObject([{
9292
answer: "Paris",
9393
score: expect.any(Number),
9494
start: expect.any(Number),
9595
end: expect.any(Number),
96-
});
96+
}]);
9797
});
9898

9999
it("tableQuestionAnswering", async () => {

packages/tasks/src/tasks/visual-question-answering/spec/input.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"comment": "type=binary"
1616
},
1717
"question": {
18-
"description": "The question to answer based on the image."
18+
"description": "The question to answer based on the image.",
19+
"type": "string"
1920
}
2021
},
2122
"required": ["question", "image"]

0 commit comments

Comments
 (0)