Skip to content

Commit e3e0aa1

Browse files
nimaxinjowilf
andauthored
Add get_details_query to SQLAlchemy ModelView. (#643)
* Add `get_details_query` to SQLAlchemy ModelView. * Integrate `get_details_query` into `find_by_pks` * Update comment --------- Co-authored-by: Jocelin Hounon <[email protected]>
1 parent f9b8fe4 commit e3e0aa1

File tree

1 file changed

+18
-2
lines changed
  • starlette_admin/contrib/sqla

1 file changed

+18
-2
lines changed

starlette_admin/contrib/sqla/view.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,22 @@ async def handle_row_action(
176176
except SQLAlchemyError as exc:
177177
raise ActionFailed(str(exc)) from exc
178178

179+
def get_details_query(self, request: Request) -> Select:
180+
"""
181+
Return a Select expression which is used as base statement for
182+
[find_by_pk][starlette_admin.views.BaseModelView.find_by_pk] and
183+
[find_by_pks][starlette_admin.views.BaseModelView.find_by_pks] methods.
184+
185+
Examples:
186+
```python hl_lines="3-4"
187+
class PostView(ModelView):
188+
189+
def get_details_query(self, request: Request):
190+
return super().get_details_query().options(selectinload(Post.author))
191+
```
192+
"""
193+
return select(self.model)
194+
179195
def get_list_query(self, request: Request) -> Select:
180196
"""
181197
Return a Select expression which is used as base statement for
@@ -324,7 +340,7 @@ async def find_by_pk(self, request: Request, pk: Any) -> Any:
324340
else:
325341
assert isinstance(self._pk_coerce, type)
326342
clause = self._pk_column == self._pk_coerce(pk)
327-
stmt = select(self.model).where(clause)
343+
stmt = self.get_details_query(request).where(clause)
328344
for field in self.get_fields_list(request, request.state.action):
329345
if isinstance(field, RelationField):
330346
stmt = stmt.options(joinedload(getattr(self.model, field.name)))
@@ -360,7 +376,7 @@ async def _exec_find_by_pks(
360376
clause = await self._get_multiple_pks_in_clause(pks, use_composite_in)
361377
else:
362378
clause = self._pk_column.in_(map(self._pk_coerce, pks)) # type: ignore
363-
stmt = select(self.model).where(clause)
379+
stmt = self.get_details_query(request).where(clause)
364380
for field in self.get_fields_list(request, request.state.action):
365381
if isinstance(field, RelationField):
366382
stmt = stmt.options(joinedload(getattr(self.model, field.name)))

0 commit comments

Comments
 (0)