Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8493,7 +8493,7 @@ def delete_collection_item(
@validate_hf_hub_args
def list_pending_access_requests(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
) -> list[AccessRequest]:
) -> Iterable[AccessRequest]:
"""
Get pending access requests for a given gated repo.
Expand All @@ -8516,7 +8516,7 @@ def list_pending_access_requests(
To disable authentication, pass `False`.
Returns:
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
be populated with user's answers.
Expand All @@ -8532,7 +8532,7 @@ def list_pending_access_requests(
>>> from huggingface_hub import list_pending_access_requests, accept_access_request
# List pending requests
>>> requests = list_pending_access_requests("meta-llama/Llama-2-7b")
>>> requests = list(list_pending_access_requests("meta-llama/Llama-2-7b"))
>>> len(requests)
411
>>> requests[0]
Expand All @@ -8552,12 +8552,12 @@ def list_pending_access_requests(
>>> accept_access_request("meta-llama/Llama-2-7b", "clem")
```
"""
return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)
yield from self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token)

@validate_hf_hub_args
def list_accepted_access_requests(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
) -> list[AccessRequest]:
) -> Iterable[AccessRequest]:
"""
Get accepted access requests for a given gated repo.
Expand All @@ -8582,7 +8582,7 @@ def list_accepted_access_requests(
To disable authentication, pass `False`.
Returns:
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
be populated with user's answers.
Expand All @@ -8597,7 +8597,7 @@ def list_accepted_access_requests(
```py
>>> from huggingface_hub import list_accepted_access_requests
>>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b")
>>> requests = list(list_accepted_access_requests("meta-llama/Llama-2-7b"))
>>> len(requests)
411
>>> requests[0]
Expand All @@ -8614,12 +8614,12 @@ def list_accepted_access_requests(
]
```
"""
return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)
yield from self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token)

@validate_hf_hub_args
def list_rejected_access_requests(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
) -> list[AccessRequest]:
) -> Iterable[AccessRequest]:
"""
Get rejected access requests for a given gated repo.
Expand All @@ -8644,7 +8644,7 @@ def list_rejected_access_requests(
To disable authentication, pass `False`.
Returns:
`list[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`Iterable[AccessRequest]`: An iterable of [`AccessRequest`] objects. Each time contains a `username`, `email`,
`status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will
be populated with user's answers.
Expand All @@ -8659,7 +8659,7 @@ def list_rejected_access_requests(
```py
>>> from huggingface_hub import list_rejected_access_requests
>>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b")
>>> requests = list(list_rejected_access_requests("meta-llama/Llama-2-7b"))
>>> len(requests)
411
>>> requests[0]
Expand All @@ -8676,36 +8676,33 @@ def list_rejected_access_requests(
]
```
"""
return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)
yield from self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token)

def _list_access_requests(
self,
repo_id: str,
status: Literal["accepted", "rejected", "pending"],
repo_type: Optional[str] = None,
token: Union[bool, str, None] = None,
) -> list[AccessRequest]:
) -> Iterable[AccessRequest]:
if repo_type not in constants.REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}")
if repo_type is None:
repo_type = constants.REPO_TYPE_MODEL

response = get_session().get(
for request in paginate(
f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}",
params={},
headers=self._build_hf_headers(token=token),
)
hf_raise_for_status(response)
return [
AccessRequest(
):
yield AccessRequest(
username=request["user"]["user"],
fullname=request["user"]["fullname"],
email=request["user"].get("email"),
status=request["status"],
timestamp=parse_datetime(request["timestamp"]),
fields=request.get("fields"), # only if custom fields in form
)
for request in response.json()
]

@validate_hf_hub_args
def cancel_access_request(
Expand Down
18 changes: 9 additions & 9 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4095,18 +4095,18 @@ def tearDown(self) -> None:

def test_access_requests_normal_usage(self) -> None:
# No access requests initially
requests = self._api.list_accepted_access_requests(self.repo_id)
requests = list(self._api.list_accepted_access_requests(self.repo_id))
assert len(requests) == 0
requests = self._api.list_pending_access_requests(self.repo_id)
requests = list(self._api.list_pending_access_requests(self.repo_id))
assert len(requests) == 0
requests = self._api.list_rejected_access_requests(self.repo_id)
requests = list(self._api.list_rejected_access_requests(self.repo_id))
assert len(requests) == 0

# Grant access to a user
self._api.grant_access(self.repo_id, OTHER_USER)

# User is in accepted list
requests = self._api.list_accepted_access_requests(self.repo_id)
requests = list(self._api.list_accepted_access_requests(self.repo_id))
assert len(requests) == 1
request = requests[0]
assert isinstance(request, AccessRequest)
Expand All @@ -4117,23 +4117,23 @@ def test_access_requests_normal_usage(self) -> None:

# Cancel access
self._api.cancel_access_request(self.repo_id, OTHER_USER)
requests = self._api.list_accepted_access_requests(self.repo_id)
requests = list(self._api.list_accepted_access_requests(self.repo_id))
assert len(requests) == 0 # not accepted anymore
requests = self._api.list_pending_access_requests(self.repo_id)
requests = list(self._api.list_pending_access_requests(self.repo_id))
assert len(requests) == 1
assert requests[0].username == OTHER_USER

# Reject access
self._api.reject_access_request(self.repo_id, OTHER_USER, rejection_reason="This is a rejection reason")
requests = self._api.list_pending_access_requests(self.repo_id)
requests = list(self._api.list_pending_access_requests(self.repo_id))
assert len(requests) == 0 # not pending anymore
requests = self._api.list_rejected_access_requests(self.repo_id)
requests = list(self._api.list_rejected_access_requests(self.repo_id))
assert len(requests) == 1
assert requests[0].username == OTHER_USER

# Accept again
self._api.accept_access_request(self.repo_id, OTHER_USER)
requests = self._api.list_accepted_access_requests(self.repo_id)
requests = list(self._api.list_accepted_access_requests(self.repo_id))
assert len(requests) == 1
assert requests[0].username == OTHER_USER

Expand Down
Loading