diff --git a/drivers/postgres/internal/backfill.go b/drivers/postgres/internal/backfill.go index c075fb8c5..e98b6ba44 100644 --- a/drivers/postgres/internal/backfill.go +++ b/drivers/postgres/internal/backfill.go @@ -27,7 +27,8 @@ func (p *Postgres) ChunkIterator(ctx context.Context, stream types.StreamInterfa defer tx.Rollback() chunkColumn := stream.Self().StreamMetadata.ChunkColumn chunkColumn = utils.Ternary(chunkColumn == "", "ctid", chunkColumn).(string) - stmt := jdbc.PostgresChunkScanQuery(stream, chunkColumn, chunk, filter) + chunkColType, _ := stream.Schema().GetType(chunkColumn) + stmt := jdbc.PostgresChunkScanQuery(stream, chunkColumn, chunk, filter, chunkColType) setter := jdbc.NewReader(ctx, stmt, p.config.BatchSize, func(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return tx.Query(query, args...) }) @@ -129,7 +130,6 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se if chunkColumn != "" { var minValue, maxValue interface{} minMaxRowCountQuery := jdbc.MinMaxQuery(stream, chunkColumn) - // TODO: Fails on UUID type (Good First Issue) err := p.client.QueryRow(minMaxRowCountQuery).Scan(&minValue, &maxValue) if err != nil { return nil, fmt.Errorf("failed to fetch table min max: %s", err) @@ -158,7 +158,8 @@ func (p *Postgres) splitTableIntoChunks(stream types.StreamInterface) (*types.Se func (p *Postgres) nextChunkEnd(stream types.StreamInterface, previousChunkEnd interface{}, chunkColumn string) (interface{}, error) { var chunkEnd interface{} - nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, chunkColumn, previousChunkEnd, p.config.BatchSize) + chunkColType, _ := stream.Schema().GetType(chunkColumn) + nextChunkEnd := jdbc.PostgresNextChunkEndQuery(stream, chunkColumn, previousChunkEnd, p.config.BatchSize, chunkColType) err := p.client.QueryRow(nextChunkEnd).Scan(&chunkEnd) if err != nil { return nil, fmt.Errorf("failed to query[%s] next chunk end: %s", nextChunkEnd, err) diff --git a/pkg/jdbc/jdbc.go b/pkg/jdbc/jdbc.go index 7f5761510..f8b2eb3e6 100644 --- a/pkg/jdbc/jdbc.go +++ b/pkg/jdbc/jdbc.go @@ -127,26 +127,38 @@ func PostgresWalLSNQuery() string { } // PostgresNextChunkEndQuery generates a SQL query to fetch the maximum value of a specified column -func PostgresNextChunkEndQuery(stream types.StreamInterface, filterColumn string, filterValue interface{}, batchSize int) string { +func PostgresNextChunkEndQuery(stream types.StreamInterface, filterColumn string, filterValue interface{}, batchSize int, dataType types.DataType) string { quotedColumn := QuoteIdentifier(filterColumn, constants.Postgres) quotedTable := QuoteTable(stream.Namespace(), stream.Name(), constants.Postgres) - baseCond := fmt.Sprintf(`%s > %v`, quotedColumn, filterValue) + var baseCond string + if dataType == types.String { + baseCond = fmt.Sprintf(`%s > '%s'`, quotedColumn, strings.ReplaceAll(fmt.Sprintf("%v", filterValue), "'", "''")) + } else { + baseCond = fmt.Sprintf(`%s > %v`, quotedColumn, filterValue) + } return fmt.Sprintf(`SELECT MAX(%s) FROM (SELECT %s FROM %s WHERE %s ORDER BY %s ASC LIMIT %d) AS T`, quotedColumn, quotedColumn, quotedTable, baseCond, quotedColumn, batchSize) } // PostgresBuildSplitScanQuery builds a chunk scan query for PostgreSQL -func PostgresChunkScanQuery(stream types.StreamInterface, filterColumn string, chunk types.Chunk, filter string) string { +func PostgresChunkScanQuery(stream types.StreamInterface, filterColumn string, chunk types.Chunk, filter string, dataType types.DataType) string { quotedFilterColumn := QuoteIdentifier(filterColumn, constants.Postgres) quotedTable := QuoteTable(stream.Namespace(), stream.Name(), constants.Postgres) + formatValue := func(val interface{}) string { + if dataType == types.String && val != nil { + return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprintf("%v", val), "'", "''")) + } + return fmt.Sprintf("%v", val) + } + chunkCond := "" if chunk.Min != nil && chunk.Max != nil { - chunkCond = fmt.Sprintf("%s >= %v AND %s < %v", quotedFilterColumn, chunk.Min, quotedFilterColumn, chunk.Max) + chunkCond = fmt.Sprintf("%s >= %s AND %s < %s", quotedFilterColumn, formatValue(chunk.Min), quotedFilterColumn, formatValue(chunk.Max)) } else if chunk.Min != nil { - chunkCond = fmt.Sprintf("%s >= %v", quotedFilterColumn, chunk.Min) + chunkCond = fmt.Sprintf("%s >= %s", quotedFilterColumn, formatValue(chunk.Min)) } else if chunk.Max != nil { - chunkCond = fmt.Sprintf("%s < %v", quotedFilterColumn, chunk.Max) + chunkCond = fmt.Sprintf("%s < %s", quotedFilterColumn, formatValue(chunk.Max)) } chunkCond = utils.Ternary(filter != "" && chunkCond != "", fmt.Sprintf("(%s) AND (%s)", chunkCond, filter), chunkCond).(string)