Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/alembic_utils/on_entity_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING

from alembic_utils.statement import coerce_to_unquoted
from alembic_utils.statement import coerce_to_unquoted, coerce_to_quoted

if TYPE_CHECKING:
from alembic_utils.replaceable_entity import ReplaceableEntity
Expand Down Expand Up @@ -30,6 +30,11 @@ def identity(self) -> str:
"""
return f"{self.__class__.__name__}: {self.schema}.{self.signature} {self.on_entity}"

@property
def sql_on_entity(self) -> str:
"""The SQL representation of the entity that the trigger is defined on"""
return coerce_to_quoted(self.on_entity)

def render_self_for_migration(self, omit_definition=False) -> str:
"""Render a string that is valid python code to reconstruct self in a migration"""
var_name = self.to_variable_name()
Expand Down
8 changes: 4 additions & 4 deletions src/alembic_utils/pg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ def from_sql(cls, sql: str) -> "PGPolicy":
def to_sql_statement_create(self):
"""Generates a SQL "create poicy" statement for PGPolicy"""

return sql_text(f"CREATE POLICY {self.signature} on {self.on_entity} {self.definition}")
return sql_text(f"CREATE POLICY {self.signature} on {self.sql_on_entity} {self.definition}")

def to_sql_statement_drop(self, cascade=False):
"""Generates a SQL "drop policy" statement for PGPolicy"""
cascade = "cascade" if cascade else ""
return sql_text(f"DROP POLICY {self.signature} on {self.on_entity} {cascade}")
return sql_text(f"DROP POLICY {self.signature} on {self.sql_on_entity} {cascade}")

def to_sql_statement_create_or_replace(self):
"""Not implemented, postgres policies do not support replace."""
yield sql_text(f"DROP POLICY IF EXISTS {self.signature} on {self.on_entity};")
yield sql_text(f"CREATE POLICY {self.signature} on {self.on_entity} {self.definition};")
yield sql_text(f"DROP POLICY IF EXISTS {self.signature} on {self.sql_on_entity};")
yield sql_text(f"CREATE POLICY {self.signature} on {self.sql_on_entity} {self.definition};")

@classmethod
def from_database(cls, connection, schema):
Expand Down
4 changes: 2 additions & 2 deletions src/alembic_utils/pg_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def to_sql_statement_create(self):
def to_sql_statement_drop(self, cascade=False):
"""Generates a SQL "drop trigger" statement for PGTrigger"""
cascade = "cascade" if cascade else ""
return sql_text(f'DROP TRIGGER "{self.signature}" ON {self.on_entity} {cascade}')
return sql_text(f'DROP TRIGGER "{self.signature}" ON {self.sql_on_entity} {cascade}')

def to_sql_statement_create_or_replace(self):
"""Generates a SQL "replace trigger" statement for PGTrigger"""
yield sql_text(f'DROP TRIGGER IF EXISTS "{self.signature}" ON {self.on_entity};')
yield sql_text(f'DROP TRIGGER IF EXISTS "{self.signature}" ON {self.sql_on_entity};')
yield self.to_sql_statement_create()

@classmethod
Expand Down
82 changes: 38 additions & 44 deletions src/test/test_pg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,38 @@
from alembic_utils.replaceable_entity import register_entities
from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command

TEST_POLICY = PGPolicy(
schema="public",
signature="some_policy",
on_entity="some_tab", # schema omitted intentionally
definition="""
for all
to anon_user
using (true)
with check (true);
""",
)


@pytest.fixture()
def schema_setup(engine) -> Generator[None, None, None]:

@pytest.fixture(params=["some_tab", "Some_Tab"])
def schema_setup(request, engine) -> Generator[PGPolicy, None, None]:
with engine.begin() as connection:
connection.execute(
text(
"""
create table public.some_tab (
id serial primary key,
name text
);

create user anon_user;
"""
)
text(f"""
create table public."{request.param}" (
id serial primary key,
name text
);
create user anon_user;
""")
)
yield

yield PGPolicy(
schema="public",
signature="some_policy",
on_entity=request.param, # schema omitted intentionally
definition="""
for all
to anon_user
using (true)
with check (true);
""",
)

with engine.begin() as connection:
connection.execute(
text(
"""
drop table public.some_tab;
drop user anon_user;
"""
)
text(f"""
drop table public."{request.param}";
drop user anon_user;
""")
)


Expand All @@ -69,9 +64,9 @@ def test_parse_without_schema_on_entity() -> None:


def test_create_revision(engine, schema_setup) -> None:
register_entities([TEST_POLICY], entity_types=[PGPolicy])
register_entities([schema_setup], entity_types=[PGPolicy])

output = run_alembic_command(
run_alembic_command(
engine=engine,
command="revision",
command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"},
Expand All @@ -96,13 +91,13 @@ def test_create_revision(engine, schema_setup) -> None:
def test_update_revision(engine, schema_setup) -> None:
# Create the view outside of a revision
with engine.begin() as connection:
connection.execute(TEST_POLICY.to_sql_statement_create())
connection.execute(schema_setup.to_sql_statement_create())

# Update definition of TO_UPPER
UPDATED_TEST_POLICY = PGPolicy(
schema=TEST_POLICY.schema,
signature=TEST_POLICY.signature,
on_entity=TEST_POLICY.on_entity,
schema=schema_setup.schema,
signature=schema_setup.signature,
on_entity=schema_setup.on_entity,
definition="""
for update
to anon_user
Expand Down Expand Up @@ -139,9 +134,9 @@ def test_update_revision(engine, schema_setup) -> None:
def test_noop_revision(engine, schema_setup) -> None:
# Create the view outside of a revision
with engine.begin() as connection:
connection.execute(TEST_POLICY.to_sql_statement_create())
connection.execute(schema_setup.to_sql_statement_create())

