11from collections .abc import Sequence
2+ from typing import Any
23
3- from sqlalchemy import Select , select
4+ from sqlalchemy import Select , delete , insert
45from sqlalchemy .ext .asyncio import AsyncSession
5- from sqlalchemy_crud_plus import CRUDPlus
6+ from sqlalchemy_crud_plus import CRUDPlus , JoinConfig
67
78from backend .app .admin .model import DataRule , DataScope
8- from backend .app .admin .schema .data_scope import CreateDataScopeParam , UpdateDataScopeParam , UpdateDataScopeRuleParam
9+ from backend .app .admin .model .m2m import data_scope_rule
10+ from backend .app .admin .schema .data_scope import (
11+ CreateDataScopeParam ,
12+ CreateDataScopeRuleParam ,
13+ UpdateDataScopeParam ,
14+ UpdateDataScopeRuleParam ,
15+ )
16+ from backend .utils .serializers import select_join_serialize
917
1018
1119class CRUDDataScope (CRUDPlus [DataScope ]):
@@ -31,15 +39,24 @@ async def get_by_name(self, db: AsyncSession, name: str) -> DataScope | None:
3139 """
3240 return await self .select_model_by_column (db , name = name )
3341
34- async def get_with_relation (self , db : AsyncSession , pk : int ) -> DataScope :
42+ async def get_with_relation (self , db : AsyncSession , pk : int ) -> Any :
3543 """
3644 获取数据范围关联数据
3745
3846 :param db: 数据库会话
3947 :param pk: 范围 ID
4048 :return:
4149 """
42- return await self .select_model (db , pk , load_strategies = ['rules' ])
50+ result = await self .select_models (
51+ db ,
52+ id = pk ,
53+ join_conditions = [
54+ JoinConfig (model = data_scope_rule , join_on = data_scope_rule .c .data_scope_id == self .model .id ),
55+ JoinConfig (model = DataRule , join_on = DataRule .id == data_scope_rule .c .data_rule_id , fill_result = True ),
56+ ],
57+ )
58+
59+ return await select_join_serialize (result , relationships = ['DataScope-m2m-DataRule' ])
4360
4461 async def get_all (self , db : AsyncSession ) -> Sequence [DataScope ]:
4562 """
@@ -65,7 +82,7 @@ async def get_select(self, name: str | None, status: int | None) -> Select:
6582 if status is not None :
6683 filters ['status' ] = status
6784
68- return await self .select_order ('id' , load_strategies = { 'rules' : 'noload' , 'roles' : 'noload' }, ** filters )
85+ return await self .select_order ('id' , ** filters )
6986
7087 async def create (self , db : AsyncSession , obj : CreateDataScopeParam ) -> None :
7188 """
@@ -88,7 +105,8 @@ async def update(self, db: AsyncSession, pk: int, obj: UpdateDataScopeParam) ->
88105 """
89106 return await self .update_model (db , pk , obj )
90107
91- async def update_rules (self , db : AsyncSession , pk : int , rule_ids : UpdateDataScopeRuleParam ) -> int :
108+ @staticmethod
109+ async def update_rules (db : AsyncSession , pk : int , rule_ids : UpdateDataScopeRuleParam ) -> int :
92110 """
93111 更新数据范围规则
94112
@@ -97,11 +115,16 @@ async def update_rules(self, db: AsyncSession, pk: int, rule_ids: UpdateDataScop
97115 :param rule_ids: 数据规则 ID 列表
98116 :return:
99117 """
100- current_data_scope = await self .get_with_relation (db , pk )
101- stmt = select (DataRule ).where (DataRule .id .in_ (rule_ids .rules ))
102- rules = await db .execute (stmt )
103- current_data_scope .rules = rules .scalars ().all ()
104- return len (current_data_scope .rules )
118+ data_scope_rule_stmt = delete (data_scope_rule ).where (data_scope_rule .c .data_scope_id == pk )
119+ await db .execute (data_scope_rule_stmt )
120+
121+ data_scope_rule_data = [
122+ CreateDataScopeRuleParam (data_scope_id = pk , data_rule_id = rule_id ).model_dump () for rule_id in rule_ids .rules
123+ ]
124+ data_scope_rule_stmt = insert (data_scope_rule )
125+ await db .execute (data_scope_rule_stmt , data_scope_rule_data )
126+
127+ return len (rule_ids .rules )
105128
106129 async def delete (self , db : AsyncSession , pks : list [int ]) -> int :
107130 """
0 commit comments