diff --git a/roundrobin/rr.go b/roundrobin/rr.go index 631a97af..fe9e0295 100644 --- a/roundrobin/rr.go +++ b/roundrobin/rr.go @@ -46,17 +46,26 @@ func RoundRobinRequestRewriteListener(rrl RequestRewriteListener) LBOption { } } +// RoundRobinPreRequestRewriteListener is a functional argument that sets error handler of the server +func RoundRobinPreRequestRewriteListener(rrl RequestRewriteListener) LBOption { + return func(s *RoundRobin) error { + s.requestPreRewriteListener = rrl + return nil + } +} + // RoundRobin implements dynamic weighted round robin load balancer http handler type RoundRobin struct { mutex *sync.Mutex next http.Handler errHandler utils.ErrorHandler // Current index (starts from -1) - index int - servers []*server - currentWeight int - stickySession *StickySession - requestRewriteListener RequestRewriteListener + index int + servers []*server + currentWeight int + stickySession *StickySession + requestPreRewriteListener RequestRewriteListener + requestRewriteListener RequestRewriteListener log *log.Logger } @@ -107,6 +116,11 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { // make shallow copy of request before chaning anything to avoid side effects newReq := *req + // Emit event to a listener if one exists + if r.requestPreRewriteListener != nil { + r.requestPreRewriteListener(req, &newReq) + } + stuck := false if r.stickySession != nil { cookieURL, present, err := r.stickySession.GetBackend(&newReq, r.Servers()) @@ -116,6 +130,7 @@ func (r *RoundRobin) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if present { + r.stickySession.StickBackend(cookieURL, &w) newReq.URL = cookieURL stuck = true } diff --git a/roundrobin/rr_test.go b/roundrobin/rr_test.go index 2f2841ca..c2df0a88 100644 --- a/roundrobin/rr_test.go +++ b/roundrobin/rr_test.go @@ -228,6 +228,23 @@ func TestRequestRewriteListener(t *testing.T) { assert.NotNil(t, lb.requestRewriteListener) } +func TestPreRequestRewriteListener(t *testing.T) { + a := testutils.NewResponder("a") + defer a.Close() + + b := testutils.NewResponder("b") + defer b.Close() + + fwd, err := forward.New() + require.NoError(t, err) + + lb, err := New(fwd, + RoundRobinPreRequestRewriteListener(func(oldReq *http.Request, newReq *http.Request) {})) + require.NoError(t, err) + + assert.NotNil(t, lb.requestPreRewriteListener) +} + func seq(t *testing.T, url string, repeat int) []string { var out []string for i := 0; i < repeat; i++ {