diff --git a/internal/sinks/postgres.go b/internal/sinks/postgres.go index e3fbc1347b..7f9c2318bd 100644 --- a/internal/sinks/postgres.go +++ b/internal/sinks/postgres.go @@ -94,7 +94,7 @@ func NewWriterFromPostgresConn(ctx context.Context, conn db.PgxPoolIface, opts * ctx: ctx, opts: opts, input: make(chan metrics.MeasurementEnvelope, cacheLimit), - lastError: make(chan error), + lastError: make(chan error, 1), sinkDb: conn, forceRecreatePartitions: false, partitionMapMetric: make(map[string]ExistingPartitionInfo), @@ -422,7 +422,10 @@ func (pgw *PostgresWriter) flush(msgs []metrics.MeasurementEnvelope) { } pgw.forceRecreatePartitions = false if err != nil { - pgw.lastError <- err + select { + case pgw.lastError <- err: + default: + } } var rowsBatched, n int64 @@ -446,7 +449,10 @@ func (pgw *PostgresWriter) flush(msgs []metrics.MeasurementEnvelope) { logger.WithField("rows", rowsBatched).WithField("elapsed", diff).Info("measurements written") return } - pgw.lastError <- err + select { + case pgw.lastError <- err: + default: + } } func (pgw *PostgresWriter) EnsureMetricTimescale(pgPartBounds map[string]ExistingPartitionInfo) (err error) { diff --git a/internal/sinks/postgres_test.go b/internal/sinks/postgres_test.go index 3198173077..fe293416bd 100644 --- a/internal/sinks/postgres_test.go +++ b/internal/sinks/postgres_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "testing" "time" @@ -11,6 +12,8 @@ import ( "github.com/cybertec-postgresql/pgwatch/v5/internal/metrics" "github.com/cybertec-postgresql/pgwatch/v5/internal/testutil" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" jsoniter "github.com/json-iterator/go" "github.com/pashagolub/pgxmock/v4" "github.com/stretchr/testify/assert" @@ -975,3 +978,80 @@ func Test_Maintain(t *testing.T) { } }) } + +type fastErrorDB struct { + execErr error + copyFromErr error +} + +func (f *fastErrorDB) Exec(context.Context, string, ...any) (pgconn.CommandTag, error) { + return pgconn.CommandTag{}, f.execErr +} + +func (f *fastErrorDB) CopyFrom(_ context.Context, _ pgx.Identifier, _ []string, rowSrc pgx.CopyFromSource) (int64, error) { + for rowSrc.Next() { + } + return 0, f.copyFromErr +} + +func (f *fastErrorDB) Begin(context.Context) (pgx.Tx, error) { return nil, nil } +func (f *fastErrorDB) QueryRow(context.Context, string, ...any) pgx.Row { return nil } +func (f *fastErrorDB) Query(context.Context, string, ...any) (pgx.Rows, error) { return nil, nil } +func (f *fastErrorDB) Acquire(context.Context) (*pgxpool.Conn, error) { return nil, nil } +func (f *fastErrorDB) BeginTx(context.Context, pgx.TxOptions) (pgx.Tx, error) { return nil, nil } +func (f *fastErrorDB) Close() {} +func (f *fastErrorDB) Config() *pgxpool.Config { return nil } +func (f *fastErrorDB) Ping(context.Context) error { return nil } +func (f *fastErrorDB) Stat() *pgxpool.Stat { return nil } + +// TestLastErrorChannelDeadlock verifies the fix for issue #1212. +// See: https://github.com/cybertec-postgresql/pgwatch/issues/1212 +func TestLastErrorChannelDeadlock(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + fakeDB := &fastErrorDB{ + execErr: errors.New("simulated partition error"), + copyFromErr: errors.New("simulated copy error"), + } + + pgw := &PostgresWriter{ + ctx: ctx, + sinkDb: fakeDB, + opts: &CmdOpts{BatchingDelay: 10 * time.Millisecond}, + input: make(chan metrics.MeasurementEnvelope, cacheLimit), + lastError: make(chan error, 1), + metricSchema: DbStorageSchemaTimescale, + partitionMapMetric: make(map[string]ExistingPartitionInfo), + partitionMapMetricDbname: make(map[string]map[string]ExistingPartitionInfo), + } + + go pgw.poll() + + msg := metrics.MeasurementEnvelope{ + MetricName: "test_metric", + DBName: "test_db", + Data: metrics.Measurements{ + {metrics.EpochColumnName: time.Now().UnixNano(), "value": int64(42)}, + }, + } + + err := pgw.Write(msg) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + var writeSucceeded atomic.Bool + go func() { + for i := 0; i < cacheLimit+10; i++ { + select { + case pgw.input <- msg: + case <-time.After(500 * time.Millisecond): + return + } + } + writeSucceeded.Store(true) + }() + + time.Sleep(700 * time.Millisecond) + assert.True(t, writeSucceeded.Load(), "poll() should not deadlock; all messages should be processed") +}