22
33import bcrypt
44
5- from sqlalchemy import select
5+ from sqlalchemy import insert , select
66from sqlalchemy .ext .asyncio import AsyncSession
77from sqlalchemy_crud_plus import CRUDPlus , JoinConfig
88
1111from backend .app .admin .schema .user import (
1212 AddOAuth2UserParam ,
1313 AddUserParam ,
14+ AddUserRoleParam ,
1415 UpdateUserParam ,
1516)
1617from backend .common .pagination import paging_data
@@ -72,15 +73,20 @@ async def add(self, db: AsyncSession, obj: AddUserParam) -> None:
7273 """
7374 salt = bcrypt .gensalt ()
7475 obj .password = get_hash_password (obj .password , salt )
76+
7577 dict_obj = obj .model_dump (exclude = {'roles' })
7678 dict_obj .update ({'salt' : salt })
7779 new_user = self .model (** dict_obj )
80+ db .add (new_user )
81+ await db .flush ()
7882
79- stmt = select (Role ).where (Role .id .in_ (obj .roles ))
80- roles = await db .execute (stmt )
81- new_user . roles = roles .scalars ().all ()
83+ role_stmt = select (Role ).where (Role .id .in_ (obj .roles ))
84+ result = await db .execute (role_stmt )
85+ roles = result .scalars ().all ()
8286
83- db .add (new_user )
87+ user_role_data = [AddUserRoleParam (user_id = new_user .id , role_id = role .id ).model_dump () for role in roles ]
88+ user_role_stmt = insert (user_role )
89+ await db .execute (user_role_stmt , user_role_data )
8490
8591 async def add_by_oauth2 (self , db : AsyncSession , obj : AddOAuth2UserParam ) -> None :
8692 """
@@ -93,12 +99,15 @@ async def add_by_oauth2(self, db: AsyncSession, obj: AddOAuth2UserParam) -> None
9399 dict_obj = obj .model_dump ()
94100 dict_obj .update ({'is_staff' : True , 'salt' : None })
95101 new_user = self .model (** dict_obj )
102+ db .add (new_user )
103+ await db .flush ()
96104
97- stmt = select (Role )
98- role = await db .execute (stmt )
99- new_user . roles = [ role .scalars ().first ()] # 默认绑定第一个角色
105+ role_stmt = select (Role )
106+ result = await db .execute (role_stmt )
107+ role = result .scalars ().first () # 默认绑定第一个角色
100108
101- db .add (new_user )
109+ user_role_stmt = insert (user_role ).values (AddUserRoleParam (user_id = new_user .id , role_id = role .id ).model_dump ())
110+ await db .execute (user_role_stmt )
102111
103112 async def update (self , db : AsyncSession , input_user : User , obj : UpdateUserParam ) -> int :
104113 """
@@ -114,9 +123,14 @@ async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam)
114123
115124 count = await self .update_model (db , input_user .id , obj )
116125
117- stmt = select (Role ).where (Role .id .in_ (role_ids ))
118- roles = await db .execute (stmt )
119- input_user .roles = roles .scalars ().all ()
126+ role_stmt = select (Role ).where (Role .id .in_ (role_ids ))
127+ result = await db .execute (role_stmt )
128+ roles = result .scalars ().all ()
129+
130+ user_role_data = [AddUserRoleParam (user_id = input_user .id , role_id = role .id ).model_dump () for role in roles ]
131+ user_role_stmt = insert (user_role )
132+ await db .execute (user_role_stmt , user_role_data )
133+
120134 return count
121135
122136 async def update_nickname (self , db : AsyncSession , user_id : int , nickname : str ) -> int :
0 commit comments