Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Currently, we support the following providers:
- [Sambanova](https://sambanova.ai)
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
- [Together](https://together.xyz)
- [Baseten](https://baseten.co)
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
Expand Down Expand Up @@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Baseten supported models](https://huggingface.co/api/partners/baseten/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [Groq supported models](https://console.groq.com/docs/models)
Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import * as Baseten from "../providers/baseten.js";
import * as BlackForestLabs from "../providers/black-forest-labs.js";
import * as Cerebras from "../providers/cerebras.js";
import * as Cohere from "../providers/cohere.js";
Expand Down Expand Up @@ -55,6 +56,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from
import { InferenceClientInputError } from "../errors.js";

export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
baseten: {
conversational: new Baseten.BasetenConversationalTask(),
},
"black-forest-labs": {
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
},
Expand Down
27 changes: 27 additions & 0 deletions packages/inference/src/providers/baseten.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/**
* See the registered mapping of HF model ID => Baseten model ID here:
*
* https://huggingface.co/api/partners/baseten/models
*
* This is a publicly available mapping.
*
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
*
* - If you work at Baseten and want to update this mapping, please use the model mapping API we provide on huggingface.co
* - If you're a community member and want to add a new supported HF model to Baseten, please open an issue on the present repo
* and we will tag Baseten team members.
*
* Thanks!
*/
import {
BaseConversationalTask,
} from "./providerHelper.js";

const BASETEN_API_BASE_URL = "https://inference.baseten.co";

export class BasetenConversationalTask extends BaseConversationalTask {
constructor() {
super("baseten", BASETEN_API_BASE_URL);
}
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
* Example:
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
*/
baseten: {},
"black-forest-labs": {},
cerebras: {},
cohere: {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ export interface Options {
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";

export const INFERENCE_PROVIDERS = [
"baseten",
"black-forest-labs",
"cerebras",
"cohere",
Expand Down
167 changes: 167 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2343,4 +2343,171 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);

describe.concurrent(
"Baseten",
() => {
const client = new InferenceClient(env.HF_BASETEN_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["baseten"] = {
"Qwen/Qwen3-235B-A22B-Instruct-2507": {
provider: "baseten",
hfModelId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
providerId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
status: "live",
task: "conversational",
},
"Qwen/Qwen3-Coder-480B-A35B-Instruct": {
provider: "baseten",
hfModelId: "Qwen/Qwen3-Coder-480B-A35B-Instruct",
providerId: "Qwen/Qwen3-Coder-480B-A35B-Instruct",
status: "live",
task: "conversational",
},
"moonshotai/Kimi-K2-Instruct-0905": {
provider: "baseten",
hfModelId: "moonshotai/Kimi-K2-Instruct-0905",
providerId: "moonshotai/Kimi-K2-Instruct-0905",
status: "live",
task: "conversational",
},
"deepseek-ai/DeepSeek-V3.1": {
provider: "baseten",
hfModelId: "deepseek-ai/DeepSeek-V3.1",
providerId: "deepseek-ai/DeepSeek-V3.1",
status: "live",
task: "conversational",
},
"deepseek-ai/DeepSeek-V3-0324": {
provider: "baseten",
hfModelId: "deepseek-ai/DeepSeek-V3-0324",
providerId: "deepseek-ai/DeepSeek-V3-0324",
status: "live",
task: "conversational",
},
};

it("chatCompletion - Qwen3 235B Instruct", async () => {
const res = await client.chatCompletion({
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
provider: "baseten",
messages: [{ role: "user", content: "What is 5 + 3?" }],
max_tokens: 20,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/(eight|8)/i);
}
});

it("chatCompletion - Qwen3 Coder 480B", async () => {
const res = await client.chatCompletion({
model: "Qwen/Qwen3-Coder-480B-A35B-Instruct",
provider: "baseten",
messages: [{ role: "user", content: "Write a simple Python function to add two numbers" }],
max_tokens: 100,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/def.*add/i);
}
});

it("chatCompletion - Kimi K2 Instruct", async () => {
const res = await client.chatCompletion({
model: "moonshotai/Kimi-K2-Instruct-0905",
provider: "baseten",
messages: [{ role: "user", content: "What is the capital of Japan?" }],
temperature: 0.1,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/tokyo/i);
}
});

it("chatCompletion - DeepSeek V3.1", async () => {
const res = await client.chatCompletion({
model: "deepseek-ai/DeepSeek-V3.1",
provider: "baseten",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chatCompletion - DeepSeek V3", async () => {
const res = await client.chatCompletion({
model: "deepseek-ai/DeepSeek-V3-0324",
provider: "baseten",
messages: [{ role: "user", content: "What is 2 * 3?" }],
temperature: 0.1,
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toMatch(/(six|6)/i);
}
});

it("chatCompletion stream - Qwen3 235B", async () => {
const stream = client.chatCompletionStream({
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
provider: "baseten",
messages: [{ role: "user", content: "Count from 1 to 3" }],
stream: true,
max_tokens: 20,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
expect(fullResponse).toMatch(/1.*2.*3/);
});

it("chatCompletion stream - DeepSeek V3.1", async () => {
const stream = client.chatCompletionStream({
model: "deepseek-ai/DeepSeek-V3.1",
provider: "baseten",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

// Verify we got a meaningful response
expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
expect(fullResponse).toMatch(/(two|2)/i);
});
},
TIMEOUT
);
});
Loading