Skip to content

Commit 9638797

Browse files
authored
Merge pull request #72 from AzureAD/release-0.5.0
Release 0.5.0
2 parents 210103c + 04b1edf commit 9638797

File tree

9 files changed

+191
-29
lines changed

9 files changed

+191
-29
lines changed

msal/application.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import requests
1111

12-
from .oauth2cli import Client, JwtSigner
12+
from .oauth2cli import Client, JwtAssertionCreator
1313
from .authority import Authority
1414
from .mex import send_request as mex_send_request
1515
from .wstrust_request import send_request as wst_send_request
@@ -18,7 +18,7 @@
1818

1919

2020
# The __init__.py will import this. Not the other way around.
21-
__version__ = "0.4.1"
21+
__version__ = "0.5.0"
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -50,16 +50,33 @@ def decorate_scope(
5050
return list(decorated)
5151

5252

53+
def extract_certs(public_cert_content):
54+
# Parses raw public certificate file contents and returns a list of strings
55+
# Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())}
56+
public_certificates = re.findall(
57+
r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----',
58+
public_cert_content, re.I)
59+
if public_certificates:
60+
return [cert.strip() for cert in public_certificates]
61+
# The public cert tags are not found in the input,
62+
# let's make best effort to exclude a private key pem file.
63+
if "PRIVATE KEY" in public_cert_content:
64+
raise ValueError(
65+
"We expect your public key but detect a private key instead")
66+
return [public_cert_content.strip()]
67+
68+
5369
class ClientApplication(object):
5470

5571
def __init__(
5672
self, client_id,
5773
client_credential=None, authority=None, validate_authority=True,
5874
token_cache=None,
59-
verify=True, proxies=None, timeout=None):
75+
verify=True, proxies=None, timeout=None,
76+
client_claims=None):
6077
"""Create an instance of application.
6178
62-
:param client_id: Your app has a clinet_id after you register it on AAD.
79+
:param client_id: Your app has a client_id after you register it on AAD.
6380
:param client_credential:
6481
For :class:`PublicClientApplication`, you simply use `None` here.
6582
For :class:`ConfidentialClientApplication`,
@@ -69,6 +86,28 @@ def __init__(
6986
{
7087
"private_key": "...-----BEGIN PRIVATE KEY-----...",
7188
"thumbprint": "A1B2C3D4E5F6...",
89+
"public_certificate": "...-----BEGIN CERTIFICATE-----..." (Optional. See below.)
90+
}
91+
92+
*Added in version 0.5.0*:
93+
public_certificate (optional) is public key certificate
94+
which will be sent through 'x5c' JWT header only for
95+
subject name and issuer authentication to support cert auto rolls.
96+
97+
:param dict client_claims:
98+
*Added in version 0.5.0*:
99+
It is a dictionary of extra claims that would be signed by
100+
by this :class:`ConfidentialClientApplication` 's private key.
101+
For example, you can use {"client_ip": "x.x.x.x"}.
102+
You may also override any of the following default claims::
103+
104+
{
105+
"aud": the_token_endpoint,
106+
"iss": self.client_id,
107+
"sub": same_as_issuer,
108+
"exp": now + 10_min,
109+
"iat": now,
110+
"jti": a_random_uuid
72111
}
73112
74113
:param str authority:
@@ -95,6 +134,7 @@ def __init__(
95134
"""
96135
self.client_id = client_id
97136
self.client_credential = client_credential
137+
self.client_claims = client_claims
98138
self.verify = verify
99139
self.proxies = proxies
100140
self.timeout = timeout
@@ -113,11 +153,15 @@ def _build_client(self, client_credential, authority):
113153
if isinstance(client_credential, dict):
114154
assert ("private_key" in client_credential
115155
and "thumbprint" in client_credential)
116-
signer = JwtSigner(
156+
headers = {}
157+
if 'public_certificate' in client_credential:
158+
headers["x5c"] = extract_certs(client_credential['public_certificate'])
159+
assertion = JwtAssertionCreator(
117160
client_credential["private_key"], algorithm="RS256",
118-
sha1_thumbprint=client_credential.get("thumbprint"))
119-
client_assertion = signer.sign_assertion(
120-
audience=authority.token_endpoint, issuer=self.client_id)
161+
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
162+
client_assertion = assertion.create_regenerative_assertion(
163+
audience=authority.token_endpoint, issuer=self.client_id,
164+
additional_claims=self.client_claims or {})
121165
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
122166
else:
123167
default_body['client_secret'] = client_credential

msal/oauth2cli/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"
22

33
from .oidc import Client
4-
from .assertion import JwtSigner
4+
from .assertion import JwtAssertionCreator
5+
from .assertion import JwtSigner # Obsolete. For backward compatibility.
56

msal/oauth2cli/assertion.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,57 @@
99

1010
logger = logging.getLogger(__name__)
1111

12-
class Signer(object):
13-
def sign_assertion(
14-
self, audience, issuer, subject, expires_at,
12+
class AssertionCreator(object):
13+
def create_normal_assertion(
14+
self, audience, issuer, subject, expires_at=None, expires_in=600,
1515
issued_at=None, assertion_id=None, **kwargs):
16-
# Names are defined in https://tools.ietf.org/html/rfc7521#section-5
16+
"""Create an assertion in bytes, based on the provided claims.
17+
18+
All parameter names are defined in https://tools.ietf.org/html/rfc7521#section-5
19+
except the expires_in is defined here as lifetime-in-seconds,
20+
which will be automatically translated into expires_at in UTC.
21+
"""
1722
raise NotImplementedError("Will be implemented by sub-class")
1823

24+
def create_regenerative_assertion(
25+
self, audience, issuer, subject=None, expires_in=600, **kwargs):
26+
"""Create an assertion as a callable,
27+
which will then compute the assertion later when necessary.
28+
29+
This is a useful optimization to reuse the client assertion.
30+
"""
31+
return AutoRefresher( # Returns a callable
32+
lambda a=audience, i=issuer, s=subject, e=expires_in, kwargs=kwargs:
33+
self.create_normal_assertion(a, i, s, expires_in=e, **kwargs),
34+
expires_in=max(expires_in-60, 0))
35+
36+
37+
class AutoRefresher(object):
38+
"""Cache the output of a factory, and auto-refresh it when necessary. Usage::
1939
20-
class JwtSigner(Signer):
40+
r = AutoRefresher(time.time, expires_in=5)
41+
for i in range(15):
42+
print(r()) # the timestamp change only after every 5 seconds
43+
time.sleep(1)
44+
"""
45+
def __init__(self, factory, expires_in=540):
46+
self._factory = factory
47+
self._expires_in = expires_in
48+
self._buf = {}
49+
def __call__(self):
50+
EXPIRES_AT, VALUE = "expires_at", "value"
51+
now = time.time()
52+
if self._buf.get(EXPIRES_AT, 0) <= now:
53+
logger.debug("Regenerating new assertion")
54+
self._buf = {VALUE: self._factory(), EXPIRES_AT: now + self._expires_in}
55+
else:
56+
logger.debug("Reusing still valid assertion")
57+
return self._buf.get(VALUE)
58+
59+
60+
class JwtAssertionCreator(AssertionCreator):
2161
def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
22-
"""Create a signer.
62+
"""Construct a Jwt assertion creator.
2363
2464
Args:
2565
@@ -37,11 +77,11 @@ def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
3777
self.headers["x5t"] = base64.urlsafe_b64encode(
3878
binascii.a2b_hex(sha1_thumbprint)).decode()
3979

40-
def sign_assertion(
41-
self, audience, issuer, subject=None, expires_at=None,
80+
def create_normal_assertion(
81+
self, audience, issuer, subject=None, expires_at=None, expires_in=600,
4282
issued_at=None, assertion_id=None, not_before=None,
4383
additional_claims=None, **kwargs):
44-
"""Sign a JWT Assertion.
84+
"""Create a JWT Assertion.
4585
4686
Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
4787
Key-value pairs in additional_claims will be added into payload as-is.
@@ -51,7 +91,7 @@ def sign_assertion(
5191
'aud': audience,
5292
'iss': issuer,
5393
'sub': subject or issuer,
54-
'exp': expires_at or (now + 10*60), # 10 minutes
94+
'exp': expires_at or (now + expires_in),
5595
'iat': issued_at or now,
5696
'jti': assertion_id or str(uuid.uuid4()),
5797
}
@@ -68,3 +108,9 @@ def sign_assertion(
68108
'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
69109
raise
70110

111+
112+
# Obsolete. For backward compatibility. They will be removed in future versions.
113+
Signer = AssertionCreator # For backward compatibility
114+
JwtSigner = JwtAssertionCreator # For backward compatibility
115+
JwtSigner.sign_assertion = JwtAssertionCreator.create_normal_assertion # For backward compatibility
116+

msal/oauth2cli/oauth2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
server_configuration, # type: dict
3434
client_id, # type: str
3535
client_secret=None, # type: Optional[str]
36-
client_assertion=None, # type: Optional[bytes]
36+
client_assertion=None, # type: Union[bytes, callable, None]
3737
client_assertion_type=None, # type: Optional[str]
3838
default_headers=None, # type: Optional[dict]
3939
default_body=None, # type: Optional[dict]
@@ -55,10 +55,12 @@ def __init__(
5555
https://example.com/.../.well-known/openid-configuration
5656
client_id (str): The client's id, issued by the authorization server
5757
client_secret (str): Triggers HTTP AUTH for Confidential Client
58-
client_assertion (bytes):
58+
client_assertion (bytes, callable):
5959
The client assertion to authenticate this client, per RFC 7521.
6060
It can be a raw SAML2 assertion (this method will encode it for you),
6161
or a raw JWT assertion.
62+
It can also be a callable (recommended),
63+
so that we will do lazy creation of an assertion.
6264
client_assertion_type (str):
6365
The type of your :attr:`client_assertion` parameter.
6466
It is typically the value of :attr:`CLIENT_ASSERTION_TYPE_SAML2` or
@@ -75,11 +77,9 @@ def __init__(
7577
self.configuration = server_configuration
7678
self.client_id = client_id
7779
self.client_secret = client_secret
80+
self.client_assertion = client_assertion
7881
self.default_body = default_body or {}
79-
if client_assertion is not None and client_assertion_type is not None:
80-
# See https://tools.ietf.org/html/rfc7521#section-4.2
81-
encoder = self.client_assertion_encoders.get(client_assertion_type, lambda a: a)
82-
self.default_body["client_assertion"] = encoder(client_assertion)
82+
if client_assertion_type is not None:
8383
self.default_body["client_assertion_type"] = client_assertion_type
8484
self.logger = logging.getLogger(__name__)
8585
self.session = s = requests.Session()
@@ -114,6 +114,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
114114
**kwargs # Relay all extra parameters to underlying requests
115115
): # Returns the json object came from the OAUTH2 response
116116
_data = {'client_id': self.client_id, 'grant_type': grant_type}
117+
118+
if self.default_body.get("client_assertion_type") and self.client_assertion:
119+
# See https://tools.ietf.org/html/rfc7521#section-4.2
120+
encoder = self.client_assertion_encoders.get(
121+
self.default_body["client_assertion_type"], lambda a: a)
122+
_data["client_assertion"] = encoder(
123+
self.client_assertion() # Do lazy on-the-fly computation
124+
if callable(self.client_assertion) else self.client_assertion)
125+
117126
_data.update(self.default_body) # It may contain authen parameters
118127
_data.update(data or {}) # So the content in data param prevails
119128
# We don't have to clean up None values here, because requests lib will.

sample/authorization-code-flow-sample/authorization_code_flow_sample.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,16 @@
2424
import json
2525
import logging
2626
import uuid
27+
import os
2728

2829
import flask
2930

3031
import msal
3132

3233
app = flask.Flask(__name__)
3334
app.debug = True
34-
app.secret_key = sys.argv[2] # In this demo, we expect a secret from 2nd CLI param
35+
app.secret_key = os.environ.get("FLASK_SECRET")
36+
assert app.secret_key, "This sample requires a FLASK_SECRET env var to enable session"
3537

3638

3739
# Optional logging

sample/device_flow_sample.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
# Ideally you should wait here, in order to save some unnecessary polling
5757
# input("Press Enter after you successfully login from another device...")
5858
result = app.acquire_token_by_device_flow(flow) # By default it will block
59+
# You can follow this instruction to shorten the block time
60+
# https://msal-python.readthedocs.io/en/latest/#msal.PublicClientApplication.acquire_token_by_device_flow
61+
# or you may even turn off the blocking behavior,
62+
# and then keep calling acquire_token_by_device_flow(flow) in your own customized loop.
5963

6064
if "access_token" in result:
6165
print(result["access_token"])

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@
6666
'License :: OSI Approved :: MIT License',
6767
'Operating System :: OS Independent',
6868
],
69-
packages=find_packages(),
69+
packages=find_packages(exclude=["tests"]),
70+
data_files=[('', ['LICENSE'])],
7071
install_requires=[
7172
'requests>=2.0.0,<3',
7273
'PyJWT[crypto]>=1.0.0,<2',

tests/test_application.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
with open(CONFIG_FILE) as conf:
2121
CONFIG = json.load(conf)
2222

23-
logger = logging.getLogger(__file__)
23+
logger = logging.getLogger(__name__)
2424
logging.basicConfig(level=logging.DEBUG)
2525

2626

@@ -99,6 +99,46 @@ def test_client_certificate(self):
9999
self.assertIn('access_token', result)
100100
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
101101

102+
def test_extract_a_tag_less_public_cert(self):
103+
pem = "my_cert"
104+
self.assertEqual(["my_cert"], extract_certs(pem))
105+
106+
def test_extract_a_tag_enclosed_cert(self):
107+
pem = """
108+
-----BEGIN CERTIFICATE-----
109+
my_cert
110+
-----END CERTIFICATE-----
111+
"""
112+
self.assertEqual(["my_cert"], extract_certs(pem))
113+
114+
def test_extract_multiple_tag_enclosed_certs(self):
115+
pem = """
116+
-----BEGIN CERTIFICATE-----
117+
my_cert1
118+
-----END CERTIFICATE-----
119+
120+
-----BEGIN CERTIFICATE-----
121+
my_cert2
122+
-----END CERTIFICATE-----
123+
"""
124+
self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem))
125+
126+
@unittest.skipUnless("public_certificate" in CONFIG, "Missing Public cert")
127+
def test_subject_name_issuer_authentication(self):
128+
assert ("private_key_file" in CONFIG
129+
and "thumbprint" in CONFIG and "public_certificate" in CONFIG)
130+
with open(os.path.join(THIS_FOLDER, CONFIG['private_key_file'])) as f:
131+
pem = f.read()
132+
with open(os.path.join(THIS_FOLDER, CONFIG['public_certificate'])) as f:
133+
public_certificate = f.read()
134+
app = ConfidentialClientApplication(
135+
CONFIG['client_id'], authority=CONFIG["authority"],
136+
client_credential={"private_key": pem, "thumbprint": CONFIG["thumbprint"],
137+
"public_certificate": public_certificate})
138+
scope = CONFIG.get("scope", [])
139+
result = app.acquire_token_for_client(scope)
140+
self.assertIn('access_token', result)
141+
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
102142

103143
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
104144
class TestPublicClientApplication(Oauth2TestCase):

tests/test_assertion.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import json
2+
3+
from msal.oauth2cli import JwtSigner
4+
from msal.oauth2cli.oidc import base64decode
5+
6+
from tests import unittest
7+
8+
9+
class AssertionTestCase(unittest.TestCase):
10+
def test_extra_claims(self):
11+
assertion = JwtSigner(key=None, algorithm="none").sign_assertion(
12+
"audience", "issuer", additional_claims={"client_ip": "1.2.3.4"})
13+
payload = json.loads(base64decode(assertion.split(b'.')[1].decode('utf-8')))
14+
self.assertEqual("1.2.3.4", payload.get("client_ip"))
15+

0 commit comments

Comments
 (0)