diff --git a/bridge.go b/bridge.go index 9f2c424..84d8af6 100644 --- a/bridge.go +++ b/bridge.go @@ -9,9 +9,9 @@ import ( "cdr.dev/slog" "github.com/coder/aibridge/mcp" - "go.opentelemetry.io/otel/trace" - "github.com/hashicorp/go-multierror" + "github.com/sony/gobreaker/v2" + "go.opentelemetry.io/otel/trace" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -48,13 +48,42 @@ var _ http.Handler = &RequestBridge{} // A [Recorder] is also required to record prompt, tool, and token use. // // mcpProxy will be closed when the [RequestBridge] is closed. +// +// Circuit breaker configuration is obtained from each provider's CircuitBreakerConfig() method. +// Providers returning nil will not have circuit breaker protection. func NewRequestBridge(ctx context.Context, providers []Provider, recorder Recorder, mcpProxy mcp.ServerProxier, logger slog.Logger, metrics *Metrics, tracer trace.Tracer) (*RequestBridge, error) { mux := http.NewServeMux() for _, provider := range providers { + // Create per-provider circuit breaker if configured + var cbs *ProviderCircuitBreakers + onChange := func(endpoint string, from, to gobreaker.State) {} + + if cfg := provider.CircuitBreakerConfig(); cfg != nil && metrics != nil { + providerName := provider.Name() + onChange = func(endpoint string, from, to gobreaker.State) { + metrics.CircuitBreakerState.WithLabelValues(providerName, endpoint).Set(stateToGaugeValue(to)) + if to == gobreaker.StateOpen { + metrics.CircuitBreakerTrips.WithLabelValues(providerName, endpoint).Inc() + } + } + } + + cbs = NewProviderCircuitBreakers(provider.Name(), *cfg, onChange) + + // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). + for _, path := range provider.BridgedRoutes() { + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware (nil cbs passes through) + wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) + mux.Handle(path, wrapped) + } // Add the known provider-specific routes which are bridged (i.e. intercepted and augmented). for _, path := range provider.BridgedRoutes() { - mux.HandleFunc(path, newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer)) + handler := newInterceptionProcessor(provider, recorder, mcpProxy, logger, metrics, tracer) + // Wrap with circuit breaker middleware (nil cbs passes through) + wrapped := CircuitBreakerMiddleware(cbs, metrics, provider.Name())(handler) + mux.Handle(path, wrapped) } // Any requests which passthrough to this will be reverse-proxied to the upstream. diff --git a/circuit_breaker.go b/circuit_breaker.go new file mode 100644 index 0000000..2adbf18 --- /dev/null +++ b/circuit_breaker.go @@ -0,0 +1,182 @@ +package aibridge + +import ( + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/sony/gobreaker/v2" +) + +// CircuitBreakerConfig holds configuration for circuit breakers. +// Fields match gobreaker.Settings for clarity. +type CircuitBreakerConfig struct { + // MaxRequests is the maximum number of requests allowed in half-open state. + MaxRequests uint32 + // Interval is the cyclic period of the closed state for clearing internal counts. + Interval time.Duration + // Timeout is how long the circuit stays open before transitioning to half-open. + Timeout time.Duration + // FailureThreshold is the number of consecutive failures that triggers the circuit to open. + FailureThreshold uint32 + // IsFailure determines if a status code should count as a failure. + // If nil, defaults to 429, 503, and 529 (Anthropic overloaded). + IsFailure func(statusCode int) bool +} + +// DefaultCircuitBreakerConfig returns sensible defaults for circuit breaker configuration. +func DefaultCircuitBreakerConfig() CircuitBreakerConfig { + return CircuitBreakerConfig{ + FailureThreshold: 5, + Interval: 10 * time.Second, + Timeout: 30 * time.Second, + MaxRequests: 3, + IsFailure: DefaultIsFailure, + } +} + +// DefaultIsFailure returns true for status codes that typically indicate +// upstream overload: 429 (Too Many Requests), 503 (Service Unavailable), +// and 529 (Anthropic Overloaded). +func DefaultIsFailure(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, // 429 + http.StatusServiceUnavailable, // 503 + 529: // Anthropic "Overloaded" + return true + default: + return false + } +} + +// ProviderCircuitBreakers manages per-endpoint circuit breakers for a single provider. +type ProviderCircuitBreakers struct { + provider string + config CircuitBreakerConfig + breakers sync.Map // endpoint -> *gobreaker.CircuitBreaker[struct{}] + onChange func(endpoint string, from, to gobreaker.State) +} + +// NewProviderCircuitBreakers creates circuit breakers for a single provider. +func NewProviderCircuitBreakers(provider string, config CircuitBreakerConfig, onChange func(endpoint string, from, to gobreaker.State)) *ProviderCircuitBreakers { + if config.IsFailure == nil { + config.IsFailure = DefaultIsFailure + } + return &ProviderCircuitBreakers{ + provider: provider, + config: config, + onChange: onChange, + } +} + +// Get returns the circuit breaker for an endpoint, creating it if needed. +func (p *ProviderCircuitBreakers) Get(endpoint string) *gobreaker.CircuitBreaker[struct{}] { + if v, ok := p.breakers.Load(endpoint); ok { + return v.(*gobreaker.CircuitBreaker[struct{}]) + } + + settings := gobreaker.Settings{ + Name: p.provider + ":" + endpoint, + MaxRequests: p.config.MaxRequests, + Interval: p.config.Interval, + Timeout: p.config.Timeout, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= p.config.FailureThreshold + }, + OnStateChange: func(_ string, from, to gobreaker.State) { + p.onChange(endpoint, from, to) + }, + } + + cb := gobreaker.NewCircuitBreaker[struct{}](settings) + actual, _ := p.breakers.LoadOrStore(endpoint, cb) + return actual.(*gobreaker.CircuitBreaker[struct{}]) +} + +// statusCapturingWriter wraps http.ResponseWriter to capture the status code. +// It also implements http.Flusher to support streaming responses. +type statusCapturingWriter struct { + http.ResponseWriter + statusCode int + headerWritten bool +} + +func (w *statusCapturingWriter) WriteHeader(code int) { + if !w.headerWritten { + w.statusCode = code + w.headerWritten = true + } + w.ResponseWriter.WriteHeader(code) +} + +func (w *statusCapturingWriter) Write(b []byte) (int, error) { + if !w.headerWritten { + w.statusCode = http.StatusOK + w.headerWritten = true + } + return w.ResponseWriter.Write(b) +} + +func (w *statusCapturingWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// Unwrap returns the underlying ResponseWriter for interface checks. +func (w *statusCapturingWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +// CircuitBreakerMiddleware returns middleware that wraps handlers with circuit breaker protection. +// It captures the response status code to determine success/failure without provider-specific logic. +// If cbs is nil, requests pass through without circuit breaker protection. +func CircuitBreakerMiddleware(cbs *ProviderCircuitBreakers, metrics *Metrics, provider string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + // No circuit breaker configured - pass through + if cbs == nil { + return next + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + endpoint := strings.TrimPrefix(r.URL.Path, "/"+provider) + cb := cbs.Get(endpoint) + + // Wrap response writer to capture status code + sw := &statusCapturingWriter{ResponseWriter: w, statusCode: http.StatusOK} + + _, err := cb.Execute(func() (struct{}, error) { + next.ServeHTTP(sw, r) + if cbs.config.IsFailure(sw.statusCode) { + return struct{}{}, fmt.Errorf("upstream error: %d", sw.statusCode) + } + return struct{}{}, nil + }) + + if err != nil && (errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests)) { + if metrics != nil { + metrics.CircuitBreakerRejects.WithLabelValues(provider, endpoint).Inc() + } + http.Error(w, "circuit breaker is open", http.StatusServiceUnavailable) + } + }) + } +} + +// stateToGaugeValue converts gobreaker.State to a gauge value. +// closed=0, half-open=0.5, open=1 +func stateToGaugeValue(s gobreaker.State) float64 { + switch s { + case gobreaker.StateClosed: + return 0 + case gobreaker.StateHalfOpen: + return 0.5 + case gobreaker.StateOpen: + return 1 + default: + return 0 + } +} diff --git a/circuit_breaker_test.go b/circuit_breaker_test.go new file mode 100644 index 0000000..95a7c78 --- /dev/null +++ b/circuit_breaker_test.go @@ -0,0 +1,259 @@ +package aibridge + +import ( + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/sony/gobreaker/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreakerMiddleware_TripsOnUpstreamErrors(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + // Mock upstream that returns 429 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // Create circuit breaker with low threshold + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) + + // Wrap upstream with circuit breaker middleware + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First 2 requests hit upstream, get 429 + for i := 0; i < 2; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(2), upstreamCalls.Load()) + + // Third request should get 503 "circuit breaker is open" without hitting upstream + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Contains(t, string(body), "circuit breaker is open") + assert.Equal(t, int32(2), upstreamCalls.Load()) // No new upstream call + + // Wait for timeout, verify recovery (circuit transitions to half-open) + require.Eventually(t, func() bool { + resp, err = http.Get(server.URL + "/test/v1/messages") + if err != nil { + return false + } + resp.Body.Close() + // Request hit upstream again (half-open state allows probe request) + return upstreamCalls.Load() == 3 + }, 5*time.Second, 25*time.Millisecond) +} + +func TestCircuitBreakerMiddleware_PerEndpointIsolation(t *testing.T) { + t.Parallel() + + chatCalls := atomic.Int32{} + responsesCalls := atomic.Int32{} + + // Mock upstream - /chat returns 429, /responses returns 200 + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/test/v1/chat/completions" { + chatCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + } else { + responsesCalls.Add(1) + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip circuit on /chat/completions + resp, err := http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + + // /chat/completions should now be blocked + resp, err = http.Get(server.URL + "/test/v1/chat/completions") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, int32(1), chatCalls.Load()) // Only 1 call, second was blocked + + // /responses should still work + resp, err = http.Get(server.URL + "/test/v1/responses") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(1), responsesCalls.Load()) +} + +func TestCircuitBreakerMiddleware_NotConfigured(t *testing.T) { + t.Parallel() + + var upstreamCalls atomic.Int32 + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upstreamCalls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + }) + + // No circuit breaker configured (nil) + handler := CircuitBreakerMiddleware(nil, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // All requests should pass through even with 429s + for i := 0; i < 10; i++ { + resp, err := http.Get(server.URL + "/test/v1/messages") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + } + assert.Equal(t, int32(10), upstreamCalls.Load()) +} + +func TestCircuitBreakerMiddleware_RecoveryAfterSuccess(t *testing.T) { + t.Parallel() + + var returnError atomic.Bool + returnError.Store(true) + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if returnError.Load() { + w.WriteHeader(http.StatusTooManyRequests) + } else { + w.WriteHeader(http.StatusOK) + } + }) + + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 2, + Interval: time.Minute, + Timeout: 50 * time.Millisecond, + MaxRequests: 1, + }, func(endpoint string, from, to gobreaker.State) {}) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // Trip the circuit + for i := 0; i < 2; i++ { + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + } + + // Circuit should be open + resp, _ := http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() + + // Switch upstream to success before we start polling + returnError.Store(false) + + // Wait for timeout (circuit transitions to half-open), then verify recovery + require.Eventually(t, func() bool { + resp, err := http.Get(server.URL + "/test/v1/messages") + if err != nil { + return false + } + defer resp.Body.Close() + // Half-open: request goes through and succeeds + return resp.StatusCode == http.StatusOK + }, 5*time.Second, 25*time.Millisecond) + + // Circuit should be closed now, more requests allowed + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusOK, resp.StatusCode) + resp.Body.Close() +} + +func TestCircuitBreakerMiddleware_CustomIsFailure(t *testing.T) { + t.Parallel() + + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 + }) + + // Custom IsFailure that treats 502 as failure + cbs := NewProviderCircuitBreakers("test", CircuitBreakerConfig{ + FailureThreshold: 1, + Interval: time.Minute, + Timeout: time.Minute, + MaxRequests: 1, + IsFailure: func(statusCode int) bool { + return statusCode == http.StatusBadGateway + }, + }, func(endpoint string, from, to gobreaker.State) {}) + + handler := CircuitBreakerMiddleware(cbs, nil, "test")(upstream) + server := httptest.NewServer(handler) + defer server.Close() + + // First request returns 502, trips circuit + resp, _ := http.Get(server.URL + "/test/v1/messages") + resp.Body.Close() + + // Second request should be blocked + resp, _ = http.Get(server.URL + "/test/v1/messages") + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + resp.Body.Close() +} + +func TestDefaultIsFailure(t *testing.T) { + t.Parallel() + + tests := []struct { + statusCode int + isFailure bool + }{ + {http.StatusOK, false}, + {http.StatusBadRequest, false}, + {http.StatusUnauthorized, false}, + {http.StatusTooManyRequests, true}, // 429 + {http.StatusInternalServerError, false}, + {http.StatusBadGateway, false}, + {http.StatusServiceUnavailable, true}, // 503 + {529, true}, // Anthropic Overloaded + } + + for _, tt := range tests { + assert.Equal(t, tt.isFailure, DefaultIsFailure(tt.statusCode), "status code %d", tt.statusCode) + } +} + +func TestStateToGaugeValue(t *testing.T) { + t.Parallel() + + assert.Equal(t, float64(0), stateToGaugeValue(gobreaker.StateClosed)) + assert.Equal(t, float64(0.5), stateToGaugeValue(gobreaker.StateHalfOpen)) + assert.Equal(t, float64(1), stateToGaugeValue(gobreaker.StateOpen)) +} diff --git a/config.go b/config.go index 8dc6f1d..ff1f639 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,8 @@ package aibridge type ProviderConfig struct { - BaseURL, Key string + BaseURL, Key string + CircuitBreaker *CircuitBreakerConfig } type ( diff --git a/go.mod b/go.mod index 9a62089..329c7bf 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 github.com/prometheus/client_golang v1.23.2 + github.com/sony/gobreaker/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 385345d..fff1ee3 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sony/gobreaker/v2 v2.3.0 h1:7VYxZ69QXRQ2Q4eEawHn6eU4FiuwovzJwsUMA03Lu4I= +github.com/sony/gobreaker/v2 v2.3.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= diff --git a/interception.go b/interception.go index 46ec7bd..62201aa 100644 --- a/interception.go +++ b/interception.go @@ -45,6 +45,8 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server ctx, span := tracer.Start(r.Context(), "Intercept") defer span.End() + route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) + interceptor, err := p.CreateInterceptor(w, r.WithContext(ctx), tracer) if err != nil { span.SetStatus(codes.Error, fmt.Sprintf("failed to create interceptor: %v", err)) @@ -93,7 +95,6 @@ func newInterceptionProcessor(p Provider, recorder Recorder, mcpProxy mcp.Server return } - route := strings.TrimPrefix(r.URL.Path, fmt.Sprintf("/%s", p.Name())) log := logger.With( slog.F("route", route), slog.F("provider", p.Name()), diff --git a/metrics.go b/metrics.go index 32d5a78..e7fbb5c 100644 --- a/metrics.go +++ b/metrics.go @@ -28,6 +28,11 @@ type Metrics struct { // Tool-related metrics. InjectedToolUseCount *prometheus.CounterVec NonInjectedToolUseCount *prometheus.CounterVec + + // Circuit breaker metrics. + CircuitBreakerState *prometheus.GaugeVec // Current state (0=closed, 0.5=half-open, 1=open) + CircuitBreakerTrips *prometheus.CounterVec // Total times circuit opened + CircuitBreakerRejects *prometheus.CounterVec // Requests rejected due to open circuit } // NewMetrics creates AND registers metrics. It will panic if a collector has already been registered. @@ -102,5 +107,26 @@ func NewMetrics(reg prometheus.Registerer) *Metrics { Name: "total", Help: "The number of times an AI model selected a tool to be invoked by the client.", }, append(baseLabels, "name")), + + // Circuit breaker metrics. + + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerState: promauto.With(reg).NewGaugeVec(prometheus.GaugeOpts{ + Subsystem: "circuit_breaker", + Name: "state", + Help: "Current state of the circuit breaker (0=closed, 0.5=half-open, 1=open).", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerTrips: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "trips_total", + Help: "Total number of times the circuit breaker has tripped open.", + }, []string{"provider", "endpoint"}), + // Pessimistic cardinality: 2 providers, 5 endpoints = up to 10. + CircuitBreakerRejects: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ + Subsystem: "circuit_breaker", + Name: "rejects_total", + Help: "Total number of requests rejected due to open circuit breaker.", + }, []string{"provider", "endpoint"}), } } diff --git a/provider.go b/provider.go index 20f8f52..a2d1d87 100644 --- a/provider.go +++ b/provider.go @@ -33,4 +33,7 @@ type Provider interface { AuthHeader() string // InjectAuthHeader allows [Provider]s to set its authentication header. InjectAuthHeader(*http.Header) + + // CircuitBreakerConfig returns the circuit breaker configuration for the provider. + CircuitBreakerConfig() *CircuitBreakerConfig } diff --git a/provider_anthropic.go b/provider_anthropic.go index fb5d10b..d07502e 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -108,6 +108,10 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } +func (p *AnthropicProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.cfg.CircuitBreaker +} + func getAnthropicErrorResponse(err error) *AnthropicErrorResponse { var apierr *anthropic.Error if !errors.As(err, &apierr) { diff --git a/provider_openai.go b/provider_openai.go index 68777e7..65288f6 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -17,7 +17,8 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + baseURL, key string + circuitBreaker *CircuitBreakerConfig } const ( @@ -36,8 +37,9 @@ func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { } return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + baseURL: cfg.BaseURL, + key: cfg.Key, + circuitBreaker: cfg.CircuitBreaker, } } @@ -108,3 +110,7 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.key) } + +func (p *OpenAIProvider) CircuitBreakerConfig() *CircuitBreakerConfig { + return p.circuitBreaker +}