2020from  http .cookies  import  Morsel 
2121
2222from  tornado  import  escape , httputil , web 
23- from  traitlets  import  Bool , Dict , Type , Unicode , default 
23+ from  traitlets  import  Bool , Dict , Enum ,  List ,  TraitError ,  Type , Unicode , default ,  validate 
2424from  traitlets .config  import  LoggingConfigurable 
2525
2626from  jupyter_server .transutils  import  _i18n 
3131_non_alphanum  =  re .compile (r"[^A-Za-z0-9]" )
3232
3333
34+ # Define the User properties that can be updated 
35+ UpdatableField  =  t .Literal ["name" , "display_name" , "initials" , "avatar_url" , "color" ]
36+ 
37+ 
3438@dataclass  
3539class  User :
3640    """Object representing a User 
@@ -188,6 +192,14 @@ class IdentityProvider(LoggingConfigurable):
188192        help = _i18n ("The logout handler class to use." ),
189193    )
190194
195+     # Define the fields that can be updated 
196+     updatable_fields  =  List (
197+         trait = Enum (list (t .get_args (UpdatableField ))),
198+         default_value = ["color" ],  # Default updatable field 
199+         config = True ,
200+         help = _i18n ("List of fields in the User model that can be updated." ),
201+     )
202+ 
191203    token_generated  =  False 
192204
193205    @default ("token" ) 
@@ -207,6 +219,18 @@ def _token_default(self):
207219            self .token_generated  =  True 
208220            return  binascii .hexlify (os .urandom (24 )).decode ("ascii" )
209221
222+     @validate ("updatable_fields" ) 
223+     def  _validate_updatable_fields (self , proposal ):
224+         """Validate that all fields in updatable_fields are valid.""" 
225+         valid_updatable_fields  =  list (t .get_args (UpdatableField ))
226+         invalid_fields  =  [
227+             field  for  field  in  proposal ["value" ] if  field  not  in valid_updatable_fields 
228+         ]
229+         if  invalid_fields :
230+             msg  =  f"Invalid fields in updatable_fields: { invalid_fields }  
231+             raise  TraitError (msg )
232+         return  proposal ["value" ]
233+ 
210234    need_token : bool  |  Bool [bool , t .Union [bool , int ]] =  Bool (True )
211235
212236    def  get_user (self , handler : web .RequestHandler ) ->  User  |  None  |  t .Awaitable [User  |  None ]:
@@ -269,6 +293,31 @@ async def _get_user(self, handler: web.RequestHandler) -> User | None:
269293
270294        return  user 
271295
296+     def  update_user (
297+         self , handler : web .RequestHandler , user_data : dict [UpdatableField , str ]
298+     ) ->  User :
299+         """Update user information and persist the user model.""" 
300+         self .check_update (user_data )
301+         current_user  =  t .cast (User , handler .current_user )
302+         updated_user  =  self .update_user_model (current_user , user_data )
303+         self .persist_user_model (handler )
304+         return  updated_user 
305+ 
306+     def  check_update (self , user_data : dict [UpdatableField , str ]) ->  None :
307+         """Raises if some fields to update are not updatable.""" 
308+         for  field  in  user_data :
309+             if  field  not  in self .updatable_fields :
310+                 msg  =  f"Field { field }  
311+                 raise  ValueError (msg )
312+ 
313+     def  update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) ->  User :
314+         """Update user information.""" 
315+         raise  NotImplementedError 
316+ 
317+     def  persist_user_model (self , handler : web .RequestHandler ) ->  None :
318+         """Persist the user model (i.e. a cookie).""" 
319+         raise  NotImplementedError 
320+ 
272321    def  identity_model (self , user : User ) ->  dict [str , t .Any ]:
273322        """Return a User as an Identity model""" 
274323        # TODO: validate? 
@@ -617,6 +666,16 @@ class PasswordIdentityProvider(IdentityProvider):
617666    def  _need_token_default (self ):
618667        return  not  bool (self .hashed_password )
619668
669+     @default ("updatable_fields" ) 
670+     def  _default_updatable_fields (self ):
671+         return  [
672+             "name" ,
673+             "display_name" ,
674+             "initials" ,
675+             "avatar_url" ,
676+             "color" ,
677+         ]
678+ 
620679    @property  
621680    def  login_available (self ) ->  bool :
622681        """Whether a LoginHandler is needed - and therefore whether the login page should be displayed.""" 
@@ -627,6 +686,17 @@ def auth_enabled(self) -> bool:
627686        """Return whether any auth is enabled""" 
628687        return  bool (self .hashed_password  or  self .token )
629688
689+     def  update_user_model (self , current_user : User , user_data : dict [UpdatableField , str ]) ->  User :
690+         """Update user information.""" 
691+         for  field  in  self .updatable_fields :
692+             if  field  in  user_data :
693+                 setattr (current_user , field , user_data [field ])
694+         return  current_user 
695+ 
696+     def  persist_user_model (self , handler : web .RequestHandler ) ->  None :
697+         """Persist the user model to a cookie.""" 
698+         self .set_login_cookie (handler , handler .current_user )
699+ 
630700    def  passwd_check (self , password ):
631701        """Check password against our stored hashed password""" 
632702        return  passwd_check (self .hashed_password , password )
0 commit comments