Skip to content

Commit 316db10

Browse files
authored
Paginated results in list_user_access (#3535)
1 parent c350eed commit 316db10

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8493,7 +8493,7 @@ def delete_collection_item(
84938493
@validate_hf_hub_args
84948494
def list_pending_access_requests(
84958495
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8496-
) -> list[AccessRequest]:
8496+
) -> Iterable[AccessRequest]:
84978497
"""
84988498
Get pending access requests for a given gated repo.
84998499
@@ -8516,7 +8516,7 @@ def list_pending_access_requests(
85168516
To disable authentication, pass `False`.
85178517
85188518
Returns:
8519-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8519+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
85208520
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
85218521
be populated with user's answers.
85228522
@@ -8532,7 +8532,7 @@ def list_pending_access_requests(
85328532
>>> from huggingface_hub import list_pending_access_requests, accept_access_request
85338533
85348534
# List pending requests
8535-
>>> requests = list_pending_access_requests("meta-llama/Llama-2-7b")
8535+
>>> requests = list(list_pending_access_requests("meta-llama/Llama-2-7b"))
85368536
>>> len(requests)
85378537
411
85388538
>>> requests[0]
@@ -8552,12 +8552,12 @@ def list_pending_access_requests(
85528552
>>> accept_access_request("meta-llama/Llama-2-7b", "clem")
85538553
```
85548554
"""
8555-
return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)
8555+
yield from self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)
85568556

85578557
@validate_hf_hub_args
85588558
def list_accepted_access_requests(
85598559
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8560-
) -> list[AccessRequest]:
8560+
) -> Iterable[AccessRequest]:
85618561
"""
85628562
Get accepted access requests for a given gated repo.
85638563
@@ -8582,7 +8582,7 @@ def list_accepted_access_requests(
85828582
To disable authentication, pass `False`.
85838583
85848584
Returns:
8585-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8585+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
85868586
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
85878587
be populated with user's answers.
85888588
@@ -8597,7 +8597,7 @@ def list_accepted_access_requests(
85978597
```py
85988598
>>> from huggingface_hub import list_accepted_access_requests
85998599
8600-
>>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b")
8600+
>>> requests = list(list_accepted_access_requests("meta-llama/Llama-2-7b"))
86018601
>>> len(requests)
86028602
411
86038603
>>> requests[0]
@@ -8614,12 +8614,12 @@ def list_accepted_access_requests(
86148614
]
86158615
```
86168616
"""
8617-
return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)
8617+
yield from self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)
86188618

86198619
@validate_hf_hub_args
86208620
def list_rejected_access_requests(
86218621
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
8622-
) -> list[AccessRequest]:
8622+
) -> Iterable[AccessRequest]:
86238623
"""
86248624
Get rejected access requests for a given gated repo.
86258625
@@ -8644,7 +8644,7 @@ def list_rejected_access_requests(
86448644
To disable authentication, pass `False`.
86458645
86468646
Returns:
8647-
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
8647+
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
86488648
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
86498649
be populated with user's answers.
86508650
@@ -8659,7 +8659,7 @@ def list_rejected_access_requests(
86598659
```py
86608660
>>> from huggingface_hub import list_rejected_access_requests
86618661
8662-
>>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b")
8662+
>>> requests = list(list_rejected_access_requests("meta-llama/Llama-2-7b"))
86638663
>>> len(requests)
86648664
411
86658665
>>> requests[0]
@@ -8676,36 +8676,33 @@ def list_rejected_access_requests(
86768676
]
86778677
```
86788678
"""
8679-
return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)
8679+
yield from self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)
86808680

86818681
def _list_access_requests(
86828682
self,
86838683
repo_id: str,
86848684
status: Literal["accepted", "rejected", "pending"],
86858685
repo_type: Optional[str] = None,
86868686
token: Union[bool, str, None] = None,
8687-
) -> list[AccessRequest]:
8687+
) -> Iterable[AccessRequest]:
86888688
if repo_type not in constants.REPO_TYPES:
86898689
raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}")
86908690
if repo_type is None:
86918691
repo_type = constants.REPO_TYPE_MODEL
86928692

8693-
response = get_session().get(
8693+
for request in paginate(
86948694
f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}",
8695+
params={},
86958696
headers=self._build_hf_headers(token=token),
8696-
)
8697-
hf_raise_for_status(response)
8698-
return [
8699-
AccessRequest(
8697+
):
8698+
yield AccessRequest(
87008699
username=request["user"]["user"],
87018700
fullname=request["user"]["fullname"],
87028701
email=request["user"].get("email"),
87038702
status=request["status"],
87048703
timestamp=parse_datetime(request["timestamp"]),
87058704
fields=request.get("fields"), # only if custom fields in form
87068705
)
8707-
for request in response.json()
8708-
]
87098706

87108707
@validate_hf_hub_args
87118708
def cancel_access_request(

tests/test_hf_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4095,18 +4095,18 @@ def tearDown(self) -> None:
40954095

40964096
def test_access_requests_normal_usage(self) -> None:
40974097
# No access requests initially
4098-
requests = self._api.list_accepted_access_requests(self.repo_id)
4098+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
40994099
assert len(requests) == 0
4100-
requests = self._api.list_pending_access_requests(self.repo_id)
4100+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41014101
assert len(requests) == 0
4102-
requests = self._api.list_rejected_access_requests(self.repo_id)
4102+
requests = list(self._api.list_rejected_access_requests(self.repo_id))
41034103
assert len(requests) == 0
41044104

41054105
# Grant access to a user
41064106
self._api.grant_access(self.repo_id, OTHER_USER)
41074107

41084108
# User is in accepted list
4109-
requests = self._api.list_accepted_access_requests(self.repo_id)
4109+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41104110
assert len(requests) == 1
41114111
request = requests[0]
41124112
assert isinstance(request, AccessRequest)
@@ -4117,23 +4117,23 @@ def test_access_requests_normal_usage(self) -> None:
41174117

41184118
# Cancel access
41194119
self._api.cancel_access_request(self.repo_id, OTHER_USER)
4120-
requests = self._api.list_accepted_access_requests(self.repo_id)
4120+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41214121
assert len(requests) == 0 # not accepted anymore
4122-
requests = self._api.list_pending_access_requests(self.repo_id)
4122+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41234123
assert len(requests) == 1
41244124
assert requests[0].username == OTHER_USER
41254125

41264126
# Reject access
41274127
self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason")
4128-
requests = self._api.list_pending_access_requests(self.repo_id)
4128+
requests = list(self._api.list_pending_access_requests(self.repo_id))
41294129
assert len(requests) == 0 # not pending anymore
4130-
requests = self._api.list_rejected_access_requests(self.repo_id)
4130+
requests = list(self._api.list_rejected_access_requests(self.repo_id))
41314131
assert len(requests) == 1
41324132
assert requests[0].username == OTHER_USER
41334133

41344134
# Accept again
41354135
self._api.accept_access_request(self.repo_id, OTHER_USER)
4136-
requests = self._api.list_accepted_access_requests(self.repo_id)
4136+
requests = list(self._api.list_accepted_access_requests(self.repo_id))
41374137
assert len(requests) == 1
41384138
assert requests[0].username == OTHER_USER
41394139

0 commit comments

Comments
 (0)