register_entities([TEST_POLICY], entity_types=[PGPolicy])
register_entities([schema_setup], entity_types=[PGPolicy])

# Create a third migration without making changes.
# This should result in no create, drop or replace statements
Expand Down Expand Up @@ -176,8 +171,9 @@ def test_drop_revision(engine, schema_setup) -> None:

# Manually create a SQL function
with engine.begin() as connection:
connection.execute(TEST_POLICY.to_sql_statement_create())
output = run_alembic_command(
connection.execute(schema_setup.to_sql_statement_create())

run_alembic_command(
engine=engine,
command="revision",
command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"},
Expand All @@ -188,8 +184,6 @@ def test_drop_revision(engine, schema_setup) -> None:
with migration_create_path.open() as migration_file:
migration_contents = migration_file.read()

# import pdb; pdb.set_trace()

assert "op.drop_entity" in migration_contents
assert "op.create_entity" in migration_contents
assert "from alembic_utils" in migration_contents
Expand Down
69 changes: 33 additions & 36 deletions src/test/test_pg_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,33 @@
from alembic_utils.pg_function import PGFunction
from alembic_utils.pg_trigger import PGTrigger
from alembic_utils.replaceable_entity import register_entities
from alembic_utils.statement import coerce_to_quoted
from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command


@pytest.fixture(scope="function")
def sql_setup(engine):
@pytest.fixture(scope="function", params=['account', 'Account'])
def sql_setup(request, engine):
with engine.begin() as connection:
connection.execute(
text(
"""
create table public.account (
id serial primary key,
email text not null
);
"""
)
text(f"""
create table public."{request.param}" (
id serial primary key,
email text not null
);
""")
)

yield
yield PGTrigger(
schema="public",
signature="lower_account_EMAIL",
on_entity=f"public.{request.param}",
definition=f"""
BEFORE INSERT ON public."{request.param}"
FOR EACH ROW EXECUTE PROCEDURE public.downcase_email()
""",
)
with engine.begin() as connection:
connection.execute(text("drop table public.account cascade"))
connection.execute(text(f'drop table public."{request.param}" cascade'))


FUNC = PGFunction.from_sql(
Expand All @@ -36,22 +43,12 @@ def sql_setup(engine):
"""
)

TRIG = PGTrigger(
schema="public",
signature="lower_account_EMAIL",
on_entity="public.account",
definition="""
BEFORE INSERT ON public.account
FOR EACH ROW EXECUTE PROCEDURE public.downcase_email()
""",
)


def test_create_revision(sql_setup, engine) -> None:
with engine.begin() as connection:
connection.execute(FUNC.to_sql_statement_create())

register_entities([FUNC, TRIG], entity_types=[PGTrigger])
register_entities([FUNC, sql_setup], entity_types=[PGTrigger])
run_alembic_command(
engine=engine,
command="revision",
Expand Down Expand Up @@ -79,14 +76,14 @@ def test_create_revision(sql_setup, engine) -> None:
def test_trig_update_revision(sql_setup, engine) -> None:
with engine.begin() as connection:
connection.execute(FUNC.to_sql_statement_create())
connection.execute(TRIG.to_sql_statement_create())
connection.execute(sql_setup.to_sql_statement_create())

UPDATED_TRIG = PGTrigger(
schema=TRIG.schema,
signature=TRIG.signature,
on_entity=TRIG.on_entity,
definition="""
AFTER INSERT ON public.account
schema=sql_setup.schema,
signature=sql_setup.signature,
on_entity=sql_setup.on_entity,
definition=f"""
AFTER INSERT ON {coerce_to_quoted(sql_setup.on_entity)}
FOR EACH ROW EXECUTE PROCEDURE public.downcase_email()
""",
)
Expand Down Expand Up @@ -121,11 +118,11 @@ def test_trig_update_revision(sql_setup, engine) -> None:
def test_noop_revision(sql_setup, engine) -> None:
with engine.begin() as connection:
connection.execute(FUNC.to_sql_statement_create())
connection.execute(TRIG.to_sql_statement_create())
connection.execute(sql_setup.to_sql_statement_create())

register_entities([FUNC, TRIG], entity_types=[PGTrigger])
register_entities([FUNC, sql_setup], entity_types=[PGTrigger])

output = run_alembic_command(
run_alembic_command(
engine=engine,
command="revision",
command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"},
Expand All @@ -150,7 +147,7 @@ def test_drop(sql_setup, engine) -> None:
# Manually create a SQL function
with engine.begin() as connection:
connection.execute(FUNC.to_sql_statement_create())
connection.execute(TRIG.to_sql_statement_create())
connection.execute(sql_setup.to_sql_statement_create())

# Register no functions locally
register_entities([], schemas=["public"], entity_types=[PGTrigger])
Expand Down Expand Up @@ -189,9 +186,9 @@ def test_on_entity_schema_not_qualified() -> None:

def test_fail_create_sql_statement_create():
trig = PGTrigger(
schema=TRIG.schema,
signature=TRIG.signature,
on_entity=TRIG.on_entity,
schema="public",
signature="lower_account_EMAIL",
on_entity="accountr",
definition="INVALID DEF",
)

Expand Down