Skip to content

Commit 7ad1b8e

Browse files
[minor] Adding ECS module to boto3-refresh-session (#57)
* ecs * ecs changes * readme updates * updated _get_credentials * docs * doc updates: * static method * custom exceptions + repr * errors docs updates * flake8 updates * black updates
1 parent 80dd182 commit 7ad1b8e

16 files changed

+358
-113
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
- Auto-refreshing credentials for long-lived `boto3` sessions
5656
- Drop-in replacement for `boto3.session.Session`
57+
- Supports automatic refresh methods for STS and ECS
5758
- Supports `assume_role` configuration, custom STS clients, and profile / region configuration, as well as all other parameters supported by `boto3.session.Session`
5859
- Tested, documented, and published to PyPI
5960

README.template.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
- Auto-refreshing credentials for long-lived `boto3` sessions
5656
- Drop-in replacement for `boto3.session.Session`
57+
- Supports automatic refresh methods for STS and ECS
5758
- Supports `assume_role` configuration, custom STS clients, and profile / region configuration, as well as all other parameters supported by `boto3.session.Session`
5859
- Tested, documented, and published to PyPI
5960

boto3_refresh_session/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .ecs import ECSRefreshableSession
12
from .session import RefreshableSession
23
from .sts import STSRefreshableSession
34

boto3_refresh_session/ecs.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from __future__ import annotations
2+
3+
__all__ = ["ECSRefreshableSession"]
4+
5+
import os
6+
7+
import requests
8+
9+
from .exceptions import BRSError
10+
from .session import BaseRefreshableSession
11+
12+
_ECS_CREDENTIALS_RELATIVE_URI = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
13+
_ECS_CREDENTIALS_FULL_URI = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
14+
_ECS_AUTHORIZATION_TOKEN = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
15+
_DEFAULT_ENDPOINT_BASE = "http://169.254.170.2"
16+
17+
18+
class ECSRefreshableSession(BaseRefreshableSession, method="ecs"):
19+
"""A boto3 session that automatically refreshes temporary AWS credentials
20+
from the ECS container credentials metadata endpoint.
21+
22+
Parameters
23+
----------
24+
defer_refresh : bool, optional
25+
If ``True`` then temporary credentials are not automatically refreshed until
26+
they are explicitly needed. If ``False`` then temporary credentials refresh
27+
immediately upon expiration. It is highly recommended that you use ``True``.
28+
Default is ``True``.
29+
30+
Other Parameters
31+
----------------
32+
kwargs : dict
33+
Optional keyword arguments passed to :class:`boto3.session.Session`.
34+
"""
35+
36+
def __init__(self, defer_refresh: bool | None = None, **kwargs):
37+
super().__init__(**kwargs)
38+
39+
self._endpoint = self._resolve_endpoint()
40+
self._headers = self._build_headers()
41+
self._http = self._init_http_session()
42+
43+
self._refresh_using(
44+
credentials_method=self._get_credentials,
45+
defer_refresh=defer_refresh is not False,
46+
refresh_method="ecs-container-metadata",
47+
)
48+
49+
def _resolve_endpoint(self) -> str:
50+
uri = os.environ.get(_ECS_CREDENTIALS_FULL_URI) or os.environ.get(
51+
_ECS_CREDENTIALS_RELATIVE_URI
52+
)
53+
if not uri:
54+
raise BRSError(
55+
"Neither AWS_CONTAINER_CREDENTIALS_FULL_URI nor "
56+
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is set. "
57+
"Are you running inside an ECS container?"
58+
)
59+
if uri.startswith("http://") or uri.startswith("https://"):
60+
return uri
61+
return f"{_DEFAULT_ENDPOINT_BASE}{uri}"
62+
63+
def _build_headers(self) -> dict[str, str]:
64+
token = os.environ.get(_ECS_AUTHORIZATION_TOKEN)
65+
if token:
66+
return {"Authorization": f"Bearer {token}"}
67+
return {}
68+
69+
def _init_http_session(self) -> requests.Session:
70+
session = requests.Session()
71+
session.headers.update(self._headers)
72+
return session
73+
74+
def _get_credentials(self) -> dict[str, str]:
75+
try:
76+
response = self._http.get(self._endpoint, timeout=3)
77+
response.raise_for_status()
78+
except requests.RequestException as exc:
79+
raise BRSError(
80+
f"Failed to retrieve ECS credentials from {self._endpoint}"
81+
) from exc
82+
83+
credentials = response.json()
84+
required = {
85+
"AccessKeyId",
86+
"SecretAccessKey",
87+
"SessionToken",
88+
"Expiration",
89+
}
90+
if not required.issubset(credentials):
91+
raise BRSError(f"Incomplete credentials received: {credentials}")
92+
return {
93+
"access_key": credentials.get("AccessKeyId"),
94+
"secret_key": credentials.get("SecretAccessKey"),
95+
"token": credentials.get("SessionToken"),
96+
"expiry_time": credentials.get("Expiration"), # already ISO8601
97+
}
98+
99+
@staticmethod
100+
def get_identity() -> dict[str, str]:
101+
"""Returns metadata about ECS.
102+
103+
Returns
104+
-------
105+
dict[str, str]
106+
Dict containing metadata about ECS.
107+
"""
108+
109+
return {"method": "ecs", "source": "ecs-container-metadata"}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
class BRSError(Exception):
2+
"""The base exception for boto3-refresh-session.
3+
4+
Parameters
5+
----------
6+
message : str, optional
7+
The message to raise.
8+
"""
9+
10+
def __init__(self, message: str | None = None):
11+
self.message = "" if message is None else message
12+
super().__init__(self.message)
13+
14+
def __str__(self) -> str:
15+
return self.message
16+
17+
def __repr__(self) -> str:
18+
return f"{self.__class__.__name__}({repr(self.message)})"
19+
20+
21+
class BRSWarning(UserWarning):
22+
"""The base warning for boto3-refresh-session.
23+
24+
Parameters
25+
----------
26+
message : str, optional
27+
The message to raise.
28+
"""
29+
30+
def __init__(self, message: str | None = None):
31+
self.message = "" if message is None else message
32+
super().__init__(self.message)
33+
34+
def __str__(self) -> str:
35+
return self.message
36+
37+
def __repr__(self) -> str:
38+
return f"{self.__class__.__name__}({repr(self.message)})"

boto3_refresh_session/session.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,21 @@
11
from __future__ import annotations
22

3-
__doc__ = """
4-
boto3_refresh_session.session
5-
=============================
6-
7-
This module provides the main interface for constructing refreshable boto3 sessions.
8-
9-
The ``RefreshableSession`` class serves as a factory that dynamically selects the appropriate
10-
credential refresh strategy based on the ``method`` parameter, e.g., ``sts``.
11-
12-
Users can interact with AWS services just like they would with a normal :class:`boto3.session.Session`,
13-
with the added benefit of automatic credential refreshing.
14-
15-
Examples
16-
--------
17-
>>> from boto3_refresh_session import RefreshableSession
18-
>>> session = RefreshableSession(
19-
... assume_role_kwargs={"RoleArn": "...", "RoleSessionName": "..."},
20-
... region_name="us-east-1"
21-
... )
22-
>>> s3 = session.client("s3")
23-
>>> s3.list_buckets()
24-
25-
.. seealso::
26-
:class:`boto3_refresh_session.sts.STSRefreshableSession`
27-
28-
Factory interface
29-
-----------------
30-
.. autosummary::
31-
:toctree: generated/
32-
:nosignatures:
33-
34-
RefreshableSession
35-
"""
36-
373
__all__ = ["RefreshableSession"]
384

395
from abc import ABC, abstractmethod
406
from typing import Any, Callable, ClassVar, Literal, get_args
41-
from warnings import warn
427

438
from boto3.session import Session
449
from botocore.credentials import (
4510
DeferredRefreshableCredentials,
4611
RefreshableCredentials,
4712
)
4813

14+
from .exceptions import BRSError, BRSWarning
15+
4916
#: Type alias for all currently available credential refresh methods.
50-
Method = Literal["sts"]
51-
RefreshMethod = Literal["sts-assume-role"]
17+
Method = Literal["sts", "ecs"]
18+
RefreshMethod = Literal["sts-assume-role", "ecs-container-metadata"]
5219

5320

5421
class BaseRefreshableSession(ABC, Session):
@@ -77,7 +44,9 @@ def __init_subclass__(cls, method: Method):
7744

7845
# guarantees that methods are unique
7946
if method in BaseRefreshableSession.registry:
80-
warn(f"Method '{method}' is already registered. Overwriting.")
47+
BRSWarning(
48+
f"Method {repr(method)} is already registered. Overwriting."
49+
)
8150

8251
BaseRefreshableSession.registry[method] = cls
8352

@@ -108,6 +77,34 @@ def _refresh_using(
10877
refresh_using=credentials_method, method=refresh_method
10978
)
11079

80+
def refreshable_credentials(self) -> dict[str, str]:
81+
"""The current temporary AWS security credentials.
82+
83+
Returns
84+
-------
85+
dict[str, str]
86+
Temporary AWS security credentials containing:
87+
AWS_ACCESS_KEY_ID : str
88+
AWS access key identifier.
89+
AWS_SECRET_ACCESS_KEY : str
90+
AWS secret access key.
91+
AWS_SESSION_TOKEN : str
92+
AWS session token.
93+
"""
94+
95+
creds = self.get_credentials().get_frozen_credentials()
96+
return {
97+
"AWS_ACCESS_KEY_ID": creds.access_key,
98+
"AWS_SECRET_ACCESS_KEY": creds.secret_key,
99+
"AWS_SESSION_TOKEN": creds.token,
100+
}
101+
102+
@property
103+
def credentials(self) -> dict[str, str]:
104+
"""The current temporary AWS security credentials."""
105+
106+
return self.refreshable_credentials()
107+
111108

112109
class RefreshableSession:
113110
"""Factory class for constructing refreshable boto3 sessions using various authentication
@@ -134,11 +131,18 @@ class RefreshableSession:
134131
See Also
135132
--------
136133
boto3_refresh_session.sts.STSRefreshableSession
134+
boto3_refresh_session.ecs.ECSRefreshableSession
137135
"""
138136

139137
def __new__(
140138
cls, method: Method = "sts", **kwargs
141139
) -> BaseRefreshableSession:
140+
if method not in (methods := cls.get_available_methods()):
141+
raise BRSError(
142+
f"{repr(method)} is an invalid method parameter. Available methods are "
143+
f"{', '.join(repr(meth) for meth in methods)}."
144+
)
145+
142146
obj = BaseRefreshableSession.registry[method]
143147
return obj(**kwargs)
144148

boto3_refresh_session/sts.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,10 @@
11
from __future__ import annotations
22

3-
__doc__ = """
4-
boto3_refresh_session.sts
5-
=========================
6-
7-
Implements the STS-based credential refresh strategy for use with
8-
:class:`boto3_refresh_session.session.RefreshableSession`.
9-
10-
This module defines the :class:`STSRefreshableSession` class, which uses
11-
IAM role assumption via STS to automatically refresh temporary credentials
12-
in the background.
13-
14-
.. versionadded:: 1.1.0
15-
16-
Examples
17-
--------
18-
>>> from boto3_refresh_session import RefreshableSession
19-
>>> session = RefreshableSession(
20-
... method="sts",
21-
... assume_role_kwargs={
22-
... "RoleArn": "arn:aws:iam::123456789012:role/MyRole",
23-
... "RoleSessionName": "my-session"
24-
... },
25-
... region_name="us-east-1"
26-
... )
27-
>>> s3 = session.client("s3")
28-
>>> s3.list_buckets()
29-
30-
.. seealso::
31-
:class:`boto3_refresh_session.session.RefreshableSession`
32-
33-
STS
34-
---
35-
36-
.. autosummary::
37-
:toctree: generated/
38-
:nosignatures:
39-
40-
STSRefreshableSession
41-
"""
423
__all__ = ["STSRefreshableSession"]
434

445
from typing import Any
45-
from warnings import warn
466

7+
from .exceptions import BRSWarning
478
from .session import BaseRefreshableSession
489

4910

@@ -73,7 +34,7 @@ class STSRefreshableSession(BaseRefreshableSession, method="sts"):
7334
def __init__(
7435
self,
7536
assume_role_kwargs: dict,
76-
defer_refresh: bool = None,
37+
defer_refresh: bool | None = None,
7738
sts_client_kwargs: dict | None = None,
7839
**kwargs,
7940
):
@@ -84,8 +45,8 @@ def __init__(
8445
if sts_client_kwargs is not None:
8546
# overwriting 'service_name' in case it appears in sts_client_kwargs
8647
if "service_name" in sts_client_kwargs:
87-
warn(
88-
"The sts_client_kwargs parameter cannot contain values for service_name. Reverting to service_name = 'sts'."
48+
BRSWarning(
49+
"'sts_client_kwargs' cannot contain values for 'service_name'. Reverting to service_name = 'sts'."
8950
)
9051
del sts_client_kwargs["service_name"]
9152
self._sts_client = self.client(

doc/conf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@
8282
"members": True,
8383
"member-order": "bysource",
8484
"exclude-members": "__init__,__new__",
85+
"inherited-members": True,
8586
}
86-
autodoc_typehints = "none"
87-
autodoc_preserve_defaults = False
88-
autodoc_class_signature = "separated"
87+
autodoc_typehints = "signature"
88+
autodoc_inherit_docstrings = True
8989

9090
# numpydoc config
9191
numpydoc_show_class_members = False

0 commit comments

Comments
 (0)