diff --git a/micro/example_test.go b/micro/example_test.go index a1fcd5ab6..57ee6534d 100644 --- a/micro/example_test.go +++ b/micro/example_test.go @@ -306,6 +306,39 @@ func ExampleContextHandler() { defer srv.Stop() } +func ExampleMiddleware() { + middle := func(h micro.Handler) micro.Handler { + return micro.HandlerFunc(func(r micro.Request) { + log.Println("in middleware") + r.SetHeader("request-id", "1234") + h.Handle(r) + }) + } + + nc, err := nats.Connect("127.0.0.1:4222") + if err != nil { + log.Fatal(err) + } + defer nc.Close() + + handler := func(req micro.Request) { + log.Println(req.Headers().Get("request-id")) + req.Respond(req.Data()) + } + config := micro.Config{ + Name: "MiddlewareExample", + Version: "0.1.0", + Endpoint: µ.EndpointConfig{ + Subject: "middleware", + Handler: micro.HandlerFunc(handler), + }, + Middleware: []micro.Middleware{middle}, + } + + srv, _ := micro.AddService(nc, config) + defer srv.Stop() +} + func ExampleControlSubject() { // subject used to get PING from all services diff --git a/micro/request.go b/micro/request.go index 380f4945c..51fe9c924 100644 --- a/micro/request.go +++ b/micro/request.go @@ -56,6 +56,9 @@ type ( // Headers returns request headers. Headers() Headers + // SetHeader sets a specific header + SetHeader(key, value string) + // Subject returns underlying NATS message subject. Subject() string } @@ -181,6 +184,17 @@ func (r *request) Headers() Headers { return Headers(r.msg.Header) } +func (r *request) SetHeader(key string, value string) { + if len(r.msg.Header) != 0 { + r.msg.Header.Set(key, value) + return + } + headers := nats.Header{ + key: []string{value}, + } + r.msg.Header = headers +} + // Subject returns underlying NATS message subject. func (r *request) Subject() string { return r.msg.Subject diff --git a/micro/service.go b/micro/service.go index d6a8c82f4..5555eee64 100644 --- a/micro/service.go +++ b/micro/service.go @@ -101,6 +101,8 @@ type ( Metadata map[string]string `json:"metadata"` } + Middleware func(Handler) Handler + // Stats is the type returned by STATS monitoring endpoint. // It contains stats of all registered endpoints. Stats struct { @@ -195,6 +197,9 @@ type ( // ErrorHandler is invoked on any nats-related service error. ErrorHandler ErrHandler + + // Middleware is a slice of handlers that should be run on every request + Middleware []Middleware } EndpointConfig struct { @@ -400,6 +405,9 @@ func (s *service) AddEndpoint(name string, handler Handler, opts ...EndpointOpt) subject = options.subject } queueGroup := queueGroupName(options.queueGroup, s.Config.QueueGroup) + for _, v := range s.Middleware { + handler = v(handler) + } return addEndpoint(s, name, subject, handler, options.metadata, queueGroup) } diff --git a/micro/test/service_test.go b/micro/test/service_test.go index b9e004946..e0f879a3c 100644 --- a/micro/test/service_test.go +++ b/micro/test/service_test.go @@ -1018,6 +1018,64 @@ func TestContextHandler(t *testing.T) { } } +func TestMiddleware(t *testing.T) { + s := RunServerOnPort(-1) + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Expected to connect to server, got %v", err) + } + defer nc.Close() + + handler := func(req micro.Request) { + headers := micro.Headers{ + "request-id": []string{req.Headers().Get("request-id")}, + "other-value": []string{req.Headers().Get("other-value")}, + } + req.Respond(req.Data(), micro.WithHeaders(headers)) + } + + middle := func(h micro.Handler) micro.Handler { + return micro.HandlerFunc(func(r micro.Request) { + r.SetHeader("request-id", "1234") + r.SetHeader("other-value", "testing") + h.Handle(r) + }) + } + + config := micro.Config{ + Name: "test_service", + Version: "0.1.0", + Endpoint: µ.EndpointConfig{ + Subject: "test.func", + Handler: micro.HandlerFunc(handler), + }, + Middleware: []micro.Middleware{middle}, + } + + srv, err := micro.AddService(nc, config) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer srv.Stop() + + resp, err := nc.Request("test.func", []byte("test"), 1*time.Second) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if string(resp.Data) != "test" { + t.Fatalf("Invalid response; want: %q; got: %q", "val", string(resp.Data)) + } + if resp.Header.Get("request-id") != "1234" { + t.Fatalf("Invalid response; want: %v; got %v", "1234", resp.Header.Get("request-id")) + } + if resp.Header.Get("other-value") != "testing" { + t.Fatalf("Invalid response; want: %v; got %v", "testing", resp.Header.Get("other-value")) + } + +} + func TestAddEndpoint_Concurrency(t *testing.T) { s := RunServerOnPort(-1) defer s.Shutdown()