Skip to content

Commit d95c300

Browse files
committed
fix: endpointUrl + provider
1 parent ee3840d commit d95c300

32 files changed

+40
-32
lines changed

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,16 @@ export async function getInferenceProviderMapping(
9999

100100
export async function resolveProvider(
101101
provider?: InferenceProviderOrPolicy,
102-
modelId?: string
102+
modelId?: string,
103+
endpointUrl?: string
103104
): Promise<InferenceProvider> {
105+
if (endpointUrl) {
106+
if (provider) {
107+
throw new Error("Specifying both endpointUrl and provider is not supported.");
108+
}
109+
/// Defaulting to hf-inference helpers / API
110+
return "hf-inference";
111+
}
104112
if (!provider) {
105113
console.log(
106114
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."

packages/inference/src/tasks/audio/audioClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export async function audioClassification(
1616
args: AudioClassificationArgs,
1717
options?: Options
1818
): Promise<AudioClassificationOutput> {
19-
const provider = await resolveProvider(args.provider, args.model);
19+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2020
const providerHelper = getProviderHelper(provider, "audio-classification");
2121
const payload = preparePayload(args);
2222
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export async function automaticSpeechRecognition(
1919
args: AutomaticSpeechRecognitionArgs,
2020
options?: Options
2121
): Promise<AutomaticSpeechRecognitionOutput> {
22-
const provider = await resolveProvider(args.provider, args.model);
22+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2323
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
2424
const payload = await buildPayload(args);
2525
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ interface OutputUrlTextToSpeechGeneration {
1313
* Recommended model: espnet/kan-bayashi_ljspeech_vits
1414
*/
1515
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
16-
const provider = await resolveProvider(args.provider, args.model);
16+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1717
const providerHelper = getProviderHelper(provider, "text-to-speech");
1818
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
1919
...options,

packages/inference/src/tasks/custom/request.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function request<T>(
1717
console.warn(
1818
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1919
);
20-
const provider = await resolveProvider(args.provider, args.model);
20+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2121
const providerHelper = getProviderHelper(provider, options?.task);
2222
const result = await innerRequest<T>(args, providerHelper, options);
2323
return result.data;

packages/inference/src/tasks/custom/streamingRequest.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ export async function* streamingRequest<T>(
1717
console.warn(
1818
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
1919
);
20-
const provider = await resolveProvider(args.provider, args.model);
20+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2121
const providerHelper = getProviderHelper(provider, options?.task);
2222
yield* innerStreamingRequest(args, providerHelper, options);
2323
}

packages/inference/src/tasks/cv/imageClassification.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export async function imageClassification(
1515
args: ImageClassificationArgs,
1616
options?: Options
1717
): Promise<ImageClassificationOutput> {
18-
const provider = await resolveProvider(args.provider, args.model);
18+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1919
const providerHelper = getProviderHelper(provider, "image-classification");
2020
const payload = preparePayload(args);
2121
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {

packages/inference/src/tasks/cv/imageSegmentation.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export async function imageSegmentation(
1515
args: ImageSegmentationArgs,
1616
options?: Options
1717
): Promise<ImageSegmentationOutput> {
18-
const provider = await resolveProvider(args.provider, args.model);
18+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1919
const providerHelper = getProviderHelper(provider, "image-segmentation");
2020
const payload = preparePayload(args);
2121
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {

packages/inference/src/tasks/cv/imageToImage.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
1212
* Recommended model: lllyasviel/sd-controlnet-depth
1313
*/
1414
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
15-
const provider = await resolveProvider(args.provider, args.model);
15+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1616
const providerHelper = getProviderHelper(provider, "image-to-image");
1717
let reqArgs: RequestArgs;
1818
if (!args.parameters) {

packages/inference/src/tasks/cv/imageToText.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
1111
* This task reads some image input and outputs the text caption.
1212
*/
1313
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
14-
const provider = await resolveProvider(args.provider, args.model);
14+
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1515
const providerHelper = getProviderHelper(provider, "image-to-text");
1616
const payload = preparePayload(args);
1717
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {

0 commit comments

Comments
 (0)