Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions destination/iceberg/java_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,11 @@ func getServerConfigJSON(config *Config, partitionInfo []PartitionInfo, port int
logger.Warnf("No region explicitly provided for Glue catalog, the Java process will attempt to use region from AWS environment")
}

// Configure custom endpoint for S3-compatible services (like MinIO)
if config.S3Endpoint != "" {
serverConfig["s3.endpoint"] = config.S3Endpoint
serverConfig["io-impl"] = "org.apache.iceberg.io.ResolvingFileIO"
// Set SSL/TLS configuration
serverConfig["s3.ssl-enabled"] = utils.Ternary(config.S3UseSSL, "true", "false").(string)
}

// Configure S3 or GCP file IO
serverConfig["io-impl"] = utils.Ternary(strings.HasPrefix(config.IcebergS3Path, "gs://"), "org.apache.iceberg.gcp.gcs.GCSFileIO", "org.apache.iceberg.aws.s3.S3FileIO")
serverConfig["io-impl"] = "org.apache.iceberg.io.ResolvingFileIO"
serverConfig["s3.ssl-enabled"] = utils.Ternary(config.S3UseSSL, "true", "false").(string)

// Marshal the config to JSON
return json.Marshal(serverConfig)
Expand Down
33 changes: 20 additions & 13 deletions drivers/mongodb/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,23 +397,30 @@ func buildMongoCondition(cond types.Condition) bson.D {
"=": "$eq",
"!=": "$ne",
}
//TODO: take val as any type
value := func(field, val string) interface{} {

value := func(field string, val any) interface{} {
// Handle unquoted null
if val == "null" {
switch v := val.(type) {
case nil:
return nil
}
case string:
if v == "null" {
return nil
}

if strings.HasPrefix(val, "\"") && strings.HasSuffix(val, "\"") {
val = val[1 : len(val)-1]
}
if field == "_id" && len(val) == 24 {
if oid, err := primitive.ObjectIDFromHex(val); err == nil {
return oid
if strings.HasPrefix(v, "\"") && strings.HasSuffix(v, "\"") {
v = v[1 : len(v)-1]
}
}
if strings.ToLower(val) == "true" || strings.ToLower(val) == "false" {
return strings.ToLower(val) == "true"
if field == "_id" && len(v) == 24 {
if oid, err := primitive.ObjectIDFromHex(v); err == nil {
return oid
}
}
if strings.ToLower(v) == "true" || strings.ToLower(v) == "false" {
return strings.ToLower(v) == "true"
}

val = v
}
if timeVal, err := typeutils.ReformatDate(val); err == nil {
return timeVal
Expand Down
13 changes: 6 additions & 7 deletions drivers/mysql/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ func (m *MySQL) ChunkIterator(ctx context.Context, stream types.StreamInterface,
Driver: constants.MySQL,
Stream: stream,
State: m.state,
Client: m.client,
}
thresholdFilter, args, err := jdbc.ThresholdFilter(opts)
thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts)
if err != nil {
return fmt.Errorf("failed to set threshold filter: %s", err)
}
Expand Down Expand Up @@ -71,7 +70,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo
var approxRowCount int64
var avgRowSize any
approxRowCountQuery := jdbc.MySQLTableRowStatsQuery()
err := m.client.QueryRow(approxRowCountQuery, stream.Name()).Scan(&approxRowCount, &avgRowSize)
err := m.client.QueryRowContext(ctx, approxRowCountQuery, stream.Name()).Scan(&approxRowCount, &avgRowSize)
if err != nil {
return nil, fmt.Errorf("failed to get approx row count and avg row size: %s", err)
}
Expand Down Expand Up @@ -112,7 +111,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo
}
sort.Strings(pkColumns)
// Get table extremes
minVal, maxVal, err := m.getTableExtremes(stream, pkColumns, tx)
minVal, maxVal, err := m.getTableExtremes(ctx, stream, pkColumns, tx)
if err != nil {
return fmt.Errorf("failed to get table extremes: %s", err)
}
Expand Down Expand Up @@ -142,7 +141,7 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo
}
}
var nextValRaw interface{}
err := tx.QueryRow(query, args...).Scan(&nextValRaw)
err := tx.QueryRowContext(ctx, query, args...).Scan(&nextValRaw)
if err == sql.ErrNoRows || nextValRaw == nil {
break
} else if err != nil {
Expand Down Expand Up @@ -196,8 +195,8 @@ func (m *MySQL) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPo
return chunks, err
}

func (m *MySQL) getTableExtremes(stream types.StreamInterface, pkColumns []string, tx *sql.Tx) (min, max any, err error) {
func (m *MySQL) getTableExtremes(ctx context.Context, stream types.StreamInterface, pkColumns []string, tx *sql.Tx) (min, max any, err error) {
query := jdbc.MinMaxQueryMySQL(stream, pkColumns)
err = tx.QueryRow(query).Scan(&min, &max)
err = tx.QueryRowContext(ctx, query).Scan(&min, &max)
return min, max, err
}
2 changes: 1 addition & 1 deletion drivers/mysql/internal/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (m *MySQL) PreCDC(ctx context.Context, streams []types.StreamInterface) err
// Load or initialize global state
globalState := m.state.GetGlobal()
if globalState == nil || globalState.State == nil {
binlogPos, err := binlog.GetCurrentBinlogPosition(m.client)
binlogPos, err := binlog.GetCurrentBinlogPosition(ctx, m.client)
if err != nil {
return fmt.Errorf("failed to get current binlog position: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/mysql/internal/incremental.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (m *MySQL) StreamIncrementalChanges(ctx context.Context, stream types.Strea
Stream: stream,
State: m.state,
}
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts)
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts)
if err != nil {
return fmt.Errorf("failed to build incremental condition: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/mysql/internal/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (m *MySQL) Close() error {

func (m *MySQL) IsCDCSupported(ctx context.Context) (bool, error) {
// Permission check via SHOW MASTER STATUS / SHOW BINARY LOG STATUS
if _, err := binlog.GetCurrentBinlogPosition(m.client); err != nil {
if _, err := binlog.GetCurrentBinlogPosition(ctx, m.client); err != nil {
return false, fmt.Errorf("failed to get binlog position: %s", err)
}

Expand Down
6 changes: 3 additions & 3 deletions drivers/oracle/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (o *Oracle) ChunkIterator(ctx context.Context, stream types.StreamInterface
State: o.state,
Client: o.client,
}
thresholdFilter, args, err := jdbc.ThresholdFilter(opts)
thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts)
if err != nil {
return fmt.Errorf("failed to set threshold filter: %s", err)
}
Expand Down Expand Up @@ -61,7 +61,7 @@ func (o *Oracle) GetOrSplitChunks(ctx context.Context, pool *destination.WriterP
splitViaRowId := func(stream types.StreamInterface) (*types.Set[types.Chunk], error) {
// TODO: Add implementation of AddRecordsToSync function which expects total number of records to be synced
query := jdbc.OracleEmptyCheckQuery(stream)
err := o.client.QueryRow(query).Scan(new(interface{}))
err := o.client.QueryRowContext(ctx, query).Scan(new(interface{}))
if err != nil {
if err == sql.ErrNoRows {
logger.Warnf("Table %s.%s is empty skipping chunking", stream.Namespace(), stream.Name())
Expand All @@ -72,7 +72,7 @@ func (o *Oracle) GetOrSplitChunks(ctx context.Context, pool *destination.WriterP

query = jdbc.OracleBlockSizeQuery()
var blockSize int64
err = o.client.QueryRow(query).Scan(&blockSize)
err = o.client.QueryRowContext(ctx, query).Scan(&blockSize)
if err != nil || blockSize == 0 {
logger.Warnf("failed to get block size from query, switching to default block size value 8192")
blockSize = 8192
Expand Down
2 changes: 1 addition & 1 deletion drivers/oracle/internal/incremental.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func (o *Oracle) StreamIncrementalChanges(ctx context.Context, stream types.Stre
State: o.state,
Client: o.client,
}
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts)
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts)
if err != nil {
return fmt.Errorf("failed to build incremental condition: %s", err)
}
Expand Down
26 changes: 12 additions & 14 deletions drivers/postgres/internal/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa
Driver: constants.Postgres,
Stream: stream,
State: p.state,
Client: p.client,
}
thresholdFilter, args, err := jdbc.ThresholdFilter(opts)
thresholdFilter, args, err := jdbc.ThresholdFilter(ctx, opts)
if err != nil {
return fmt.Errorf("failed to set threshold filter: %s", err)
}
Expand All @@ -44,7 +43,7 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa
chunkColumn = utils.Ternary(chunkColumn == "", "ctid", chunkColumn).(string)
stmt := jdbc.PostgresChunkScanQuery(stream, chunkColumn, chunk, filter)
setter := jdbc.NewReader(ctx, stmt, func(ctx context.Context, query string, queryArgs ...any) (*sql.Rows, error) {
return tx.Query(query, args...)
return tx.QueryContext(ctx, query, args...)
})

return setter.Capture(func(rows *sql.Rows) error {
Expand All @@ -61,28 +60,27 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa
})
}

func (p *Postgres) GetOrSplitChunks(_ context.Context, pool *destination.WriterPool, stream types.StreamInterface) (*types.Set[types.Chunk], error) {
func (p *Postgres) GetOrSplitChunks(ctx context.Context, pool *destination.WriterPool, stream types.StreamInterface) (*types.Set[types.Chunk], error) {
var approxRowCount int64
approxRowCountQuery := jdbc.PostgresRowCountQuery(stream)
// TODO: use ctx while querying
err := p.client.QueryRow(approxRowCountQuery).Scan(&approxRowCount)
err := p.client.QueryRowContext(ctx, approxRowCountQuery).Scan(&approxRowCount)
if err != nil {
return nil, fmt.Errorf("failed to get approx row count: %s", err)
}
pool.AddRecordsToSyncStats(approxRowCount)
return p.splitTableIntoChunks(stream)
return p.splitTableIntoChunks(ctx, stream)
}

func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Set[types.Chunk], error) {
func (p *Postgres) splitTableIntoChunks(ctx context.Context, stream types.StreamInterface) (*types.Set[types.Chunk], error) {
generateCTIDRanges := func(stream types.StreamInterface) (*types.Set[types.Chunk], error) {
var relPages, blockSize uint32
relPagesQuery := jdbc.PostgresRelPageCount(stream)
err := p.client.QueryRow(relPagesQuery).Scan(&relPages)
err := p.client.QueryRowContext(ctx, relPagesQuery).Scan(&relPages)
if err != nil {
return nil, fmt.Errorf("failed to get relPages: %s", err)
}
blockSizeQuery := jdbc.PostgresBlockSizeQuery()
err = p.client.QueryRow(blockSizeQuery).Scan(&blockSize)
err = p.client.QueryRowContext(ctx, blockSizeQuery).Scan(&blockSize)
if err != nil {
return nil, fmt.Errorf("failed to get block size: %s", err)
}
Expand Down Expand Up @@ -125,7 +123,7 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se
chunkStart := min
splits := types.NewSet[types.Chunk]()
for {
chunkEnd, err := p.nextChunkEnd(stream, chunkStart, chunkColumn)
chunkEnd, err := p.nextChunkEnd(ctx, stream, chunkStart, chunkColumn)
if err != nil {
return nil, fmt.Errorf("failed to split chunks based on next query size: %s", err)
}
Expand All @@ -145,7 +143,7 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se
var minValue, maxValue interface{}
minMaxRowCountQuery := jdbc.MinMaxQuery(stream, chunkColumn)
// TODO: Fails on UUID type (Good First Issue)
err := p.client.QueryRow(minMaxRowCountQuery).Scan(&minValue, &maxValue)
err := p.client.QueryRowContext(ctx, minMaxRowCountQuery).Scan(&minValue, &maxValue)
if err != nil {
return nil, fmt.Errorf("failed to fetch table min max: %s", err)
}
Expand All @@ -171,10 +169,10 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se
}
}

func (p *Postgres) nextChunkEnd(stream types.StreamInterface, previousChunkEnd interface{}, chunkColumn string) (interface{}, error) {
func (p *Postgres) nextChunkEnd(ctx context.Context, stream types.StreamInterface, previousChunkEnd interface{}, chunkColumn string) (interface{}, error) {
var chunkEnd interface{}
nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, chunkColumn, previousChunkEnd)
err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd)
err := p.client.QueryRowContext(ctx, nextChunkEnd).Scan(&chunkEnd)
if err != nil {
return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err)
}
Expand Down
11 changes: 6 additions & 5 deletions drivers/postgres/internal/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,23 @@ func (p *Postgres) PostCDC(ctx context.Context, _ types.StreamInterface, noErr b
return nil
}

func doesReplicationSlotExists(conn *sqlx.DB, slotName string, publication string) (bool, error) {
func doesReplicationSlotExists(ctx context.Context, conn *sqlx.DB, slotName string, publication string) (bool, error) {
var exists bool
err := conn.QueryRow(
err := conn.QueryRowContext(
ctx,
"SELECT EXISTS(Select 1 from pg_replication_slots where slot_name = $1)",
slotName,
).Scan(&exists)
if err != nil {
return false, err
}

return exists, validateReplicationSlot(conn, slotName, publication)
return exists, validateReplicationSlot(ctx, conn, slotName, publication)
}

func validateReplicationSlot(conn *sqlx.DB, slotName string, publication string) error {
func validateReplicationSlot(ctx context.Context, conn *sqlx.DB, slotName string, publication string) error {
slot := waljs.ReplicationSlot{}
err := conn.Get(&slot, fmt.Sprintf(waljs.ReplicationSlotTempl, slotName))
err := conn.GetContext(ctx, &slot, fmt.Sprintf(waljs.ReplicationSlotTempl, slotName))
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion drivers/postgres/internal/incremental.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func (p *Postgres) StreamIncrementalChanges(ctx context.Context, stream types.St
Stream: stream,
State: p.state,
}
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(opts)
incrementalQuery, queryArgs, err := jdbc.BuildIncrementalQuery(ctx, opts)
if err != nil {
return fmt.Errorf("failed to build incremental condition: %s", err)
}
Expand Down
8 changes: 4 additions & 4 deletions drivers/postgres/internal/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (p *Postgres) Setup(ctx context.Context) error {

logger.Infof("CDC initial wait time set to: %d", cdc.InitialWaitTime)

exists, err := doesReplicationSlotExists(pgClient, cdc.ReplicationSlot, cdc.Publication)
exists, err := doesReplicationSlotExists(ctx, pgClient, cdc.ReplicationSlot, cdc.Publication)
if err != nil {
if strings.Contains(err.Error(), "sql: no rows in result set") {
err = fmt.Errorf("no record found")
Expand Down Expand Up @@ -174,7 +174,7 @@ func (p *Postgres) CloseConnection() {
func (p *Postgres) GetStreamNames(ctx context.Context) ([]string, error) {
logger.Infof("Starting discover for Postgres database %s", p.config.Database)
var tableNamesOutput []Table
err := p.client.Select(&tableNamesOutput, getPrivilegedTablesTmpl)
err := p.client.SelectContext(ctx, &tableNamesOutput, getPrivilegedTablesTmpl)
if err != nil {
return nil, fmt.Errorf("failed to retrieve table names: %s", err)
}
Expand All @@ -191,7 +191,7 @@ func (p *Postgres) ProduceSchema(ctx context.Context, streamName string) (*types
schemaName, streamName := streamParts[0], streamParts[1]
stream := types.NewStream(streamName, schemaName, &p.config.Database)
var columnSchemaOutput []ColumnDetails
err := p.client.Select(&columnSchemaOutput, getTableSchemaTmpl, schemaName, streamName)
err := p.client.SelectContext(ctx, &columnSchemaOutput, getTableSchemaTmpl, schemaName, streamName)
if err != nil {
return stream, fmt.Errorf("failed to retrieve column details for table %s: %s", streamName, err)
}
Expand All @@ -202,7 +202,7 @@ func (p *Postgres) ProduceSchema(ctx context.Context, streamName string) (*types
}

var primaryKeyOutput []ColumnDetails
err = p.client.Select(&primaryKeyOutput, getTablePrimaryKey, schemaName, streamName)
err = p.client.SelectContext(ctx, &primaryKeyOutput, getTablePrimaryKey, schemaName, streamName)
if err != nil {
return stream, fmt.Errorf("failed to retrieve primary key columns for table %s: %s", streamName, err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/binlog/binlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func NewConnection(_ context.Context, config *Config, pos mysql.Position, stream
}

func (c *Connection) StreamMessages(ctx context.Context, client *sqlx.DB, callback abstract.CDCMsgFn) error {
latestBinlogPos, err := GetCurrentBinlogPosition(client)
latestBinlogPos, err := GetCurrentBinlogPosition(ctx, client)
if err != nil {
return fmt.Errorf("failed to get latest binlog position: %s", err)
}
Expand Down Expand Up @@ -127,19 +127,19 @@ func (c *Connection) Cleanup() {
}

// GetCurrentBinlogPosition retrieves the current binlog position from MySQL.
func GetCurrentBinlogPosition(client *sqlx.DB) (mysql.Position, error) {
func GetCurrentBinlogPosition(ctx context.Context, client *sqlx.DB) (mysql.Position, error) {
// SHOW MASTER STATUS is not supported in MySQL 8.4 and after

// Get MySQL version
majorVersion, minorVersion, err := jdbc.MySQLVersion(client)
majorVersion, minorVersion, err := jdbc.MySQLVersion(ctx, client)
if err != nil {
return mysql.Position{}, fmt.Errorf("failed to get MySQL version: %s", err)
}

// Use the appropriate query based on the MySQL version
query := utils.Ternary(majorVersion > 8 || (majorVersion == 8 && minorVersion >= 4), jdbc.MySQLMasterStatusQueryNew(), jdbc.MySQLMasterStatusQuery()).(string)

rows, err := client.Query(query)
rows, err := client.QueryContext(ctx, query)
if err != nil {
return mysql.Position{}, fmt.Errorf("failed to get master status: %s", err)
}
Expand Down
Loading