4949PRIVILEGES : Final = "PRIVILEGES"
5050PGLOGICAL_NODE_ALREADY_EXISTS : Final = "PGLOGICAL_NODE_ALREADY_EXISTS"
5151TABLES_WITH_NO_PK : Final = "TABLES_WITH_NO_PK"
52+ UNSUPPORTED_TABLES_WITH_REPLICA_IDENTITY : Final = "UNSUPPORTED_TABLES_WITH_REPLICA_IDENTITY"
53+ REPLICATION_ROLE : Final = "REPLICATION_ROLE"
5254CLOUDSQL : Final = "CLOUDSQL"
5355ALLOYDB : Final = "ALLOYDB"
5456CloudSQL_SUPER_ROLE : Final = "cloudsqladmin"
@@ -126,21 +128,23 @@ def execute(self) -> None:
126128 self ._check_max_replication_slots ()
127129 self ._check_max_wal_senders_replication_slots ()
128130 self ._check_max_worker_processes ()
129- self ._check_extensions ()
130131 self ._check_fdw ()
131132
132133 # Per DB Checks.
134+ self ._check_extensions ()
135+ self ._check_replication_role ()
133136 # db_check_results stores the verification results for all DBs.
134137 for config in self .rule_config :
135138 db_check_results : dict [str , dict [str , list ]] = {}
136- for db in self .get_all_dbs ():
139+ for db in sorted ( self .get_all_dbs () ):
137140 is_pglogical_installed = self ._check_pglogical_installed (db , db_check_results )
138141 if is_pglogical_installed :
139142 privilege_check_passed = self ._check_privileges (db , db_check_results )
140143 if not privilege_check_passed :
141- break
144+ continue
142145 self ._check_if_node_exists (db , db_check_results )
143146 self ._check_tables_without_pk (db , db_check_results )
147+ self ._check_tables_replica_identity (db , db_check_results )
144148 self ._save_results (config .db_variant , db_check_results )
145149
146150 def _save_results (self , db_variant : PostgresVariants , db_check_results : dict [str , dict [str , list ]]) -> None :
@@ -157,9 +161,9 @@ def _save_results(self, db_variant: PostgresVariants, db_check_results: dict[str
157161
158162 if rule == PRIVILEGES :
159163 if severity == ACTION_REQUIRED :
160- output_str = ";" .join (result [ACTION_REQUIRED ])
164+ output_str = ";\n \n " .join (result [ACTION_REQUIRED ])
161165 elif severity == PASS :
162- output_str = ";" .join (result [PASS ])
166+ output_str = ";\n " .join (result [PASS ])
163167 else :
164168 continue
165169
@@ -179,6 +183,12 @@ def _save_results(self, db_variant: PostgresVariants, db_check_results: dict[str
179183 else :
180184 continue
181185
186+ if rule == UNSUPPORTED_TABLES_WITH_REPLICA_IDENTITY :
187+ if severity == ACTION_REQUIRED :
188+ output_str = f"Source has table(s) with both primary key and replica identity FULL or NOTHING. Please remove replica identity or change it to DEFAULT to migrate: { ';' .join (result [ACTION_REQUIRED ])} "
189+ else :
190+ continue
191+
182192 if len (result [severity ]) > 0 :
183193 self .save_rule_result (
184194 db_variant ,
@@ -225,6 +235,17 @@ def _check_tables_without_pk(self, db_name: str, db_check_results: dict[str, dic
225235 if tables :
226236 db_check_results [rule_code ][WARNING ].append (f"In database { db_name } , { tables } don't have primary keys" )
227237
238+ def _check_tables_replica_identity (self , db_name : str , db_check_results : dict [str , dict [str , list ]]) -> None :
239+ rule_code = UNSUPPORTED_TABLES_WITH_REPLICA_IDENTITY
240+ result = self .local_db .sql (
241+ "select CONCAT(nspname, '.', relname) from collection_postgres_tables_with_primary_key_replica_identity where database_name = $db_name" ,
242+ params = {"db_name" : db_name },
243+ ).fetchall ()
244+ tables = ", " .join (row [0 ] for row in result )
245+ init_results_dict (db_check_results , rule_code )
246+ if tables :
247+ db_check_results [rule_code ][ACTION_REQUIRED ].append (f"{ tables } in database { db_name } " )
248+
228249 def _check_version (self ) -> None :
229250 rule_code = "DATABASE_VERSION"
230251 self .console .print (f"version: { self .db_version } " )
@@ -357,6 +378,29 @@ def check_user_obj_privileges(self, db_name: str) -> list[str]:
357378 errors .extend (f"user doesn't have SELECT privilege on sequence { row [0 ]} .{ row [1 ]} " for row in rows )
358379 return errors
359380
381+ def _check_replication_role (self ) -> None :
382+ if self ._is_rds ():
383+ return
384+ rule_code = REPLICATION_ROLE
385+ result = self .local_db .sql ("SELECT rolreplication FROM collection_postgres_replication_role" ).fetchone ()
386+ if result is None :
387+ return
388+ for c in self .rule_config :
389+ if result [0 ] == "false" :
390+ self .save_rule_result (
391+ c .db_variant ,
392+ rule_code ,
393+ ACTION_REQUIRED ,
394+ "user does not have rolreplication role." ,
395+ )
396+ else :
397+ self .save_rule_result (
398+ c .db_variant ,
399+ rule_code ,
400+ PASS ,
401+ "user has rolreplication role." ,
402+ )
403+
360404 def _check_privileges (self , db_name : str , db_check_results : dict [str , dict [str , list ]]) -> bool :
361405 rule_code = PRIVILEGES
362406 errors = self ._check_pglogical_privileges (db_name )
0 commit comments