Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Currently, we support the following providers:
- [Novita](https://novita.ai)
- [Nscale](https://nscale.com)
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
- [Public AI](https://publicai.co)
- [Replicate](https://replicate.com)
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import * as Novita from "../providers/novita.js";
import * as Nscale from "../providers/nscale.js";
import * as OpenAI from "../providers/openai.js";
import * as OvhCloud from "../providers/ovhcloud.js";
import * as PublicAI from "../providers/publicai.js";
import type {
AudioClassificationTaskHelper,
AudioToAudioTaskHelper,
Expand Down Expand Up @@ -139,6 +140,9 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
conversational: new OvhCloud.OvhCloudConversationalTask(),
"text-generation": new OvhCloud.OvhCloudTextGenerationTask(),
},
publicai: {
conversational: new PublicAI.PublicAIConversationalTask(),
},
replicate: {
"text-to-image": new Replicate.ReplicateTextToImageTask(),
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
novita: {},
nscale: {},
openai: {},
publicai: {},
ovhcloud: {},
replicate: {},
sambanova: {},
Expand Down
7 changes: 7 additions & 0 deletions packages/inference/src/providers/publicai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { BaseConversationalTask } from "./providerHelper.js";

export class PublicAIConversationalTask extends BaseConversationalTask {
constructor() {
super("publicai", "https://api.publicai.co");
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export const INFERENCE_PROVIDERS = [
"nscale",
"openai",
"ovhcloud",
"publicai",
"replicate",
"sambanova",
"scaleway",
Expand Down
53 changes: 53 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2238,4 +2238,57 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);

describe.concurrent(
"PublicAI",
() => {
const client = new InferenceClient(env.HF_PUBLICAI_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["publicai"] = {
"swiss-ai/Apertus-8B-Instruct-2509": {
provider: "publicai",
hfModelId: "swiss-ai/Apertus-8B-Instruct-2509",
providerId: "swiss-ai/apertus-8b-instruct",
status: "live",
task: "conversational",
},
};

it("chatCompletion", async () => {
const res = await client.chatCompletion({
model: "swiss-ai/Apertus-8B-Instruct-2509",
provider: "publicai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = client.chatCompletionStream({
model: "swiss-ai/Apertus-8B-Instruct-2509",
provider: "publicai",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
},
TIMEOUT
);
});
Loading