Skip to content

Commit 9b8cd3a

Browse files
tomsun28gongchaohanouticelinaSBrandeis
authored
feat: support zai as inference provider (#1766)
Hi 🤗 team, We’d like to apply to register [ZAI](https://z.ai/) as inference provider. Thanks One more question: when we tried to register the model mapping with the endpoint `POST api/partners/zai/models`, we received a 404 response. https://huggingface.co/docs/inference-providers/register-as-a-provider#3-model-mapping-api Does this action only work after the PR has been merged? --------- Co-authored-by: gongchao <[email protected]> Co-authored-by: célina <[email protected]> Co-authored-by: Simon Brandeis <[email protected]> Co-authored-by: SBrandeis <[email protected]>
1 parent 881c44b commit 9b8cd3a

File tree

6 files changed

+97
-0
lines changed

6 files changed

+97
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ Currently, we support the following providers:
6565
- [Cohere](https://cohere.com)
6666
- [Cerebras](https://cerebras.ai/)
6767
- [Groq](https://groq.com)
68+
- [Z.ai](https://z.ai/)
6869

6970
To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. The default value of the `provider` parameter is "auto", which will select the first of the providers available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers.
7071

@@ -100,6 +101,7 @@ Only a subset of models are supported when requesting third-party providers. You
100101
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
101102
- [Groq supported models](https://console.groq.com/docs/models)
102103
- [Novita AI supported models](https://huggingface.co/api/partners/novita/models)
104+
- [Z.ai supported models](https://huggingface.co/api/partners/zai-org/models)
103105

104106
**Important note:** To be compatible, the third-party API must adhere to the "standard" shape API we expect on HF model pages for each pipeline task type.
105107
This is not an issue for LLMs as everyone converged on the OpenAI API anyways, but can be more tricky for other tasks like "text-to-image" or "automatic-speech-recognition" where there exists no standard API. Let us know if any help is needed or if we can make things easier for you!

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import * as Replicate from "../providers/replicate.js";
5050
import * as Sambanova from "../providers/sambanova.js";
5151
import * as Scaleway from "../providers/scaleway.js";
5252
import * as Together from "../providers/together.js";
53+
import * as Zai from "../providers/zai-org.js";
5354
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
5455
import { InferenceClientInputError } from "../errors.js";
5556

@@ -164,6 +165,9 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
164165
conversational: new Together.TogetherConversationalTask(),
165166
"text-generation": new Together.TogetherTextGenerationTask(),
166167
},
168+
"zai-org": {
169+
conversational: new Zai.ZaiConversationalTask(),
170+
},
167171
};
168172

169173
/**

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
3737
sambanova: {},
3838
scaleway: {},
3939
together: {},
40+
"zai-org": {},
4041
};
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* See the registered mapping of HF model ID => ZAI model ID here:
3+
*
4+
* https://huggingface.co/api/partners/zai-org/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at zai and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to zai, please open an issue on the present repo
13+
* and we will tag zai team members.
14+
*
15+
* Thanks!
16+
*/
17+
import { BaseConversationalTask } from "./providerHelper.js";
18+
import type { HeaderParams } from "../types.js";
19+
20+
const ZAI_API_BASE_URL = "https://api.z.ai";
21+
22+
export class ZaiConversationalTask extends BaseConversationalTask {
23+
constructor() {
24+
super("zai-org", ZAI_API_BASE_URL);
25+
}
26+
27+
override prepareHeaders(params: HeaderParams, binary: boolean): Record<string, string> {
28+
const headers = super.prepareHeaders(params, binary);
29+
headers["x-source-channel"] = "hugging_face";
30+
headers["accept-language"] = "en-US,en";
31+
return headers;
32+
}
33+
34+
override makeRoute(): string {
35+
return "/api/paas/v4/chat/completions";
36+
}
37+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export const INFERENCE_PROVIDERS = [
6464
"sambanova",
6565
"scaleway",
6666
"together",
67+
"zai-org",
6768
] as const;
6869

6970
export const PROVIDERS_OR_POLICIES = [...INFERENCE_PROVIDERS, "auto"] as const;

packages/inference/test/InferenceClient.spec.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,6 +2134,58 @@ describe.skip("InferenceClient", () => {
21342134
},
21352135
TIMEOUT
21362136
);
2137+
describe.concurrent(
2138+
"ZAI",
2139+
() => {
2140+
const client = new InferenceClient(env.HF_ZAI_KEY ?? "dummy");
2141+
2142+
HARDCODED_MODEL_INFERENCE_MAPPING["zai-org"] = {
2143+
"zai-org/GLM-4.5": {
2144+
provider: "zai-org",
2145+
hfModelId: "zai-org/GLM-4.5",
2146+
providerId: "glm-4.5",
2147+
status: "live",
2148+
task: "conversational",
2149+
},
2150+
};
2151+
2152+
it("chatCompletion", async () => {
2153+
const res = await client.chatCompletion({
2154+
model: "zai-org/GLM-4.5",
2155+
provider: "zai-org",
2156+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
2157+
});
2158+
if (res.choices && res.choices.length > 0) {
2159+
const completion = res.choices[0].message?.content;
2160+
expect(completion).toContain("two");
2161+
}
2162+
});
2163+
2164+
it("chatCompletion stream", async () => {
2165+
const stream = client.chatCompletionStream({
2166+
model: "zai-org/GLM-4.5",
2167+
provider: "zai-org",
2168+
messages: [{ role: "user", content: "Say 'this is a test'" }],
2169+
stream: true,
2170+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
2171+
2172+
let fullResponse = "";
2173+
for await (const chunk of stream) {
2174+
if (chunk.choices && chunk.choices.length > 0) {
2175+
const content = chunk.choices[0].delta?.content;
2176+
if (content) {
2177+
fullResponse += content;
2178+
}
2179+
}
2180+
}
2181+
2182+
// Verify we got a meaningful response
2183+
expect(fullResponse).toBeTruthy();
2184+
expect(fullResponse.length).toBeGreaterThan(0);
2185+
});
2186+
},
2187+
TIMEOUT
2188+
);
21372189
describe.concurrent(
21382190
"OVHcloud",
21392191
() => {

0 commit comments

Comments
 (0)