22
33import bcrypt
44
5- from sqlalchemy import delete , insert , select
5+ from sqlalchemy import select
66from sqlalchemy .ext .asyncio import AsyncSession
7+ from sqlalchemy .orm import noload , selectinload
78from sqlalchemy .sql import Select
89from sqlalchemy_crud_plus import CRUDPlus , JoinConfig
910
@@ -76,13 +77,9 @@ async def add(self, db: AsyncSession, obj: AddUserParam) -> None:
7677 dict_obj .update ({'salt' : salt })
7778 new_user = self .model (** dict_obj )
7879
79- db .add (new_user )
80- await db .flush () # 获取用户ID
81-
82- # 添加用户角色关联(逻辑外键)
83- if obj .roles :
84- for role_id in obj .roles :
85- await db .execute (insert (user_role ).values (user_id = new_user .id , role_id = role_id ))
80+ stmt = select (Role ).where (Role .id .in_ (obj .roles ))
81+ roles = await db .execute (stmt )
82+ new_user .roles = roles .scalars ().all ()
8683
8784 async def add_by_oauth2 (self , db : AsyncSession , obj : AddOAuth2UserParam ) -> None :
8885 """
@@ -96,15 +93,11 @@ async def add_by_oauth2(self, db: AsyncSession, obj: AddOAuth2UserParam) -> None
9693 dict_obj .update ({'is_staff' : True , 'salt' : None })
9794 new_user = self .model (** dict_obj )
9895
99- db .add (new_user )
100- await db .flush () # 获取用户ID
101-
102- # 绑定第一个角色(逻辑外键)
103- stmt = select (Role ).limit (1 )
96+ stmt = select (Role )
10497 role = await db .execute (stmt )
105- first_role = role .scalars ().first ()
106- if first_role :
107- await db .execute ( insert ( user_role ). values ( user_id = new_user . id , role_id = first_role . id ) )
98+ new_user . roles = [ role .scalars ().first ()] # 默认绑定第一个角色
99+
100+ db .add ( new_user )
108101
109102 async def update (self , db : AsyncSession , input_user : User , obj : UpdateUserParam ) -> int :
110103 """
@@ -120,13 +113,9 @@ async def update(self, db: AsyncSession, input_user: User, obj: UpdateUserParam)
120113
121114 count = await self .update_model (db , input_user .id , obj )
122115
123- # 删除原有用户角色关联
124- await db .execute (delete (user_role ).where (user_role .c .user_id == input_user .id ))
125-
126- # 添加新的用户角色关联(逻辑外键)
127- if role_ids :
128- for role_id in role_ids :
129- await db .execute (insert (user_role ).values (user_id = input_user .id , role_id = role_id ))
116+ stmt = select (Role ).where (Role .id .in_ (role_ids ))
117+ roles = await db .execute (stmt )
118+ input_user .roles = roles .scalars ().all ()
130119
131120 return count
132121
@@ -220,6 +209,10 @@ async def get_select(self, dept: int | None, username: str | None, phone: str |
220209 return await self .select_order (
221210 'id' ,
222211 'desc' ,
212+ load_options = [
213+ selectinload (self .model .dept ).options (noload (Dept .parent ), noload (Dept .children ), noload (Dept .users )),
214+ selectinload (self .model .roles ).options (noload (Role .users ), noload (Role .menus ), noload (Role .scopes )),
215+ ],
223216 ** filters ,
224217 )
225218
@@ -280,7 +273,7 @@ async def get_joins(
280273 :param db: 数据库会话
281274 :param user_id: 用户 ID
282275 :param username: 用户名
283- :return: 包含用户信息及关联部门、角色等数据的对象,支持访问 .status、.dept、.roles 等属性
276+ :return:
284277 """
285278 filters = {}
286279
0 commit comments