Skip to content

Commit 425ad7a

Browse files
committed
feat(endpoints): scaling metric and threshold
Add the possibility to customize both the scaling metric and threshold when creating or updating an endpoint. Signed-off-by: Raphael Glon <[email protected]>
1 parent d0b2330 commit 425ad7a

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from concurrent.futures import Future, ThreadPoolExecutor
2525
from dataclasses import asdict, dataclass, field
2626
from datetime import datetime
27+
from enum import Enum
2728
from functools import wraps
2829
from itertools import islice
2930
from pathlib import Path
@@ -1646,6 +1647,11 @@ def _inner(self, *args, **kwargs):
16461647
return _inner # type: ignore
16471648

16481649

1650+
class EndpointsScalingMetric(Enum):
1651+
pending_requests = "pendingRequests"
1652+
hardware_usage = "hardwareUsage"
1653+
1654+
16491655
class HfApi:
16501656
"""
16511657
Client to interact with the Hugging Face Hub via HTTP.
@@ -7391,6 +7397,8 @@ def create_inference_endpoint(
73917397
account_id: Optional[str] = None,
73927398
min_replica: int = 1,
73937399
max_replica: int = 1,
7400+
scaling_metric: Optional[str | EndpointsScalingMetric] = None,
7401+
scaling_threshold: Optional[int] = None,
73947402
scale_to_zero_timeout: Optional[int] = None,
73957403
revision: Optional[str] = None,
73967404
task: Optional[str] = None,
@@ -7431,6 +7439,12 @@ def create_inference_endpoint(
74317439
scaling to zero, set this value to 0 and adjust `scale_to_zero_timeout` accordingly. Defaults to 1.
74327440
max_replica (`int`, *optional*):
74337441
The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1.
7442+
scaling_metric (`str`, *optional*):
7443+
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Optional.
7444+
Defaults to None (meaning: let the hf endpoints service specify the threshold).
7445+
scaling_threshold (`int`, *optional*):
7446+
The scaling metric threshold to trigger a scale up. Optional. Ignored when scaling metric is not
7447+
provided. Defaults to None (meaning: let the hf endpoints service specify the threshold).
74347448
scale_to_zero_timeout (`int`, *optional*):
74357449
The duration in minutes before an inactive endpoint is scaled to zero, or no scaling to zero if
74367450
set to None and `min_replica` is not 0. Defaults to None.
@@ -7582,6 +7596,9 @@ def create_inference_endpoint(
75827596
},
75837597
"type": type,
75847598
}
7599+
if scaling_metric:
7600+
scaling_metric = EndpointsScalingMetric(scaling_metric)
7601+
payload["compute"]["scaling"]["measure"] = {scaling_metric.value: scaling_threshold}
75857602
if env:
75867603
payload["model"]["env"] = env
75877604
if secrets:
@@ -7746,6 +7763,8 @@ def update_inference_endpoint(
77467763
min_replica: Optional[int] = None,
77477764
max_replica: Optional[int] = None,
77487765
scale_to_zero_timeout: Optional[int] = None,
7766+
scaling_metric: Optional[str | EndpointsScalingMetric] = None,
7767+
scaling_threshold: Optional[int] = None,
77497768
# Model update
77507769
repository: Optional[str] = None,
77517770
framework: Optional[str] = None,
@@ -7786,7 +7805,12 @@ def update_inference_endpoint(
77867805
The maximum number of replicas (instances) to scale to for the Inference Endpoint.
77877806
scale_to_zero_timeout (`int`, *optional*):
77887807
The duration in minutes before an inactive endpoint is scaled to zero.
7789-
7808+
scaling_metric (`str`, *optional*):
7809+
The metric reference for scaling. Either "pendingRequests" or "hardwareUsage" when provided. Optional.
7810+
Defaults to None.
7811+
scaling_threshold (`int`, *optional*):
7812+
The scaling metric threshold to trigger a scale up. Optional. Ignored when scaling metric is not
7813+
provided. Defaults to None.
77907814
repository (`str`, *optional*):
77917815
The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`).
77927816
framework (`str`, *optional*):
@@ -7840,6 +7864,9 @@ def update_inference_endpoint(
78407864
payload["compute"]["scaling"]["minReplica"] = min_replica
78417865
if scale_to_zero_timeout is not None:
78427866
payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout
7867+
if scaling_metric:
7868+
scaling_metric = EndpointsScalingMetric(scaling_metric)
7869+
payload["compute"]["scaling"]["measure"] = {scaling_metric.value: scaling_threshold}
78437870
if repository is not None:
78447871
payload["model"]["repository"] = repository
78457872
if framework is not None:

0 commit comments

Comments
 (0)