Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,13 @@ def chat_completion(
payload_model = model or self.model

# Get the provider helper
provider_helper = get_provider_helper(self.provider, task="conversational", model=payload_model)
provider_helper = get_provider_helper(
self.provider,
task="conversational",
model=model_id_or_url
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
else payload_model,
)

# Prepare the payload
parameters = {
Expand Down
8 changes: 7 additions & 1 deletion src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,13 @@ async def chat_completion(
payload_model = model or self.model

# Get the provider helper
provider_helper = get_provider_helper(self.provider, task="conversational", model=payload_model)
provider_helper = get_provider_helper(
self.provider,
task="conversational",
model=model_id_or_url
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
else payload_model,
)

# Prepare the payload
parameters = {
Expand Down
4 changes: 3 additions & 1 deletion src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def get_provider_helper(
ValueError: If provider or task is not supported
"""

if model is None and provider in (None, "auto"):
if (model is None and provider in (None, "auto")) or (
model is not None and model.startswith(("http://", "https://"))
):
provider = "hf-inference"

if provider is None:
Expand Down
79 changes: 79 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,3 +1087,82 @@ def test_warning_if_bill_to_with_direct_calls(self):
match="You've provided an external provider's API key, so requests will be billed directly by the provider.",
):
InferenceClient(bill_to="openai", token="replicate_key", provider="replicate")


@pytest.mark.parametrize(
"client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model",
[
# passing a custom endpoint in the model argument
pytest.param(
"https://my-custom-endpoint.com/custom_path",
"model",
"https://my-custom-endpoint.com/custom_path/v1/chat/completions",
"dummy",
id="client_model_is_url",
),
Comment on lines +1096 to +1102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this is the expected behavior? In the past I've made a difference between InferenceClient(model="https://...").chat_completion and InferenceClient(base_url="https://...").chat_completion but I haven't tested it for a very long time. My reasoning is that if we always add /v1/chat/completions to the URL, we can't provide a URL that does not support it (what if "https://..." is already the URL to the endpoint handling chat_completion?)

but ok to remove this distinction if it helps with clarity / aligns better with JS client

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be mistaken but I'm not sure we had this distinction in the pre-provider implementation of InferenceClient

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
but imo if you pass an URL (either via model or base_url), when you call chat_completion you're calling /chat/completions route by default (same behavior as the OpenAI client), but agree that we cannot customize this route, which is a bit annoying but I'm not sure if it's worth doing it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking back. I took a look at #2540 and always adding (/v1)/chat/completions seems indeed the way to go (more consistent). Looks like we merged both behaviors even before the "pre-provider" implementation ^^

# passing a custom endpoint in the base_url argument
pytest.param(
"https://another-endpoint.com/v1/",
"base_url",
"https://another-endpoint.com/v1/chat/completions",
"dummy",
id="client_base_url_is_url",
),
# passing a model ID
pytest.param(
"username/repo_name",
"model",
"https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions",
"username/repo_name",
id="client_model_is_id",
),
# passing a custom endpoint in the model argument
pytest.param(
"https://specific-chat-endpoint.com/v1/chat/completions",
"model",
"https://specific-chat-endpoint.com/v1/chat/completions",
"dummy",
id="client_model_is_full_chat_url",
),
# passing a localhost URL in the model argument
pytest.param(
"http://localhost:8080",
"model",
"http://localhost:8080/v1/chat/completions",
"dummy",
id="client_model_is_localhost_url",
),
# passing a localhost URL in the base_url argument
pytest.param(
"http://127.0.0.1:8000/custom/path/v1",
"base_url",
"http://127.0.0.1:8000/custom/path/v1/chat/completions",
"dummy",
id="client_base_url_is_localhost_ip_with_path",
),
],
)
def test_chat_completion_url_resolution(
mocker, client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model
):
init_kwargs = {init_kwarg_name: client_init_arg, "provider": "hf-inference"}
client = InferenceClient(**init_kwargs)

mock_response_content = b'{"choices": [{"message": {"content": "Mock response"}}]}'
mocker.patch(
"huggingface_hub.inference._providers.hf_inference._check_supported_task",
return_value=None,
)

with patch.object(InferenceClient, "_inner_post", return_value=mock_response_content) as mock_inner_post:
client.chat_completion(messages=[{"role": "user", "content": "Hello?"}], stream=False)

mock_inner_post.assert_called_once()

request_params = mock_inner_post.call_args[0][0]
inner_post_kwargs = mock_inner_post.call_args[1]

assert request_params.url == expected_request_url

assert request_params.json is not None
assert request_params.json.get("model") == expected_payload_model
Loading