diff --git a/internal/reaper/reaper.go b/internal/reaper/reaper.go index 0c187c400a..4c8de538f6 100644 --- a/internal/reaper/reaper.go +++ b/internal/reaper/reaper.go @@ -9,6 +9,8 @@ import ( "strings" "time" + "sync" + "sync/atomic" "github.com/cybertec-postgresql/pgwatch/v5/internal/cmdopts" @@ -37,6 +39,7 @@ type Reaper struct { measurementCh chan metrics.MeasurementEnvelope measurementCache *InstanceMetricCache logger log.Logger + mu sync.RWMutex monitoredSources sources.SourceConns prevLoopMonitoredDBs sources.SourceConns cancelFuncs map[string]context.CancelFunc @@ -97,7 +100,11 @@ func (r *Reaper) Reap(ctx context.Context) { // UpdateMonitoredDBCache(r.monitoredSources) hostsToShutDownDueToRoleChange := make(map[string]bool) // hosts went from master to standby and have "only if master" set - for _, monitoredSource := range r.monitoredSources { + r.mu.RLock() + srcs := slices.Clone(r.monitoredSources) + r.mu.RUnlock() + + for _, monitoredSource := range srcs { srcL := logger.WithField("source", monitoredSource.Name) ctx = log.WithLogger(ctx, srcL) @@ -418,7 +425,9 @@ func (r *Reaper) LoadSources(ctx context.Context) (err error) { r.logger.WithField("source", md.Name).Info("Source configs changed, restarting all gatherers...") r.ShutdownOldWorkers(ctx, map[string]bool{md.Name: true}) } + r.mu.Lock() r.monitoredSources = newSrcs + r.mu.Unlock() r.logger.WithField("sources", len(r.monitoredSources)).Info("sources refreshed") return nil } @@ -427,9 +436,12 @@ func (r *Reaper) LoadSources(ctx context.Context) (err error) { // every monitoredDbsDatastoreSyncIntervalSeconds (default 10min) func (r *Reaper) WriteMonitoredSources(ctx context.Context) { for { - if len(r.monitoredSources) > 0 { + r.mu.RLock() + srcs := slices.Clone(r.monitoredSources) + r.mu.RUnlock() + if len(srcs) > 0 { now := time.Now().UnixNano() - for _, mdb := range r.monitoredSources { + for _, mdb := range srcs { db := metrics.NewMeasurement(now) db["tag_group"] = mdb.Group db["master_only"] = mdb.OnlyIfMaster diff --git a/internal/reaper/reaper_test.go b/internal/reaper/reaper_test.go index ded3bea739..169a3ca750 100644 --- a/internal/reaper/reaper_test.go +++ b/internal/reaper/reaper_test.go @@ -5,6 +5,11 @@ import ( "os" "path/filepath" "testing" + "time" + "strings" + "fmt" + "github.com/sirupsen/logrus/hooks/test" + "errors" "github.com/cybertec-postgresql/pgwatch/v5/internal/cmdopts" "github.com/cybertec-postgresql/pgwatch/v5/internal/log" @@ -331,3 +336,541 @@ func TestReaper_LoadSources(t *testing.T) { assert.Nil(t, mockConn1.ExpectationsWereMet(), "Expected all mock expectations to be met") }) } + +func TestReaper_Ready(t *testing.T) { + ctx := context.Background() + r := NewReaper(ctx, &cmdopts.Options{}) + + assert.False(t, r.Ready()) + + r.ready.Store(true) + assert.True(t, r.Ready()) +} +func TestReaper_PrintMemStats(t *testing.T) { + ctx := log.WithLogger(context.Background(), log.NewNoopLogger()) + r := NewReaper(ctx, &cmdopts.Options{}) + + assert.NotPanics(t, func() { + r.PrintMemStats() + }) +} + + +// MockSinkWriter simulates the Sinks interface +type MockSinkWriter struct { + WriteCalled bool + LastMsg metrics.MeasurementEnvelope + SyncCalled bool + DeleteCalled bool + WriteError error +} + +func (m *MockSinkWriter) Write(msg metrics.MeasurementEnvelope) error { + m.WriteCalled = true + m.LastMsg = msg + return m.WriteError +} + +func (m *MockSinkWriter) SyncMetric(_, _ string, _ sinks.SyncOp) error { + m.SyncCalled = true + return nil +} + +func TestWriteMeasurements(t *testing.T) { + tests := []struct { + name string + writeError error + }{ + { + name: "Happy Path - Successful Write", + writeError: nil, + }, + { + name: "Error Path - Write Fails", + writeError: errors.New("something went wrong"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSink := &MockSinkWriter{ + WriteError: tt.writeError, + } + opts := &cmdopts.Options{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := NewReaper(ctx, opts) + r.SinksWriter = mockSink + + go r.WriteMeasurements(ctx) + + dummyMsg := metrics.MeasurementEnvelope{ + DBName: "test_db", + MetricName: "test_metric", + Data: metrics.Measurements{{"value": 1}}, + } + r.measurementCh <- dummyMsg + + // Allow brief time for channel processing + time.Sleep(50 * time.Millisecond) + + assert.True(t, mockSink.WriteCalled, "Sink Write should have been called") + assert.Equal(t, "test_db", mockSink.LastMsg.DBName) + assert.Equal(t, "test_metric", mockSink.LastMsg.MetricName) + + // Clean up + cancel() + }) + } +} + +func TestWriteMonitoredSources(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := NewReaper(ctx, &cmdopts.Options{}) + + // Add a fake monitored source + r.monitoredSources = append(r.monitoredSources, &sources.SourceConn{ + Source: sources.Source{ + Name: "test_db_source", + Group: "test_group", + OnlyIfMaster: true, + CustomTags: map[string]string{"env": "prod"}, + }, + }) + + go r.WriteMonitoredSources(ctx) + + // Listen to the channel to verify output + select { + case msg := <-r.measurementCh: + assert.Equal(t, "test_db_source", msg.DBName) + assert.Equal(t, monitoredDbsDatastoreSyncMetricName, msg.MetricName) + + assert.NotEmpty(t, msg.Data) + row := msg.Data[0] + + assert.Equal(t, "test_group", row["tag_group"]) + assert.Equal(t, true, row["master_only"]) + assert.Equal(t, "prod", row["tag_env"]) + + case <-time.After(1 * time.Second): + t.Fatal("Timed out waiting for WriteMonitoredSources to produce data") + } +} + +func TestAddSysinfoToMeasurements(t *testing.T) { + opts := &cmdopts.Options{} + opts.Sinks.RealDbnameField = "real_dbname" + opts.Sinks.SystemIdentifierField = "sys_id" + + r := NewReaper(context.Background(), opts) + + md := &sources.SourceConn{} + md.RealDbname = "postgres_prod" + md.SystemIdentifier = "123456789" + + data := metrics.Measurements{ + {"value": 10}, + {"value": 20}, + } + + r.AddSysinfoToMeasurements(data, md) + + for _, row := range data { + assert.Equal(t, "postgres_prod", row["real_dbname"]) + assert.Equal(t, "123456789", row["sys_id"]) + } + + assert.Equal(t, 10, data[0]["value"]) + assert.Equal(t, 20, data[1]["value"]) +} + +func TestFetchMetric_CacheHit(t *testing.T) { + ctx := context.Background() + opts := &cmdopts.Options{} + opts.Metrics.InstanceLevelCacheMaxSeconds = 60 + + r := NewReaper(ctx, opts) + metricName := "cached_metric" + + md := &sources.SourceConn{ + Source: sources.Source{ + Name: "db1", + }, + } + md.SystemIdentifier = "sys_id_123" + md.Metrics = map[string]float64{metricName: 10} + + // Setup Metric Definition + metricDefs.Lock() + metricDefs.MetricDefs[metricName] = metrics.Metric{ + SQLs: map[int]string{0: "SELECT 1"}, + IsInstanceLevel: true, + } + metricDefs.Unlock() + + // Pre-populate Cache + cacheKey := fmt.Sprintf("%s:%s", md.GetClusterIdentifier(), metricName) + cachedData := metrics.Measurements{{"cached_val": 999}} + + r.measurementCache.Put(cacheKey, cachedData) + + envelope, err := r.FetchMetric(ctx, md, metricName) + + assert.NoError(t, err) + + if envelope == nil { + t.Fatal("Cache Miss! FetchMetric returned nil envelope. Check GetClusterIdentifier() logic.") + } + + assert.Equal(t, metricName, envelope.MetricName) + assert.Equal(t, 999, envelope.Data[0]["cached_val"]) +} + +func TestFetchMetric_NotFound(t *testing.T) { + ctx := context.Background() + r := NewReaper(ctx, &cmdopts.Options{}) + md := &sources.SourceConn{Source: sources.Source{Name: "db1"}} + + // Execute with non-existent metric + envelope, err := r.FetchMetric(ctx, md, "ghost_metric") + + assert.Error(t, err) + assert.Equal(t, metrics.ErrMetricNotFound, err) + assert.Nil(t, envelope) +} + +func TestFetchMetric_EmptySQL_Ignored(t *testing.T) { + ctx := context.Background() + r := NewReaper(ctx, &cmdopts.Options{}) + md := &sources.SourceConn{Source: sources.Source{Name: "db1"}} + + metricName := "empty_sql_metric" + + metricDefs.Lock() + metricDefs.MetricDefs[metricName] = metrics.Metric{ + SQLs: map[int]string{}, + } + metricDefs.Unlock() + + envelope, err := r.FetchMetric(ctx, md, metricName) + + assert.NoError(t, err) + assert.Nil(t, envelope) +} + +func TestFetchMetric_SwitchCases_PgxMock(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + metricName string + setupMetric metrics.Metric + isInRecovery bool + mockDB func(mock pgxmock.PgxPoolIface) + expectNilEnv bool + expectErr bool + }{ + { + name: "Case: Skip PrimaryOnly metric on Standby", + metricName: "test_primary_metric", + setupMetric: metrics.Metric{ + NodeStatus: "primary", + }, + isInRecovery: true, + mockDB: func(_ pgxmock.PgxPoolIface) { + // We expect no queries because the function should exit early + }, + expectNilEnv: true, // Should return nil, nil + expectErr: false, + }, + { + name: "Case: Skip StandbyOnly metric on Primary", + metricName: "test_standby_metric", + setupMetric: metrics.Metric{ + NodeStatus: "standby", + }, + isInRecovery: false, + mockDB: func(_ pgxmock.PgxPoolIface) { + // We expect no queries because the function should exit early + }, + expectNilEnv: true, + expectErr: false, + }, + { + name: "Case: specialMetricInstanceUp", + metricName: specialMetricInstanceUp, + setupMetric: metrics.Metric{ + IsInstanceLevel: false, + }, + mockDB: func(mock pgxmock.PgxPoolIface) { + // The function executes a Ping to check if the DB is up + mock.ExpectPing() + }, + expectNilEnv: false, + expectErr: false, + }, +{ + name: "Case: specialMetricChangeEvents", + metricName: specialMetricChangeEvents, + setupMetric: metrics.Metric{ + IsInstanceLevel: false, + }, + mockDB: func(mock pgxmock.PgxPoolIface) { + // Inject dummy metric definitions so the Detect functions have SQL to run + metricDefs.Lock() + metricDefs.MetricDefs["sproc_hashes"] = metrics.Metric{SQLs: map[int]string{0: "SELECT dummy_sproc"}} + metricDefs.MetricDefs["table_hashes"] = metrics.Metric{SQLs: map[int]string{0: "SELECT dummy_table"}} + metricDefs.MetricDefs["index_hashes"] = metrics.Metric{SQLs: map[int]string{0: "SELECT dummy_index"}} + metricDefs.MetricDefs["configuration_hashes"] = metrics.Metric{SQLs: map[int]string{0: "SELECT dummy_config"}} + metricDefs.MetricDefs["privilege_changes"] = metrics.Metric{SQLs: map[int]string{0: "SELECT dummy_priv"}} + metricDefs.Unlock() + + // Expect all 5 queries in exact order. + mock.ExpectQuery("SELECT dummy_sproc"). + WillReturnRows(pgxmock.NewRows([]string{"tag_sproc", "tag_oid", "md5"})) + + mock.ExpectQuery("SELECT dummy_table"). + WillReturnRows(pgxmock.NewRows([]string{"tag_table", "md5"})) + + mock.ExpectQuery("SELECT dummy_index"). + WillReturnRows(pgxmock.NewRows([]string{"tag_index", "table", "md5", "is_valid"})) + + mock.ExpectQuery("SELECT dummy_config"). + WillReturnRows(pgxmock.NewRows([]string{"epoch", "objIdent", "objValue"})) + + mock.ExpectQuery("SELECT dummy_priv"). + WillReturnRows(pgxmock.NewRows([]string{"object_type", "tag_role", "tag_object", "privilege_type"})) + }, + expectNilEnv: true, + expectErr: false, + }, + { + name: "Case: Default with Valid SQL", + metricName: "default_valid_sql", + setupMetric: metrics.Metric{ + SQLs: map[int]string{0: "SELECT 1 AS val"}, + IsInstanceLevel: false, + }, + mockDB: func(mock pgxmock.PgxPoolIface) { + // Mocking the QueryMeasurements function + columns := []string{"val"} + mock.ExpectQuery("SELECT 1 AS val"). + WillReturnRows(pgxmock.NewRows(columns).AddRow(99)) + }, + expectNilEnv: false, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock, err := pgxmock.NewPool() + assert.NoError(t, err) + defer mock.Close() + + md := &sources.SourceConn{ + Source: sources.Source{ + Name: "db_test", + }, + Conn: mock, + } + md.SystemIdentifier = "sys_id_test" + md.Metrics = map[string]float64{tt.metricName: 10} + md.ChangeState = make(map[string]map[string]string) + md.IsInRecovery = tt.isInRecovery + + if tt.mockDB != nil { + tt.mockDB(mock) + } + + // Setup Reaper and Metrics + opts := &cmdopts.Options{} + r := NewReaper(ctx, opts) + + metricDefs.Lock() + if metricDefs.MetricDefs == nil { + metricDefs.MetricDefs = make(map[string]metrics.Metric) + } + metricDefs.MetricDefs[tt.metricName] = tt.setupMetric + metricDefs.Unlock() + + envelope, err := r.FetchMetric(ctx, md, tt.metricName) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if tt.expectNilEnv { + assert.Nil(t, envelope) + } else { + assert.NotNil(t, envelope) + assert.Equal(t, tt.metricName, envelope.MetricName) + } + + assert.NoError(t, mock.ExpectationsWereMet()) + }) + } +} + +func TestCreateSourceHelpers(t *testing.T) { + logger, hook := test.NewNullLogger() + ctx := context.Background() + + tests := []struct { + name string + sourceName string + sourceType sources.Kind + inRecovery bool + inPrevLoop bool + createHelpers bool + tryCreateExts string + + expectedLogMsgs []string + unexpectedLogMsgs []string + }{ + { + name: "Skip if Non-Postgres Source", + sourceName: "PgBouncer", + sourceType: sources.SourcePgBouncer, + inRecovery: false, + createHelpers: true, + unexpectedLogMsgs: []string{"trying to create helper objects"}, + }, + { + name: "Skip if In Recovery", + sourceName: "standby_db", + sourceType: sources.SourcePostgres, + inRecovery: true, + createHelpers: true, + unexpectedLogMsgs: []string{"trying to create helper objects"}, + }, + { + name: "Skip if Already Created", + sourceName: "existing_db", + sourceType: sources.SourcePostgres, + inRecovery: false, + inPrevLoop: true, + createHelpers: true, + tryCreateExts: "plpythonu", + unexpectedLogMsgs: []string{ + "trying to create helper objects", + "trying to create extensions", + }, + }, + { + name: "Happy Path: Create Helpers", + sourceName: "fresh_primary", + sourceType: sources.SourcePostgres, + inRecovery: false, + inPrevLoop: false, + createHelpers: true, + expectedLogMsgs: []string{"trying to create helper objects if missing"}, + }, + { + name: "Happy Path: Create Extensions", + sourceName: "fresh_primary_ext", + sourceType: sources.SourcePostgres, + inRecovery: false, + inPrevLoop: false, + tryCreateExts: "pg_stat_statements", + expectedLogMsgs: []string{"trying to create extensions if missing"}, + }, + { + name: "Happy Path: Create Both", + sourceName: "fresh_primary_full", + sourceType: sources.SourcePostgres, + inRecovery: false, + inPrevLoop: false, + createHelpers: true, + tryCreateExts: "pg_stat_statements", + expectedLogMsgs: []string{ + "trying to create helper objects if missing", + "trying to create extensions if missing", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hook.Reset() + + opts := &cmdopts.Options{} + r := &Reaper{ + Options: opts, + logger: logger, + prevLoopMonitoredDBs: make(sources.SourceConns, 0), + } + + r.Sources.CreateHelpers = tt.createHelpers + r.Sources.TryCreateListedExtsIfMissing = tt.tryCreateExts + + src := &sources.SourceConn{ + Source: sources.Source{ + Name: tt.sourceName, + Kind: tt.sourceType, + }, + } + src.IsInRecovery = tt.inRecovery + + if tt.inPrevLoop { + r.prevLoopMonitoredDBs = append(r.prevLoopMonitoredDBs, src) + } + + mock, err := pgxmock.NewPool() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer mock.Close() + + src.Conn = mock + + // Setup Expectations for "Happy Path" + if strings.Contains(tt.name, "Happy Path") { + // Determine what to expect based on the test case + if strings.Contains(tt.name, "Extensions") || strings.Contains(tt.name, "Both") { + rows := pgxmock.NewRows([]string{"name"}).AddRow("pg_stat_statements").AddRow("plpythonu") + mock.ExpectQuery("select name::text from pg_available_extensions").WillReturnRows(rows) + + mock.ExpectQuery("create extension .*").WillReturnRows(pgxmock.NewRows([]string{}))} + } + + srcL := logger.WithField("source", src.Name) + + assert.NotPanics(t, func() { + r.CreateSourceHelpers(ctx, srcL, src) + }) + + entries := hook.AllEntries() + + // Check Expected Messages + for _, exp := range tt.expectedLogMsgs { + found := false + for _, e := range entries { + if strings.Contains(e.Message, exp) { + found = true + break + } + } + assert.True(t, found, "Expected log message not found: %s", exp) + } + + // Check Unexpected Messages + for _, unexp := range tt.unexpectedLogMsgs { + found := false + for _, e := range entries { + if strings.Contains(e.Message, unexp) { + found = true + break + } + } + assert.False(t, found, "Found unexpected log message: %s", unexp) + } + }) + } +} diff --git a/internal/webserver/wslog_test.go b/internal/webserver/wslog_test.go index dcc1f5ddf3..2a6e02da63 100644 --- a/internal/webserver/wslog_test.go +++ b/internal/webserver/wslog_test.go @@ -41,14 +41,39 @@ func TestServeWsLog_Success(t *testing.T) { // send ping message to keep connection alive assert.NoError(t, ws.WriteMessage(websocket.PingMessage, nil)) + stopChan := make(chan struct{}) + go func() { + for { + select { + case <-stopChan: + return + default: + ts.Info("Test message") + time.Sleep(50 * time.Millisecond) + } + } + }() - // send some log message - time.Sleep(100 * time.Millisecond) - ts.Info("Test message") - // check output though the websocket - assert.NoError(t, ws.SetReadDeadline(time.Now().Add(2*time.Second))) - msgType, msg, err := ws.ReadMessage() - assert.NoError(t, err) + // Set a 5-second deadline for the initial read to succeed + assert.NoError(t, ws.SetReadDeadline(time.Now().Add(5*time.Second))) + + var msgType int + var msg []byte + + // Block and wait to receive the expected message + for { + tpe, m, err := ws.ReadMessage() + require.NoError(t, err, "Websocket read failed or timed out") + + if strings.Contains(string(m), "Test message") { + msgType = tpe + msg = m + break + } + } + + // Stop the background log spammer now that we caught our message + close(stopChan) assert.Equal(t, websocket.TextMessage, msgType) assert.NotEmpty(t, msg) assert.Contains(t, string(msg), "Test message")