Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/f5devcentral/go-bigip v0.0.0-20250731061239-628be0470a84
github.com/f5devcentral/go-bigip/f5teem v0.0.0-20250731061239-628be0470a84
github.com/f5devcentral/mockhttpclient v0.0.0-20210630101009-cc12e8b81051
github.com/fsnotify/fsnotify v1.4.9
github.com/google/uuid v1.3.0
github.com/miekg/dns v1.1.42
github.com/onsi/ginkgo/v2 v2.19.1
Expand Down
163 changes: 112 additions & 51 deletions pkg/controller/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ import (
"github.com/F5Networks/k8s-bigip-ctlr/v2/pkg/health"
bigIPPrometheus "github.com/F5Networks/k8s-bigip-ctlr/v2/pkg/prometheus"
log "github.com/F5Networks/k8s-bigip-ctlr/v2/pkg/vlogger"
"github.com/fsnotify/fsnotify"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -28,50 +31,100 @@ type webHook struct {

func (ctlr *Controller) startWebhook() {
webhookServerOnce.Do(func() {
// Initial cert load
cert, err := loadAndValidateTLSCertificate(certFile, keyFile)
if err != nil {
log.Errorf("[Webhook] TLS cert load failed: %v", err)
return
}

// This will be updated when cert changes
var currentCert atomic.Value
currentCert.Store(cert)

// Watch for changes
go watchCertFiles(certFile, keyFile, func() {
newCert, err := loadAndValidateTLSCertificate(certFile, keyFile)
if err != nil {
log.Errorf("[Webhook] Failed to reload webhook TLS cert: %v", err)
return
}
currentCert.Store(newCert)
log.Debugf("[Webhook] TLS cert reloaded")
})

tlsCfg := &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
c := currentCert.Load().(tls.Certificate)
return &c, nil
},
}

webhookMux := http.NewServeMux()
webhookMux.HandleFunc("/mutate", ctlr.handleMutate)
webhookMux.HandleFunc("/validate", ctlr.handleValidate)
ctlr.webhookServer = webHook{
Server: &http.Server{
Addr: ctlr.agentParams.HttpsAddress,
Handler: webhookMux,
Addr: ctlr.agentParams.HttpsAddress,
Handler: webhookMux,
TLSConfig: tlsCfg,
},
address: ctlr.agentParams.HttpsAddress,
}
webhookShutdownCh := make(chan struct{})

// Check cert/key existence and validity before starting server
if _, err := os.Stat(certFile); err != nil {
log.Errorf("Webhook server failed as TLS certificate file not found: %s, error: %v", certFile, err)
return
}
if _, err := os.Stat(keyFile); err != nil {
log.Errorf("Webhook server failed as TLS key file not found: %s, error: %v", keyFile, err)
return
}
if err := validateTLSCertificate(certFile, keyFile); err != nil {
log.Errorf("Webhook server failed as Invalid TLS certificate or key: %v", err)
return
}

// Graceful shutdown goroutine
go func() {
<-webhookShutdownCh
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := ctlr.webhookServer.GetWebhookServer().Shutdown(ctx); err != nil {
log.Errorf("Webhook server graceful shutdown failed: %v", err)
log.Errorf("[Webhook] server graceful shutdown failed: %v", err)
} else {
log.Infof("Webhook server gracefully stopped")
log.Infof("[Webhook] server gracefully stopped")
}
}()
log.Infof("Starting webhook server on :%s", ctlr.agentParams.HttpsAddress)
log.Infof("[Webhook] starting webhook server on :%s", ctlr.agentParams.HttpsAddress)
if err := ctlr.webhookServer.GetWebhookServer().ListenAndServeTLS(certFile, keyFile); err != nil && err != http.ErrServerClosed {
log.Errorf("Webhook server failed: %v", err)
}
})
}

// loadAndValidateTLSCertificate reads and validates the TLS certificate and key files.
func loadAndValidateTLSCertificate(certPath, keyPath string) (tls.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not read cert file: %w", err)
}

_, err = os.ReadFile(keyPath)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not read key file: %w", err)
}

block, _ := pem.Decode(certPEM)
if block == nil {
return tls.Certificate{}, fmt.Errorf("failed to parse certificate PEM")
}

parsedCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to parse certificate: %w", err)
}

now := time.Now()
if now.Before(parsedCert.NotBefore) {
return tls.Certificate{}, fmt.Errorf("certificate is not valid yet (NotBefore: %v)", parsedCert.NotBefore)
}
if now.After(parsedCert.NotAfter) {
return tls.Certificate{}, fmt.Errorf("certificate is expired (NotAfter: %v)", parsedCert.NotAfter)
}

// If valid, load keypair as usual
return tls.LoadX509KeyPair(certPath, keyPath)
}

func (ctlr *Controller) CISHealthCheck() {
healthCheckOnce.Do(func() {
healthMux := http.NewServeMux()
Expand Down Expand Up @@ -114,38 +167,6 @@ func (ctlr *Controller) CISHealthCheck() {
})
}

// validateTLSCertificate checks if the cert/key files are valid and not expired
func validateTLSCertificate(certPath, keyPath string) error {
cert, err := os.ReadFile(certPath)
if err != nil {
return fmt.Errorf("could not read cert file: %w", err)
}
key, err := os.ReadFile(keyPath)
if err != nil {
return fmt.Errorf("could not read key file: %w", err)
}
_, err = tls.X509KeyPair(cert, key)
if err != nil {
return fmt.Errorf("invalid TLS key pair: %w", err)
}
// Check for expiration
block, _ := pem.Decode(cert)
if block == nil {
return fmt.Errorf("failed to parse certificate PEM")
}
parsedCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse certificate: %w", err)
}
if time.Now().After(parsedCert.NotAfter) {
return fmt.Errorf("certificate is expired (NotAfter: %v)", parsedCert.NotAfter)
}
if time.Now().Before(parsedCert.NotBefore) {
return fmt.Errorf("certificate is not valid yet (NotBefore: %v)", parsedCert.NotBefore)
}
return nil
}

func (ctlr *Controller) CISHealthCheckHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clusterConfig := ctlr.multiClusterHandler.getClusterConfig(ctlr.multiClusterHandler.LocalClusterName)
Expand Down Expand Up @@ -188,3 +209,43 @@ func (w webHook) IsWebhookServerRunning() bool {
func (w webHook) GetWebhookServer() *http.Server {
return w.Server
}

// watchCertFiles monitors the certificate and key files for changes and reloads them when modified.
func watchCertFiles(certPath, keyPath string, certsReload func()) {
absCertPath, _ := filepath.Abs(certPath)
absKeyPath, _ := filepath.Abs(keyPath)

watcher, err := fsnotify.NewWatcher()
if err != nil {
fmt.Printf("[Webhook] fsnotify init failed: %v\n", err)
return
}

defer watcher.Close()

certDir := filepath.Dir(absCertPath)
keyDir := filepath.Dir(absKeyPath)
_ = watcher.Add(certDir)

if certDir != keyDir {
_ = watcher.Add(keyDir)
}

log.Debugf("[Webhook] Watching certificate file: %s and key file: %s for changes...", certPath, keyPath)

for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) != 0 {
certsReload()
}
case err, ok := <-watcher.Errors:
if ok {
log.Errorf("[Webhook] fsnotify error: %v\n", err)
}
}
}
}
Loading