diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 80c3f83e1fe..521174dca36 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -572,6 +572,8 @@ def __getattr__(self, __name: str) -> Any: ) and (func := getattr(value, "__func__", None)) is not None and not inspect.isclass(getattr(value, "__self__", None)) + # skip SQLAlchemy instrumented methods + and not getattr(value, "_sa_instrumented", False) ): # Rebind `self` to the proxy on methods to capture nested mutations. return functools.partial(func, self) diff --git a/tests/units/test_model.py b/tests/units/test_model.py index 51b8982d926..e7f64f736e5 100644 --- a/tests/units/test_model.py +++ b/tests/units/test_model.py @@ -8,7 +8,7 @@ import reflex.model from reflex.constants.state import FIELD_MARKER from reflex.model import Model, ModelRegistry -from reflex.state import BaseState +from reflex.state import BaseState, State from tests.units.test_state import ( mock_app_simple, # noqa: F401 # for pytest.mark.usefixtures ) @@ -240,3 +240,43 @@ async def test_upcast_event_handler_arg(handler, payload): assert update.delta == { UpcastStateWithSqlAlchemy.get_full_name(): {"passed" + FIELD_MARKER: True} } + + +def test_no_rebind_mutable_proxy_for_instrumented_functions(): + """Test that we don't rebind mutable proxies for instrumented functions.""" + import sqlalchemy + import sqlalchemy.orm + + class SABase(sqlalchemy.orm.MappedAsDataclass, sqlalchemy.orm.DeclarativeBase): + pass + + class SAKeyword(SABase): + __tablename__ = "sa_keyword" + + id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + primary_key=True, init=False, default=None + ) + value: sqlalchemy.orm.Mapped[str] = sqlalchemy.orm.mapped_column(default="") + obj_id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + sqlalchemy.ForeignKey("sa_obj.id"), default=None + ) + + class SAObj(SABase): + __tablename__ = "sa_obj" + + id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + primary_key=True, init=False, default=None + ) + keywords: sqlalchemy.orm.Mapped[list[SAKeyword]] = sqlalchemy.orm.relationship( + lazy="selectin", # codespell:ignore + cascade="all, delete", + default_factory=list, + ) + + class SAState(State): + sa_obj: SAObj = SAObj() + + sa_state = SAState() + assert "sa_obj" not in sa_state.dirty_vars + sa_state.sa_obj.keywords.append(SAKeyword(value="test")) + assert "sa_obj" in sa_state.dirty_vars