diff --git a/integration/pgdog.toml b/integration/pgdog.toml index 041d6bb2e..b70a5a2c5 100644 --- a/integration/pgdog.toml +++ b/integration/pgdog.toml @@ -23,6 +23,8 @@ tls_private_key = "integration/tls/key.pem" query_parser_engine = "pg_query_raw" system_catalogs = "omnisharded_sticky" reload_schema_on_ddl = false +#idle_healthcheck_delay = 50000000 + [memory] net_buffer = 8096 @@ -403,6 +405,28 @@ shard = 1 database = "pgdog_schema" shard = 0 +[[databases]] +name = "pgdog_schema_no_cross" +host = "127.0.0.1" +database_name = "shard_0" +shard = 0 + +[[databases]] +name = "pgdog_schema_no_cross" +host = "127.0.0.1" +database_name = "shard_1" +shard = 1 + + +[[sharded_schemas]] +database = "pgdog_schema_no_cross" +name = "shard_0" +shard = 0 + +[[sharded_schemas]] +database = "pgdog_schema_no_cross" +name = "shard_1" +shard = 1 # ------------------------------------------------------------------------------ # ----- Admin ------------------------------------------------------------------ diff --git a/integration/python/test_session_mode.py b/integration/python/test_session_mode.py new file mode 100644 index 000000000..d0240f43e --- /dev/null +++ b/integration/python/test_session_mode.py @@ -0,0 +1,338 @@ +import asyncpg +import psycopg +import pytest +from globals import no_out_of_sync + + +def session_conn(schema): + return psycopg.connect( + user="pgdog_session_no_cross_shard", + password="pgdog", + dbname="pgdog_schema_no_cross", + host="127.0.0.1", + port=6432, + options=f"-c search_path={schema},public", + ) + + +async def async_session_conn(schema): + return await asyncpg.connect( + user="pgdog_session_no_cross_shard", + password="pgdog", + database="pgdog_schema_no_cross", + host="127.0.0.1", + port=6432, + server_settings={"search_path": f"{schema},public"}, + statement_cache_size=0, + ) + + +def test_session_simple_queries(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + cur = conn.cursor() + cur.execute("SELECT 1::bigint") + assert cur.fetchone()[0] == 1 + conn.commit() + conn.close() + no_out_of_sync() + + +def test_session_ddl_and_dml(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + conn.autocommit = True + cur = conn.cursor() + + cur.execute("DROP TABLE IF EXISTS test_session_mode") + cur.execute("CREATE TABLE test_session_mode(id BIGINT, value TEXT)") + + cur.execute( + "INSERT INTO test_session_mode (id, value) " + "VALUES (%s, %s) RETURNING *", + (1, "hello"), + ) + row = cur.fetchone() + assert row[0] == 1 + assert row[1] == "hello" + + cur.execute("SELECT * FROM test_session_mode WHERE id = %s", (1,)) + assert cur.fetchone()[0] == 1 + + cur.execute( + "UPDATE test_session_mode SET value = %s WHERE id = %s RETURNING *", + ("world", 1), + ) + row = cur.fetchone() + assert row[1] == "world" + + cur.execute("DELETE FROM test_session_mode WHERE id = %s", (1,)) + assert cur.rowcount == 1 + + cur.execute("DROP TABLE test_session_mode") + conn.close() + no_out_of_sync() + + +def test_session_transactions(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + cur = conn.cursor() + + cur.execute("DROP TABLE IF EXISTS test_session_tx") + conn.commit() + + cur.execute("CREATE TABLE test_session_tx(id BIGINT, value TEXT)") + conn.commit() + + cur.execute( + "INSERT INTO test_session_tx (id, value) VALUES (%s, %s)", (1, "a") + ) + cur.execute( + "INSERT INTO test_session_tx (id, value) VALUES (%s, %s)", (2, "b") + ) + conn.commit() + + cur.execute("SELECT count(*) FROM test_session_tx") + assert cur.fetchone()[0] == 2 + conn.commit() + + cur.execute( + "INSERT INTO test_session_tx (id, value) VALUES (%s, %s)", (3, "c") + ) + conn.rollback() + + cur.execute("SELECT count(*) FROM test_session_tx") + assert cur.fetchone()[0] == 2 + conn.commit() + + cur.execute("DROP TABLE test_session_tx") + conn.commit() + conn.close() + no_out_of_sync() + + +def test_session_transaction_set_local(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + cur = conn.cursor() + + cur.execute("SET LOCAL statement_timeout TO '10s'") + cur.execute("SELECT 1::bigint") + assert cur.fetchone()[0] == 1 + conn.commit() + conn.close() + no_out_of_sync() + + +def test_session_search_path_visible(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + cur = conn.cursor() + cur.execute("SHOW search_path") + search_path = cur.fetchone()[0] + assert schema in search_path + conn.commit() + conn.close() + no_out_of_sync() + + +def test_session_multiple_statements_in_transaction(): + for schema in ["shard_0", "shard_1"]: + conn = session_conn(schema) + cur = conn.cursor() + + cur.execute("DROP TABLE IF EXISTS test_session_multi") + conn.commit() + cur.execute("CREATE TABLE test_session_multi(id BIGINT)") + conn.commit() + + for i in range(10): + cur.execute("INSERT INTO test_session_multi (id) VALUES (%s)", (i,)) + conn.commit() + + cur.execute("SELECT count(*) FROM test_session_multi") + assert cur.fetchone()[0] == 10 + conn.commit() + + cur.execute("DROP TABLE test_session_multi") + conn.commit() + conn.close() + no_out_of_sync() + + +def no_search_path_conn(): + return psycopg.connect( + user="pgdog_session_no_cross_shard", + password="pgdog", + dbname="pgdog_schema_no_cross", + host="127.0.0.1", + port=6432, + ) + + +def _create_no_sp_test_table(): + for db in ["shard_0", "shard_1"]: + direct = psycopg.connect( + user="pgdog", password="pgdog", dbname=db, + host="127.0.0.1", port=5432, + ) + direct.autocommit = True + for schema in ["shard_0", "shard_1"]: + direct.cursor().execute(f"CREATE SCHEMA IF NOT EXISTS {schema}") + direct.cursor().execute( + f"CREATE TABLE IF NOT EXISTS {schema}.no_sp_test(id BIGINT, value TEXT)" + ) + direct.cursor().execute("CREATE TABLE IF NOT EXISTS no_sp_test(id BIGINT, value TEXT)") + direct.close() + + +def test_no_search_path_cross_shard_insert_blocked(): + _create_no_sp_test_table() + + conn = no_search_path_conn() + conn.autocommit = True + cur = conn.cursor() + with pytest.raises(psycopg.errors.SystemError, match="cross-shard queries are disabled"): + cur.execute("INSERT INTO no_sp_test (id, value) VALUES (1, 'test')") + conn.close() + no_out_of_sync() + + +def test_no_search_path_cross_shard_update_blocked(): + conn = no_search_path_conn() + conn.autocommit = True + cur = conn.cursor() + with pytest.raises(psycopg.errors.SystemError, match="cross-shard queries are disabled"): + cur.execute("UPDATE no_sp_test SET value = 'changed'") + conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_simple_queries(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + row = await conn.fetchrow("SELECT 1::bigint AS v") + assert row["v"] == 1 + await conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_ddl_and_dml(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + + await conn.execute("DROP TABLE IF EXISTS test_async_session_mode") + await conn.execute("CREATE TABLE test_async_session_mode(id BIGINT, value TEXT)") + + row = await conn.fetchrow( + "INSERT INTO test_async_session_mode (id, value) VALUES ($1, $2) RETURNING *", + 1, "hello", + ) + assert row["id"] == 1 + assert row["value"] == "hello" + + row = await conn.fetchrow( + "SELECT * FROM test_async_session_mode WHERE id = $1", 1 + ) + assert row["id"] == 1 + + row = await conn.fetchrow( + "UPDATE test_async_session_mode SET value = $1 WHERE id = $2 RETURNING *", + "world", 1, + ) + assert row["value"] == "world" + + result = await conn.execute( + "DELETE FROM test_async_session_mode WHERE id = $1", 1 + ) + assert result == "DELETE 1" + + await conn.execute("DROP TABLE test_async_session_mode") + await conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_transactions(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + + await conn.execute("DROP TABLE IF EXISTS test_async_session_tx") + await conn.execute("CREATE TABLE test_async_session_tx(id BIGINT, value TEXT)") + + async with conn.transaction(): + await conn.execute( + "INSERT INTO test_async_session_tx (id, value) VALUES ($1, $2)", 1, "a" + ) + await conn.execute( + "INSERT INTO test_async_session_tx (id, value) VALUES ($1, $2)", 2, "b" + ) + + row = await conn.fetchrow("SELECT count(*)::bigint AS c FROM test_async_session_tx") + assert row["c"] == 2 + + try: + async with conn.transaction(): + await conn.execute( + "INSERT INTO test_async_session_tx (id, value) VALUES ($1, $2)", + 3, "c", + ) + raise Exception("force rollback") + except Exception: + pass + + row = await conn.fetchrow("SELECT count(*)::bigint AS c FROM test_async_session_tx") + assert row["c"] == 2 + + await conn.execute("DROP TABLE test_async_session_tx") + await conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_transaction_set_local(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + async with conn.transaction(): + await conn.execute("SET LOCAL statement_timeout TO '10s'") + row = await conn.fetchrow("SELECT 1::bigint AS v") + assert row["v"] == 1 + await conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_search_path_visible(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + row = await conn.fetchrow("SHOW search_path") + assert schema in row["search_path"] + await conn.close() + no_out_of_sync() + + +@pytest.mark.asyncio +async def test_async_session_multiple_statements_in_transaction(): + for schema in ["shard_0", "shard_1"]: + conn = await async_session_conn(schema) + + await conn.execute("DROP TABLE IF EXISTS test_async_session_multi") + await conn.execute("CREATE TABLE test_async_session_multi(id BIGINT)") + + async with conn.transaction(): + for i in range(10): + await conn.execute( + "INSERT INTO test_async_session_multi (id) VALUES ($1)", i + ) + + row = await conn.fetchrow( + "SELECT count(*)::bigint AS c FROM test_async_session_multi" + ) + assert row["c"] == 10 + + await conn.execute("DROP TABLE test_async_session_multi") + await conn.close() + no_out_of_sync() diff --git a/integration/users.toml b/integration/users.toml index 87273f7c6..ca2b987a7 100644 --- a/integration/users.toml +++ b/integration/users.toml @@ -10,6 +10,14 @@ password = "pgdog" server_user = "pgdog" pooler_mode = "session" +[[users]] +name = "pgdog_session" +database = "pgdog_sharded" +password = "pgdog" +server_user = "pgdog" +pooler_mode = "session" +cross_shard_disabled = true + [[users]] name = "pgdog" database = "pgdog_sharded" @@ -67,3 +75,11 @@ min_pool_size = 0 name = "pgdog" password = "pgdog" database = "pgdog_schema" + +[[users]] +name = "pgdog_session_no_cross_shard" +password = "pgdog" +database = "pgdog_schema_no_cross" +server_user = "pgdog" +cross_shard_disabled = true +pooler_mode = "session" diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index cd6fb5f9a..e2b09522f 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -394,18 +394,24 @@ impl Client { let shutdown = self.comms.shutting_down(); let mut offline; let mut query_engine = QueryEngine::from_client(self)?; + let mut terminating = false; loop { - offline = (self.comms.offline() && !self.admin || self.shutdown) && query_engine.done(); + offline = (self.comms.offline() && !self.admin || self.shutdown) + && query_engine.can_disconnect(); if offline { break; } + if terminating && query_engine.can_disconnect() { + break; + } + let client_state = query_engine.client_state(); select! { _ = shutdown.notified() => { - if query_engine.done() { + if query_engine.can_disconnect() { continue; // Wake up task. } } @@ -416,7 +422,7 @@ impl Client { self.server_message(&mut query_engine, message).await?; } - buffer = self.buffer(client_state) => { + buffer = self.buffer(client_state), if !terminating => { let event = buffer?; // Only send requests to the backend if they are complete. @@ -429,13 +435,11 @@ impl Client { match event { BufferEvent::DisconnectAbrupt => break, BufferEvent::DisconnectGraceful => { - let done = query_engine.done(); - - if done { + if query_engine.can_disconnect() { break; } + terminating = true; } - BufferEvent::HaveRequest => (), } } diff --git a/pgdog/src/frontend/client/query_engine/mod.rs b/pgdog/src/frontend/client/query_engine/mod.rs index b383b69d5..37ae05285 100644 --- a/pgdog/src/frontend/client/query_engine/mod.rs +++ b/pgdog/src/frontend/client/query_engine/mod.rs @@ -122,9 +122,9 @@ impl QueryEngine { Ok(self.backend.read().await?) } - /// Query engine finished executing. - pub fn done(&self) -> bool { - !self.backend.connected() && self.begin_stmt.is_none() + /// Client can safely disconnect (no active backend connection or pending transaction). + pub fn can_disconnect(&self) -> bool { + self.begin_stmt.is_none() && self.backend.done() } /// Current state. @@ -195,6 +195,7 @@ impl QueryEngine { query, transaction_type, extended, + .. } => { self.start_transaction(context, query.clone(), *transaction_type, *extended) .await? diff --git a/pgdog/src/frontend/router/parser/command.rs b/pgdog/src/frontend/router/parser/command.rs index 4e0b26dc9..b7ba5c2e1 100644 --- a/pgdog/src/frontend/router/parser/command.rs +++ b/pgdog/src/frontend/router/parser/command.rs @@ -21,6 +21,7 @@ pub enum Command { query: BufferedQuery, transaction_type: TransactionType, extended: bool, + route: Route, }, CommitTransaction { extended: bool, @@ -66,6 +67,7 @@ impl Command { match self { Self::Query(route) => route, Self::Set { route, .. } => route, + Self::StartTransaction { route, .. } => route, _ => &DEFAULT_ROUTE, } } diff --git a/pgdog/src/frontend/router/parser/query/transaction.rs b/pgdog/src/frontend/router/parser/query/transaction.rs index ef9e6379a..3639174b9 100644 --- a/pgdog/src/frontend/router/parser/query/transaction.rs +++ b/pgdog/src/frontend/router/parser/query/transaction.rs @@ -35,6 +35,8 @@ impl QueryParser { query: context.query()?.clone(), transaction_type, extended, + route: Route::write(context.shards_calculator.shard()) + .with_read(transaction_type == TransactionType::ReadOnly), }); } TransactionStmtKind::TransStmtRollbackTo => rollback_savepoint = true,