From 9c2398475f9236c3cd1a08f4a170036b51eda6c9 Mon Sep 17 00:00:00 2001 From: Badal Prasad Singh Date: Fri, 17 Oct 2025 14:20:08 +0530 Subject: [PATCH 1/4] fix: Azure ADLS Lakekeeper Rest Catalog Issue (#591) Signed-off-by: badalprasadsingh --- destination/iceberg/java_client.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/destination/iceberg/java_client.go b/destination/iceberg/java_client.go index bc8d48d2..72f86402 100644 --- a/destination/iceberg/java_client.go +++ b/destination/iceberg/java_client.go @@ -114,16 +114,11 @@ func getServerConfigJSON(config *Config, partitionInfo []PartitionInfo, port int logger.Warnf("No region explicitly provided for Glue catalog, the Java process will attempt to use region from AWS environment") } - // Configure custom endpoint for S3-compatible services (like MinIO) if config.S3Endpoint != "" { serverConfig["s3.endpoint"] = config.S3Endpoint - serverConfig["io-impl"] = "org.apache.iceberg.io.ResolvingFileIO" - // Set SSL/TLS configuration - serverConfig["s3.ssl-enabled"] = utils.Ternary(config.S3UseSSL, "true", "false").(string) } - - // Configure S3 or GCP file IO - serverConfig["io-impl"] = utils.Ternary(strings.HasPrefix(config.IcebergS3Path, "gs://"), "org.apache.iceberg.gcp.gcs.GCSFileIO", "org.apache.iceberg.aws.s3.S3FileIO") + serverConfig["io-impl"] = "org.apache.iceberg.io.ResolvingFileIO" + serverConfig["s3.ssl-enabled"] = utils.Ternary(config.S3UseSSL, "true", "false").(string) // Marshal the config to JSON return json.Marshal(serverConfig) From 0473e28c0654be232faf8c17635bff618e4b6299 Mon Sep 17 00:00:00 2001 From: Vaibhav Date: Mon, 20 Oct 2025 11:43:12 +0530 Subject: [PATCH 2/4] fix: Integration test spark error fixed (#594) --- utils/testutils/test_utils.go | 45 +++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/utils/testutils/test_utils.go b/utils/testutils/test_utils.go index 4913b989..216263f4 100644 --- a/utils/testutils/test_utils.go +++ b/utils/testutils/test_utils.go @@ -10,6 +10,7 @@ import ( "time" "github.com/apache/spark-connect-go/v35/spark/sql" + "github.com/apache/spark-connect-go/v35/spark/sql/types" "github.com/datazip-inc/olake/constants" "github.com/datazip-inc/olake/utils" "github.com/datazip-inc/olake/utils/typeutils" @@ -375,17 +376,47 @@ func VerifyIcebergSync(t *testing.T, tableName, icebergDB string, datatypeSchema } }() + fullTableName := fmt.Sprintf("%s.%s.%s", icebergCatalog, icebergDB, tableName) selectQuery := fmt.Sprintf( - "SELECT * FROM %s.%s.%s WHERE _op_type = '%s'", - icebergCatalog, icebergDB, tableName, opSymbol, + "SELECT * FROM %s WHERE _op_type = '%s'", + fullTableName, opSymbol, ) t.Logf("Executing query: %s", selectQuery) - selectQueryDf, err := spark.Sql(ctx, selectQuery) - require.NoError(t, err, "Failed to select query from the table") + var selectRows []types.Row + var queryErr error + maxRetries := 5 + retryDelay := 2 * time.Second - selectRows, err := selectQueryDf.Collect(ctx) - require.NoError(t, err, "Failed to collect data rows from Iceberg") + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + time.Sleep(retryDelay) + } + var selectQueryDf sql.DataFrame + // This is to check if the table exists in destination, as race condition might cause table to not be created yet + selectQueryDf, queryErr = spark.Sql(ctx, selectQuery) + if queryErr != nil { + t.Logf("Query attempt %d failed: %v", attempt+1, queryErr) + continue + } + + // To ensure stale data is not being used for verification + selectRows, queryErr = selectQueryDf.Collect(ctx) + if queryErr != nil { + t.Logf("Query attempt %d failed (Collect error): %v", attempt+1, queryErr) + continue + } + if len(selectRows) > 0 { + queryErr = nil + break + } + + // for every type of operation, op symbol will be different, using that to ensure data is not stale + queryErr = fmt.Errorf("stale data: query succeeded but returned 0 rows for _op_type = '%s'", opSymbol) + t.Logf("Query attempt %d/%d failed: %v", attempt+1, maxRetries, queryErr) + } + + require.NoError(t, queryErr, "Failed to collect data rows from Iceberg after %d attempts: %v", maxRetries, queryErr) require.NotEmpty(t, selectRows, "No rows returned for _op_type = '%s'", opSymbol) // delete row checked @@ -408,7 +439,7 @@ func VerifyIcebergSync(t *testing.T, tableName, icebergDB string, datatypeSchema } t.Logf("Verified Iceberg synced data with respect to data synced from source[%s] found equal", driver) - describeQuery := fmt.Sprintf("DESCRIBE TABLE %s.%s.%s", icebergCatalog, icebergDB, tableName) + describeQuery := fmt.Sprintf("DESCRIBE TABLE %s", fullTableName) describeDf, err := spark.Sql(ctx, describeQuery) require.NoError(t, err, "Failed to describe Iceberg table") From 3831d532797c4a277fb5d3aff7748ea2b33bd45c Mon Sep 17 00:00:00 2001 From: ashh1215 <124177595+ashh1215@users.noreply.github.com> Date: Mon, 20 Oct 2025 17:36:16 +0530 Subject: [PATCH 3/4] Merge pull request #566 from ashh1215/AddContext fix: add context while querying (sql/sqlx) (#319) --- drivers/mysql/internal/backfill.go | 13 ++++++------ drivers/mysql/internal/cdc.go | 2 +- drivers/mysql/internal/incremental.go | 2 +- drivers/mysql/internal/mysql.go | 2 +- drivers/oracle/internal/backfill.go | 6 +++--- drivers/oracle/internal/incremental.go | 2 +- drivers/postgres/internal/backfill.go | 26 +++++++++++------------- drivers/postgres/internal/cdc.go | 11 +++++----- drivers/postgres/internal/incremental.go | 2 +- drivers/postgres/internal/postgres.go | 8 ++++---- pkg/binlog/binlog.go | 8 ++++---- pkg/jdbc/jdbc.go | 16 +++++++-------- pkg/waljs/pgoutput.go | 2 +- pkg/waljs/replicator.go | 2 +- pkg/waljs/waljs.go | 2 +- 15 files changed, 51 insertions(+), 53 deletions(-) diff --git a/drivers/mysql/internal/backfill.go b/drivers/mysql/internal/backfill.go index fcd53d55..363f0b79 100644 --- a/drivers/mysql/internal/backfill.go +++ b/drivers/mysql/internal/backfill.go @@ -23,9 +23,8 @@ func (m *MySQL) ChunkIterator(ctx context.Context, stream types.StreamInterface, Driver: constants.MySQL, Stream: stream, State: m.state, - Client: m.client, } - thresholdFilter, args, err := jdbc.ThresholdFilter(opts) + thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts) if err != nil { return fmt.Errorf("failed to set threshold filter: %s", err) } @@ -71,7 +70,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo var approxRowCount int64 var avgRowSize any approxRowCountQuery := jdbc.MySQLTableRowStatsQuery() - err := m.client.QueryRow(approxRowCountQuery, stream.Name()).Scan(&approxRowCount, &avgRowSize) + err := m.client.QueryRowContext(ctx, approxRowCountQuery, stream.Name()).Scan(&approxRowCount, &avgRowSize) if err != nil { return nil, fmt.Errorf("failed to get approx row count and avg row size: %s", err) } @@ -112,7 +111,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo } sort.Strings(pkColumns) // Get table extremes - minVal, maxVal, err := m.getTableExtremes(stream, pkColumns, tx) + minVal, maxVal, err := m.getTableExtremes(ctx, stream, pkColumns, tx) if err != nil { return fmt.Errorf("failed to get table extremes: %s", err) } @@ -142,7 +141,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo } } var nextValRaw interface{} - err := tx.QueryRow(query, args...).Scan(&nextValRaw) + err := tx.QueryRowContext(ctx, query, args...).Scan(&nextValRaw) if err == sql.ErrNoRows || nextValRaw == nil { break } else if err != nil { @@ -196,8 +195,8 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo return chunks, err } -func (m *MySQL) getTableExtremes(stream types.StreamInterface, pkColumns []string, tx *sql.Tx) (min, max any, err error) { +func (m *MySQL) getTableExtremes(ctx context.Context, stream types.StreamInterface, pkColumns []string, tx *sql.Tx) (min, max any, err error) { query := jdbc.MinMaxQueryMySQL(stream, pkColumns) - err = tx.QueryRow(query).Scan(&min, &max) + err = tx.QueryRowContext(ctx, query).Scan(&min, &max) return min, max, err } diff --git a/drivers/mysql/internal/cdc.go b/drivers/mysql/internal/cdc.go index 7dda8f04..65feab42 100644 --- a/drivers/mysql/internal/cdc.go +++ b/drivers/mysql/internal/cdc.go @@ -42,7 +42,7 @@ func (m *MySQL) PreCDC(ctx context.Context, streams []types.StreamInterface) err // Load or initialize global state globalState := m.state.GetGlobal() if globalState == nil || globalState.State == nil { - binlogPos, err := binlog.GetCurrentBinlogPosition(m.client) + binlogPos, err := binlog.GetCurrentBinlogPosition(ctx, m.client) if err != nil { return fmt.Errorf("failed to get current binlog position: %s", err) } diff --git a/drivers/mysql/internal/incremental.go b/drivers/mysql/internal/incremental.go index b0b20040..b553f95c 100644 --- a/drivers/mysql/internal/incremental.go +++ b/drivers/mysql/internal/incremental.go @@ -17,7 +17,7 @@ func (m *MySQL) StreamIncrementalChanges(ctx context.Context, stream types.Strea Stream: stream, State: m.state, } - incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts) + incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts) if err != nil { return fmt.Errorf("failed to build incremental condition: %s", err) } diff --git a/drivers/mysql/internal/mysql.go b/drivers/mysql/internal/mysql.go index e9c4fc07..e2467dae 100644 --- a/drivers/mysql/internal/mysql.go +++ b/drivers/mysql/internal/mysql.go @@ -242,7 +242,7 @@ func (m *MySQL) Close() error { func (m *MySQL) IsCDCSupported(ctx context.Context) (bool, error) { // Permission check via SHOW MASTER STATUS / SHOW BINARY LOG STATUS - if _, err := binlog.GetCurrentBinlogPosition(m.client); err != nil { + if _, err := binlog.GetCurrentBinlogPosition(ctx, m.client); err != nil { return false, fmt.Errorf("failed to get binlog position: %s", err) } diff --git a/drivers/oracle/internal/backfill.go b/drivers/oracle/internal/backfill.go index 05de4b85..61b465a7 100644 --- a/drivers/oracle/internal/backfill.go +++ b/drivers/oracle/internal/backfill.go @@ -23,7 +23,7 @@ func (o *Oracle) ChunkIterator(ctx context.Context, stream types.StreamInterface State: o.state, Client: o.client, } - thresholdFilter, args, err := jdbc.ThresholdFilter(opts) + thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts) if err != nil { return fmt.Errorf("failed to set threshold filter: %s", err) } @@ -61,7 +61,7 @@ func (o *Oracle) GetOrSplitChunks(ctx context.Context, pool *destination.WriterP splitViaRowId := func(stream types.StreamInterface) (*types.Set[types.Chunk], error) { // TODO: Add implementation of AddRecordsToSync function which expects total number of records to be synced query := jdbc.OracleEmptyCheckQuery(stream) - err := o.client.QueryRow(query).Scan(new(interface{})) + err := o.client.QueryRowContext(ctx, query).Scan(new(interface{})) if err != nil { if err == sql.ErrNoRows { logger.Warnf("Table %s.%s is empty skipping chunking", stream.Namespace(), stream.Name()) @@ -72,7 +72,7 @@ func (o *Oracle) GetOrSplitChunks(ctx context.Context, pool *destination.WriterP query = jdbc.OracleBlockSizeQuery() var blockSize int64 - err = o.client.QueryRow(query).Scan(&blockSize) + err = o.client.QueryRowContext(ctx, query).Scan(&blockSize) if err != nil || blockSize == 0 { logger.Warnf("failed to get block size from query, switching to default block size value 8192") blockSize = 8192 diff --git a/drivers/oracle/internal/incremental.go b/drivers/oracle/internal/incremental.go index 1ea6f183..da38c726 100644 --- a/drivers/oracle/internal/incremental.go +++ b/drivers/oracle/internal/incremental.go @@ -18,7 +18,7 @@ func (o *Oracle) StreamIncrementalChanges(ctx context.Context, stream types.Stre State: o.state, Client: o.client, } - incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts) + incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts) if err != nil { return fmt.Errorf("failed to build incremental condition: %s", err) } diff --git a/drivers/postgres/internal/backfill.go b/drivers/postgres/internal/backfill.go index a6ae8b1b..70a92972 100644 --- a/drivers/postgres/internal/backfill.go +++ b/drivers/postgres/internal/backfill.go @@ -21,9 +21,8 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa Driver: constants.Postgres, Stream: stream, State: p.state, - Client: p.client, } - thresholdFilter, args, err := jdbc.ThresholdFilter(opts) + thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts) if err != nil { return fmt.Errorf("failed to set threshold filter: %s", err) } @@ -44,7 +43,7 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa chunkColumn = utils.Ternary(chunkColumn == "", "ctid", chunkColumn).(string) stmt := jdbc.PostgresChunkScanQuery(stream, chunkColumn, chunk, filter) setter := jdbc.NewReader(ctx, stmt, func(ctx context.Context, query string, queryArgs ...any) (*sql.Rows, error) { - return tx.Query(query, args...) + return tx.QueryContext(ctx, query, args...) }) return setter.Capture(func(rows *sql.Rows) error { @@ -61,28 +60,27 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa }) } -func (p *Postgres) GetOrSplitChunks(_ context.Context, pool *destination.WriterPool, stream types.StreamInterface) (*types.Set[types.Chunk], error) { +func (p *Postgres) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPool, stream types.StreamInterface) (*types.Set[types.Chunk], error) { var approxRowCount int64 approxRowCountQuery := jdbc.PostgresRowCountQuery(stream) - // TODO: use ctx while querying - err := p.client.QueryRow(approxRowCountQuery).Scan(&approxRowCount) + err := p.client.QueryRowContext(ctx, approxRowCountQuery).Scan(&approxRowCount) if err != nil { return nil, fmt.Errorf("failed to get approx row count: %s", err) } pool.AddRecordsToSyncStats(approxRowCount) - return p.splitTableIntoChunks(stream) + return p.splitTableIntoChunks(ctx, stream) } -func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Set[types.Chunk], error) { +func (p *Postgres) splitTableIntoChunks(ctx context.Context, stream types.StreamInterface) (*types.Set[types.Chunk], error) { generateCTIDRanges := func(stream types.StreamInterface) (*types.Set[types.Chunk], error) { var relPages, blockSize uint32 relPagesQuery := jdbc.PostgresRelPageCount(stream) - err := p.client.QueryRow(relPagesQuery).Scan(&relPages) + err := p.client.QueryRowContext(ctx, relPagesQuery).Scan(&relPages) if err != nil { return nil, fmt.Errorf("failed to get relPages: %s", err) } blockSizeQuery := jdbc.PostgresBlockSizeQuery() - err = p.client.QueryRow(blockSizeQuery).Scan(&blockSize) + err = p.client.QueryRowContext(ctx, blockSizeQuery).Scan(&blockSize) if err != nil { return nil, fmt.Errorf("failed to get block size: %s", err) } @@ -125,7 +123,7 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se chunkStart := min splits := types.NewSet[types.Chunk]() for { - chunkEnd, err := p.nextChunkEnd(stream, chunkStart, chunkColumn) + chunkEnd, err := p.nextChunkEnd(ctx, stream, chunkStart, chunkColumn) if err != nil { return nil, fmt.Errorf("failed to split chunks based on next query size: %s", err) } @@ -145,7 +143,7 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se var minValue, maxValue interface{} minMaxRowCountQuery := jdbc.MinMaxQuery(stream, chunkColumn) // TODO: Fails on UUID type (Good First Issue) - err := p.client.QueryRow(minMaxRowCountQuery).Scan(&minValue, &maxValue) + err := p.client.QueryRowContext(ctx, minMaxRowCountQuery).Scan(&minValue, &maxValue) if err != nil { return nil, fmt.Errorf("failed to fetch table min max: %s", err) } @@ -171,10 +169,10 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se } } -func (p *Postgres) nextChunkEnd(stream types.StreamInterface, previousChunkEnd interface{}, chunkColumn string) (interface{}, error) { +func (p *Postgres) nextChunkEnd(ctx context.Context, stream types.StreamInterface, previousChunkEnd interface{}, chunkColumn string) (interface{}, error) { var chunkEnd interface{} nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, chunkColumn, previousChunkEnd) - err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd) + err := p.client.QueryRowContext(ctx, nextChunkEnd).Scan(&chunkEnd) if err != nil { return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err) } diff --git a/drivers/postgres/internal/cdc.go b/drivers/postgres/internal/cdc.go index c4c7c882..06d3be05 100644 --- a/drivers/postgres/internal/cdc.go +++ b/drivers/postgres/internal/cdc.go @@ -100,9 +100,10 @@ func (p *Postgres) PostCDC(ctx context.Context, _ types.StreamInterface, noErr b return nil } -func doesReplicationSlotExists(conn *sqlx.DB, slotName string, publication string) (bool, error) { +func doesReplicationSlotExists(ctx context.Context, conn *sqlx.DB, slotName string, publication string) (bool, error) { var exists bool - err := conn.QueryRow( + err := conn.QueryRowContext( + ctx, "SELECT EXISTS(Select 1 from pg_replication_slots where slot_name = $1)", slotName, ).Scan(&exists) @@ -110,12 +111,12 @@ func doesReplicationSlotExists(conn *sqlx.DB, slotName string, publication strin return false, err } - return exists, validateReplicationSlot(conn, slotName, publication) + return exists, validateReplicationSlot(ctx, conn, slotName, publication) } -func validateReplicationSlot(conn *sqlx.DB, slotName string, publication string) error { +func validateReplicationSlot(ctx context.Context, conn *sqlx.DB, slotName string, publication string) error { slot := waljs.ReplicationSlot{} - err := conn.Get(&slot, fmt.Sprintf(waljs.ReplicationSlotTempl, slotName)) + err := conn.GetContext(ctx, &slot, fmt.Sprintf(waljs.ReplicationSlotTempl, slotName)) if err != nil { return err } diff --git a/drivers/postgres/internal/incremental.go b/drivers/postgres/internal/incremental.go index c3b4f611..d9ee26a5 100644 --- a/drivers/postgres/internal/incremental.go +++ b/drivers/postgres/internal/incremental.go @@ -16,7 +16,7 @@ func (p *Postgres) StreamIncrementalChanges(ctx context.Context, stream types.St Stream: stream, State: p.state, } - incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts) + incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts) if err != nil { return fmt.Errorf("failed to build incremental condition: %s", err) } diff --git a/drivers/postgres/internal/postgres.go b/drivers/postgres/internal/postgres.go index 18dc6ae2..2e3d6582 100644 --- a/drivers/postgres/internal/postgres.go +++ b/drivers/postgres/internal/postgres.go @@ -115,7 +115,7 @@ func (p *Postgres) Setup(ctx context.Context) error { logger.Infof("CDC initial wait time set to: %d", cdc.InitialWaitTime) - exists, err := doesReplicationSlotExists(pgClient, cdc.ReplicationSlot, cdc.Publication) + exists, err := doesReplicationSlotExists(ctx, pgClient, cdc.ReplicationSlot, cdc.Publication) if err != nil { if strings.Contains(err.Error(), "sql: no rows in result set") { err = fmt.Errorf("no record found") @@ -174,7 +174,7 @@ func (p *Postgres) CloseConnection() { func (p *Postgres) GetStreamNames(ctx context.Context) ([]string, error) { logger.Infof("Starting discover for Postgres database %s", p.config.Database) var tableNamesOutput []Table - err := p.client.Select(&tableNamesOutput, getPrivilegedTablesTmpl) + err := p.client.SelectContext(ctx, &tableNamesOutput, getPrivilegedTablesTmpl) if err != nil { return nil, fmt.Errorf("failed to retrieve table names: %s", err) } @@ -191,7 +191,7 @@ func (p *Postgres) ProduceSchema(ctx context.Context, streamName string) (*types schemaName, streamName := streamParts[0], streamParts[1] stream := types.NewStream(streamName, schemaName, &p.config.Database) var columnSchemaOutput []ColumnDetails - err := p.client.Select(&columnSchemaOutput, getTableSchemaTmpl, schemaName, streamName) + err := p.client.SelectContext(ctx, &columnSchemaOutput, getTableSchemaTmpl, schemaName, streamName) if err != nil { return stream, fmt.Errorf("failed to retrieve column details for table %s: %s", streamName, err) } @@ -202,7 +202,7 @@ func (p *Postgres) ProduceSchema(ctx context.Context, streamName string) (*types } var primaryKeyOutput []ColumnDetails - err = p.client.Select(&primaryKeyOutput, getTablePrimaryKey, schemaName, streamName) + err = p.client.SelectContext(ctx, &primaryKeyOutput, getTablePrimaryKey, schemaName, streamName) if err != nil { return stream, fmt.Errorf("failed to retrieve primary key columns for table %s: %s", streamName, err) } diff --git a/pkg/binlog/binlog.go b/pkg/binlog/binlog.go index c976550a..24de8281 100644 --- a/pkg/binlog/binlog.go +++ b/pkg/binlog/binlog.go @@ -56,7 +56,7 @@ func NewConnection(_ context.Context, config *Config, pos mysql.Position, stream } func (c *Connection) StreamMessages(ctx context.Context, client *sqlx.DB, callback abstract.CDCMsgFn) error { - latestBinlogPos, err := GetCurrentBinlogPosition(client) + latestBinlogPos, err := GetCurrentBinlogPosition(ctx, client) if err != nil { return fmt.Errorf("failed to get latest binlog position: %s", err) } @@ -127,11 +127,11 @@ func (c *Connection) Cleanup() { } // GetCurrentBinlogPosition retrieves the current binlog position from MySQL. -func GetCurrentBinlogPosition(client *sqlx.DB) (mysql.Position, error) { +func GetCurrentBinlogPosition(ctx context.Context, client *sqlx.DB) (mysql.Position, error) { // SHOW MASTER STATUS is not supported in MySQL 8.4 and after // Get MySQL version - majorVersion, minorVersion, err := jdbc.MySQLVersion(client) + majorVersion, minorVersion, err := jdbc.MySQLVersion(ctx, client) if err != nil { return mysql.Position{}, fmt.Errorf("failed to get MySQL version: %s", err) } @@ -139,7 +139,7 @@ func GetCurrentBinlogPosition(client *sqlx.DB) (mysql.Position, error) { // Use the appropriate query based on the MySQL version query := utils.Ternary(majorVersion > 8 || (majorVersion == 8 && minorVersion >= 4), jdbc.MySQLMasterStatusQueryNew(), jdbc.MySQLMasterStatusQuery()).(string) - rows, err := client.Query(query) + rows, err := client.QueryContext(ctx, query) if err != nil { return mysql.Position{}, fmt.Errorf("failed to get master status: %s", err) } diff --git a/pkg/jdbc/jdbc.go b/pkg/jdbc/jdbc.go index 3c3d34ba..a183abdf 100644 --- a/pkg/jdbc/jdbc.go +++ b/pkg/jdbc/jdbc.go @@ -343,9 +343,9 @@ func MySQLTableColumnsQuery() string { // MySQLVersion returns the version of the MySQL server // It returns the major and minor version of the MySQL server -func MySQLVersion(client *sqlx.DB) (int, int, error) { +func MySQLVersion(ctx context.Context, client *sqlx.DB) (int, int, error) { var version string - err := client.QueryRow("SELECT @@version").Scan(&version) + err := client.QueryRowContext(ctx, "SELECT @@version").Scan(&version) if err != nil { return 0, 0, fmt.Errorf("failed to get MySQL version: %s", err) } @@ -461,7 +461,7 @@ func OracleChunkRetrievalQuery(taskName string) string { } // OracleIncrementalValueFormatter is used to format the value of the cursor field for Oracle incremental sync, mainly because of the various timestamp formats -func OracleIncrementalValueFormatter(cursorField, argumentPlaceholder string, isBackfill bool, lastCursorValue any, opts DriverOptions) (string, any, error) { +func OracleIncrementalValueFormatter(ctx context.Context, cursorField, argumentPlaceholder string, isBackfill bool, lastCursorValue any, opts DriverOptions) (string, any, error) { // Get the datatype of the cursor field from streams stream := opts.Stream // in case of incremental sync mode, during backfill to avoid duplicate records we need to use '<=', otherwise use '>' @@ -479,7 +479,7 @@ func OracleIncrementalValueFormatter(cursorField, argumentPlaceholder string, is } query := fmt.Sprintf("SELECT DATA_TYPE FROM ALL_TAB_COLUMNS WHERE OWNER = '%s' AND TABLE_NAME = '%s' AND COLUMN_NAME = '%s'", stream.Namespace(), stream.Name(), cursorField) - err = opts.Client.QueryRow(query).Scan(&datatype) + err = opts.Client.QueryRowContext(ctx, query).Scan(&datatype) if err != nil { return "", nil, fmt.Errorf("failed to get column datatype: %s", err) } @@ -576,7 +576,7 @@ type DriverOptions struct { } // BuildIncrementalQuery generates the incremental query SQL based on driver type -func BuildIncrementalQuery(opts DriverOptions) (string, []any, error) { +func BuildIncrementalQuery(ctx context.Context, opts DriverOptions) (string, []any, error) { primaryCursor, secondaryCursor := opts.Stream.Cursor() lastPrimaryCursorValue := opts.State.GetCursor(opts.Stream.Self(), primaryCursor) lastSecondaryCursorValue := opts.State.GetCursor(opts.Stream.Self(), secondaryCursor) @@ -593,7 +593,7 @@ func BuildIncrementalQuery(opts DriverOptions) (string, []any, error) { // buildCursorCondition creates the SQL condition for incremental queries based on cursor fields. buildCursorCondition := func(cursorField string, lastCursorValue any, argumentPosition int) (string, any, error) { if opts.Driver == constants.Oracle { - return OracleIncrementalValueFormatter(cursorField, placeholder(argumentPosition), false, lastCursorValue, opts) + return OracleIncrementalValueFormatter(ctx, cursorField, placeholder(argumentPosition), false, lastCursorValue, opts) } quotedColumn := QuoteIdentifier(cursorField, opts.Driver) return fmt.Sprintf("%s > %s", quotedColumn, placeholder(argumentPosition)), lastCursorValue, nil @@ -666,7 +666,7 @@ func GetMaxCursorValues(ctx context.Context, client *sqlx.DB, driverType constan // ThresholdFilter is used to update the filter for initial run of incremental sync during backfill. // This is to avoid dupliction of records, as max cursor value is fetched before the chunk creation. -func ThresholdFilter(opts DriverOptions) (string, []any, error) { +func ThresholdFilter(ctx context.Context, opts DriverOptions) (string, []any, error) { if opts.Stream.GetSyncMode() != types.INCREMENTAL { return "", nil, nil } @@ -676,7 +676,7 @@ func ThresholdFilter(opts DriverOptions) (string, []any, error) { createThresholdCondition := func(argumentPosition int, cursorField string, cursorValue any) (string, any, error) { if opts.Driver == constants.Oracle { - return OracleIncrementalValueFormatter(cursorField, placeholder(argumentPosition), true, cursorValue, opts) + return OracleIncrementalValueFormatter(ctx, cursorField, placeholder(argumentPosition), true, cursorValue, opts) } conditionFilter := fmt.Sprintf("%s <= %s", QuoteIdentifier(cursorField, opts.Driver), placeholder(argumentPosition)) return conditionFilter, cursorValue, nil diff --git a/pkg/waljs/pgoutput.go b/pkg/waljs/pgoutput.go index 9de678c4..9779f01c 100644 --- a/pkg/waljs/pgoutput.go +++ b/pkg/waljs/pgoutput.go @@ -31,7 +31,7 @@ func (p *pgoutputReplicator) Socket() *Socket { func (p *pgoutputReplicator) StreamChanges(ctx context.Context, db *sqlx.DB, insertFn abstract.CDCMsgFn) error { var slot ReplicationSlot - if err := db.Get(&slot, fmt.Sprintf(ReplicationSlotTempl, p.socket.ReplicationSlot)); err != nil { + if err := db.GetContext(ctx, &slot, fmt.Sprintf(ReplicationSlotTempl, p.socket.ReplicationSlot)); err != nil { return fmt.Errorf("failed to get replication slot: %s", err) } p.socket.CurrentWalPosition = slot.CurrentLSN diff --git a/pkg/waljs/replicator.go b/pkg/waljs/replicator.go index 3a2767a4..3cb299e7 100644 --- a/pkg/waljs/replicator.go +++ b/pkg/waljs/replicator.go @@ -98,7 +98,7 @@ func NewReplicator(ctx context.Context, db *sqlx.DB, config *Config, typeConvert // Get replication slot position var slot ReplicationSlot - if err := db.Get(&slot, fmt.Sprintf(ReplicationSlotTempl, config.ReplicationSlotName)); err != nil { + if err := db.GetContext(ctx, &slot, fmt.Sprintf(ReplicationSlotTempl, config.ReplicationSlotName)); err != nil { return nil, fmt.Errorf("failed to get replication slot: %s", err) } diff --git a/pkg/waljs/waljs.go b/pkg/waljs/waljs.go index f9ab9916..4589f3dd 100644 --- a/pkg/waljs/waljs.go +++ b/pkg/waljs/waljs.go @@ -33,7 +33,7 @@ func (w *wal2jsonReplicator) Socket() *Socket { func (w *wal2jsonReplicator) StreamChanges(ctx context.Context, db *sqlx.DB, callback abstract.CDCMsgFn) error { // update current lsn information var slot ReplicationSlot - if err := db.Get(&slot, fmt.Sprintf(ReplicationSlotTempl, w.socket.ReplicationSlot)); err != nil { + if err := db.GetContext(ctx, &slot, fmt.Sprintf(ReplicationSlotTempl, w.socket.ReplicationSlot)); err != nil { return fmt.Errorf("failed to get replication slot: %s", err) } From 42a9298a2ede9d7ec8e8cdac0bde6fe9fbce8d0f Mon Sep 17 00:00:00 2001 From: vr-varad Date: Sat, 11 Oct 2025 15:49:43 +0530 Subject: [PATCH 4/4] Fix: take val as any type in mongodb --- drivers/mongodb/internal/backfill.go | 33 +++++++++++++++++----------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/drivers/mongodb/internal/backfill.go b/drivers/mongodb/internal/backfill.go index 5f094352..ef5f97ea 100644 --- a/drivers/mongodb/internal/backfill.go +++ b/drivers/mongodb/internal/backfill.go @@ -397,23 +397,30 @@ func buildMongoCondition(cond types.Condition) bson.D { "=": "$eq", "!=": "$ne", } - //TODO: take val as any type - value := func(field, val string) interface{} { + + value := func(field string, val any) interface{} { // Handle unquoted null - if val == "null" { + switch v := val.(type) { + case nil: return nil - } + case string: + if v == "null" { + return nil + } - if strings.HasPrefix(val, "\"") && strings.HasSuffix(val, "\"") { - val = val[1 : len(val)-1] - } - if field == "_id" && len(val) == 24 { - if oid, err := primitive.ObjectIDFromHex(val); err == nil { - return oid + if strings.HasPrefix(v, "\"") && strings.HasSuffix(v, "\"") { + v = v[1 : len(v)-1] } - } - if strings.ToLower(val) == "true" || strings.ToLower(val) == "false" { - return strings.ToLower(val) == "true" + if field == "_id" && len(v) == 24 { + if oid, err := primitive.ObjectIDFromHex(v); err == nil { + return oid + } + } + if strings.ToLower(v) == "true" || strings.ToLower(v) == "false" { + return strings.ToLower(v) == "true" + } + + val = v } if timeVal, err := typeutils.ReformatDate(val); err == nil { return timeVal