Skip to content

Commit 90b1c5d

Browse files
committed
add refusal
1 parent 9488831 commit 90b1c5d

File tree

4 files changed

+41
-12
lines changed

4 files changed

+41
-12
lines changed

src/huggingface_hub/_webhooks_payload.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,27 @@ def model_json_schema(cls, *args, **kwargs) -> dict[str, Any]:
3939
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
4040
)
4141

42+
@classmethod
43+
def schema(cls, *args, **kwargs) -> dict[str, Any]:
44+
raise ImportError(
45+
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
46+
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
47+
)
48+
4249
@classmethod
4350
def model_validate_json(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel":
4451
raise ImportError(
4552
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
4653
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
4754
)
4855

56+
@classmethod
57+
def parse_raw(cls, json_data: str | bytes | bytearray, *args, **kwargs) -> "BaseModel":
58+
raise ImportError(
59+
"You must have `pydantic` installed to use `WebhookPayload`. This is an optional dependency that"
60+
" should be installed separately. Please run `pip install --upgrade pydantic` and retry."
61+
)
62+
4963

5064
# This is an adaptation of the ReportV3 interface implemented in moon-landing. V0, V1 and V2 have been ignored as they
5165
# are not in used anymore. To keep in sync when format is updated in

src/huggingface_hub/inference/_client.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -878,13 +878,15 @@ def chat_completion(
878878
```
879879
"""
880880
if issubclass(response_format, BaseModel):
881-
base_model = response_format
881+
response_model = response_format
882882
response_format = ChatCompletionInputGrammarType(
883883
type="json",
884-
value=base_model.model_json_schema(),
884+
value=response_model.model_json_schema()
885+
if hasattr(response_model, "model_json_schema")
886+
else response_model.schema(),
885887
)
886888
else:
887-
base_model = None
889+
response_model = None
888890

889891
model_url = self._resolve_chat_completion_url(model)
890892

@@ -922,13 +924,18 @@ def chat_completion(
922924
return _stream_chat_completion_response(data) # type: ignore[arg-type]
923925

924926
chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
925-
if base_model:
927+
if response_model:
926928
for choice in chat_completion_output.choices:
927929
if choice.message.content:
928930
try:
929-
choice.message.parsed = base_model.model_validate_json(choice.message.content)
931+
# pydantic v2 uses model_validate_json
932+
choice.message.parsed = (
933+
response_model.model_validate_json(choice.message.content)
934+
if hasattr(response_model, "model_validate_json")
935+
else response_model.parse_raw(choice.message.content)
936+
)
930937
except ValueError:
931-
pass
938+
choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}"
932939
return chat_completion_output
933940

934941
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -933,13 +933,15 @@ async def chat_completion(
933933
```
934934
"""
935935
if issubclass(response_format, BaseModel):
936-
base_model = response_format
936+
response_model = response_format
937937
response_format = ChatCompletionInputGrammarType(
938938
type="json",
939-
value=base_model.model_json_schema(),
939+
value=response_model.model_json_schema()
940+
if hasattr(response_model, "model_json_schema")
941+
else response_model.schema(),
940942
)
941943
else:
942-
base_model = None
944+
response_model = None
943945

944946
model_url = self._resolve_chat_completion_url(model)
945947

@@ -977,13 +979,18 @@ async def chat_completion(
977979
return _async_stream_chat_completion_response(data) # type: ignore[arg-type]
978980

979981
chat_completion_output = ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]
980-
if base_model:
982+
if response_model:
981983
for choice in chat_completion_output.choices:
982984
if choice.message.content:
983985
try:
984-
choice.message.parsed = base_model.model_validate_json(choice.message.content)
986+
# pydantic v2 uses model_validate_json
987+
choice.message.parsed = (
988+
response_model.model_validate_json(choice.message.content)
989+
if hasattr(response_model, "model_validate_json")
990+
else response_model.parse_raw(choice.message.content)
991+
)
985992
except ValueError:
986-
pass
993+
choice.message.refusal = f"Failed to generate the response as a {response_model.__name__}"
987994
return chat_completion_output
988995

989996
def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:

src/huggingface_hub/inference/_generated/types/chat_completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class ChatCompletionOutputMessage(BaseInferenceType):
199199
content: Optional[str] = None
200200
tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None
201201
parsed: Optional[BaseModel] = None
202+
refusal: Optional[str] = None
202203

203204

204205
@dataclass

0 commit comments

Comments
 (0)