Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions micro/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,39 @@ func ExampleContextHandler() {
defer srv.Stop()
}

func ExampleMiddleware() {
middle := func(h micro.Handler) micro.Handler {
return micro.HandlerFunc(func(r micro.Request) {
log.Println("in middleware")
r.SetHeader("request-id", "1234")
h.Handle(r)
})
}

nc, err := nats.Connect("127.0.0.1:4222")
if err != nil {
log.Fatal(err)
}
defer nc.Close()

handler := func(req micro.Request) {
log.Println(req.Headers().Get("request-id"))
req.Respond(req.Data())
}
config := micro.Config{
Name: "MiddlewareExample",
Version: "0.1.0",
Endpoint: &micro.EndpointConfig{
Subject: "middleware",
Handler: micro.HandlerFunc(handler),
},
Middleware: []micro.Middleware{middle},
}

srv, _ := micro.AddService(nc, config)
defer srv.Stop()
}

func ExampleControlSubject() {

// subject used to get PING from all services
Expand Down
14 changes: 14 additions & 0 deletions micro/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type (
// Headers returns request headers.
Headers() Headers

// SetHeader sets a specific header
SetHeader(key, value string)

// Subject returns underlying NATS message subject.
Subject() string
}
Expand Down Expand Up @@ -181,6 +184,17 @@ func (r *request) Headers() Headers {
return Headers(r.msg.Header)
}

func (r *request) SetHeader(key string, value string) {
if len(r.msg.Header) != 0 {
r.msg.Header.Set(key, value)
return
}
headers := nats.Header{
key: []string{value},
}
r.msg.Header = headers
}

// Subject returns underlying NATS message subject.
func (r *request) Subject() string {
return r.msg.Subject
Expand Down
8 changes: 8 additions & 0 deletions micro/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ type (
Metadata map[string]string `json:"metadata"`
}

Middleware func(Handler) Handler

// Stats is the type returned by STATS monitoring endpoint.
// It contains stats of all registered endpoints.
Stats struct {
Expand Down Expand Up @@ -195,6 +197,9 @@ type (

// ErrorHandler is invoked on any nats-related service error.
ErrorHandler ErrHandler

// Middleware is a slice of handlers that should be run on every request
Middleware []Middleware
}

EndpointConfig struct {
Expand Down Expand Up @@ -400,6 +405,9 @@ func (s *service) AddEndpoint(name string, handler Handler, opts ...EndpointOpt)
subject = options.subject
}
queueGroup := queueGroupName(options.queueGroup, s.Config.QueueGroup)
for _, v := range s.Middleware {
handler = v(handler)
}
return addEndpoint(s, name, subject, handler, options.metadata, queueGroup)
}

Expand Down
58 changes: 58 additions & 0 deletions micro/test/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,64 @@ func TestContextHandler(t *testing.T) {
}
}

func TestMiddleware(t *testing.T) {
s := RunServerOnPort(-1)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Expected to connect to server, got %v", err)
}
defer nc.Close()

handler := func(req micro.Request) {
headers := micro.Headers{
"request-id": []string{req.Headers().Get("request-id")},
"other-value": []string{req.Headers().Get("other-value")},
}
req.Respond(req.Data(), micro.WithHeaders(headers))
}

middle := func(h micro.Handler) micro.Handler {
return micro.HandlerFunc(func(r micro.Request) {
r.SetHeader("request-id", "1234")
r.SetHeader("other-value", "testing")
h.Handle(r)
})
}

config := micro.Config{
Name: "test_service",
Version: "0.1.0",
Endpoint: &micro.EndpointConfig{
Subject: "test.func",
Handler: micro.HandlerFunc(handler),
},
Middleware: []micro.Middleware{middle},
}

srv, err := micro.AddService(nc, config)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
defer srv.Stop()

resp, err := nc.Request("test.func", []byte("test"), 1*time.Second)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
if string(resp.Data) != "test" {
t.Fatalf("Invalid response; want: %q; got: %q", "val", string(resp.Data))
}
if resp.Header.Get("request-id") != "1234" {
t.Fatalf("Invalid response; want: %v; got %v", "1234", resp.Header.Get("request-id"))
}
if resp.Header.Get("other-value") != "testing" {
t.Fatalf("Invalid response; want: %v; got %v", "testing", resp.Header.Get("other-value"))
}

}

func TestAddEndpoint_Concurrency(t *testing.T) {
s := RunServerOnPort(-1)
defer s.Shutdown()
Expand Down