34
34
from oauthlib .oauth2 import InsecureTransportError , MissingCodeError , MissingTokenError
35
35
from requests .exceptions import SSLError
36
36
37
-
38
37
OAUTH2TOKEN = {
39
38
'access_token' : 'token' ,
40
39
'token_type' : 'Bearer' ,
@@ -87,8 +86,9 @@ def tearDown(self):
87
86
oauth2 .db = self ._db
88
87
oauth2 .OAuth2Session = self ._OAuth2Session
89
88
90
- def _helper (self , fullname_field = True , mail_field = True , conf = None , missing_conf = None ):
89
+ def _helper (self , fullname_field = True , mail_field = True , conf = None , missing_conf = None , jwt_enable = False ):
91
90
oauth2 .db = MagicMock ()
91
+ oauth2 .jwt = MagicMock ()
92
92
93
93
oauth2 .toolkit .config = {
94
94
'ckan.oauth2.legacy_idm' : 'false' ,
@@ -110,6 +110,9 @@ def _helper(self, fullname_field=True, mail_field=True, conf=None, missing_conf=
110
110
if fullname_field :
111
111
helper .profile_api_fullname_field = self ._fullname_field
112
112
113
+ if jwt_enable :
114
+ helper .jwt_enable = True
115
+
113
116
return helper
114
117
115
118
@parameterized .expand ([
@@ -341,7 +344,6 @@ def test_identify(self, username, fullname=None, email=None, user_exists=True,
341
344
print (username , fullname , email , user_exists , fullname_field , sysadmin )
342
345
343
346
# Create the mocks
344
- request = MagicMock ()
345
347
request = make_request (False , 'localhost' , '/oauth2/callback' , {})
346
348
oauth2 .toolkit .request = request
347
349
oauth2 .model .Session = MagicMock ()
@@ -383,6 +385,29 @@ def test_identify(self, username, fullname=None, email=None, user_exists=True,
383
385
oauth2 .model .Session .commit .assert_called_once ()
384
386
oauth2 .model .Session .remove .assert_called_once ()
385
387
388
+ def test_identify_jwt (self ):
389
+
390
+ helper = self ._helper (jwt_enable = True )
391
+ token = OAUTH2TOKEN
392
+ user_data = {
self .
_user_field :
'test_user' ,
self .
_email_field :
'[email protected] ' }
393
+
394
+ oauth2 .jwt .decode .return_value = user_data
395
+
396
+ oauth2 .model .Session = MagicMock ()
397
+ user = MagicMock ()
398
+ user .name = None
399
+ user .email = None
400
+ oauth2 .model .User = MagicMock (return_value = user )
401
+ oauth2 .model .User .by_email = MagicMock (return_value = [user ])
402
+
403
+ returned_username = helper .identify (token )
404
+
405
+ self .assertEquals (user_data [self ._user_field ], returned_username )
406
+
407
+ oauth2 .model .Session .add .assert_called_once_with (user )
408
+ oauth2 .model .Session .commit .assert_called_once ()
409
+ oauth2 .model .Session .remove .assert_called_once ()
410
+
386
411
@parameterized .expand ([
387
412
({'error' : 'invalid_token' , 'error_description' : 'Error Description' },),
388
413
({'error' : 'another_error' },)
@@ -472,10 +497,12 @@ def test_redirect_from_callback(self, identity):
472
497
self .assertEquals (came_from , oauth2 .toolkit .response .location )
473
498
474
499
@parameterized .expand ([
475
- (True ,),
476
- (False ,)
500
+ (True , True ),
501
+ (True , False ),
502
+ (False , False ),
503
+ (False , True ),
477
504
])
478
- def test_update_token (self , user_exists ):
505
+ def test_update_token (self , user_exists , jwt_expires_in ):
479
506
helper = self ._helper ()
480
507
user = 'user'
481
508
@@ -494,26 +521,48 @@ def test_update_token(self, user_exists):
494
521
oauth2 .db .UserToken .by_user_name = MagicMock (return_value = usertoken )
495
522
496
523
# The token to be updated
497
- newtoken = {
498
- 'access_token' : 'new_access_token' ,
499
- 'token_type' : 'new_token_type' ,
500
- 'expires_in' : 'new_expires_in' ,
501
- 'refresh_token' : 'new_refresh_token'
502
- }
503
-
504
- helper .update_token ('user' , newtoken )
505
-
506
- # Check that the object has been stored
507
- oauth2 .model .Session .add .assert_called_once ()
508
- oauth2 .model .Session .commit .assert_called_once ()
524
+ if jwt_expires_in :
525
+ newtoken = {
526
+ 'access_token' : 'new_access_token' ,
527
+ 'token_type' : 'new_token_type' ,
528
+ 'expires_in' : 'new_expires_in' ,
529
+ 'refresh_token' : 'new_refresh_token'
530
+ }
531
+ helper .update_token ('user' , newtoken )
532
+
533
+ # Check that the object has been stored
534
+ oauth2 .model .Session .add .assert_called_once ()
535
+ oauth2 .model .Session .commit .assert_called_once ()
536
+
537
+ # Check that the object contains the correct information
538
+ tk = oauth2 .model .Session .add .call_args_list [0 ][0 ][0 ]
539
+ self .assertEquals (user , tk .user_name )
540
+ self .assertEquals (newtoken ['access_token' ], tk .access_token )
541
+ self .assertEquals (newtoken ['token_type' ], tk .token_type )
542
+ self .assertEquals (newtoken ['expires_in' ], tk .expires_in )
543
+ self .assertEquals (newtoken ['refresh_token' ], tk .refresh_token )
544
+ else :
545
+ newtoken = {
546
+ 'access_token' : 'new_access_token' ,
547
+ 'token_type' : 'new_token_type' ,
548
+ 'refresh_token' : 'new_refresh_token'
549
+ }
550
+ expires_in_data = {'exp' : 3600 , 'iat' : 0 }
551
+ oauth2 .jwt .decode .return_value = expires_in_data
552
+ helper .update_token ('user' , newtoken )
553
+
554
+ # Check that the object has been stored
555
+ oauth2 .model .Session .add .assert_called_once ()
556
+ oauth2 .model .Session .commit .assert_called_once ()
557
+
558
+ # Check that the object contains the correct information
559
+ tk = oauth2 .model .Session .add .call_args_list [0 ][0 ][0 ]
560
+ self .assertEquals (user , tk .user_name )
561
+ self .assertEquals (newtoken ['access_token' ], tk .access_token )
562
+ self .assertEquals (newtoken ['token_type' ], tk .token_type )
563
+ self .assertEquals (3600 , tk .expires_in )
564
+ self .assertEquals (newtoken ['refresh_token' ], tk .refresh_token )
509
565
510
- # Check that the object contains the correct information
511
- tk = oauth2 .model .Session .add .call_args_list [0 ][0 ][0 ]
512
- self .assertEquals (user , tk .user_name )
513
- self .assertEquals (newtoken ['access_token' ], tk .access_token )
514
- self .assertEquals (newtoken ['token_type' ], tk .token_type )
515
- self .assertEquals (newtoken ['expires_in' ], tk .expires_in )
516
- self .assertEquals (newtoken ['refresh_token' ], tk .refresh_token )
517
566
518
567
@parameterized .expand ([
519
568
(True ,),
0 commit comments