@@ -11,20 +11,24 @@ import (
1111 "github.com/mark3labs/mcp-go/mcp"
1212)
1313
14+ // OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
15+ type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
16+
17+
1418// BeforeAnyHookFunc is a function that is called after the request is
1519// parsed but before the method is called.
16- type BeforeAnyHookFunc func(id any, method mcp.MCPMethod, message any)
20+ type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)
1721
1822// OnSuccessHookFunc is a hook that will be called after the request
1923// successfully generates a result, but before the result is sent to the client.
20- type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result any)
24+ type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)
2125
2226// OnErrorHookFunc is a hook that will be called when an error occurs,
2327// either during the request parsing or the method execution.
2428//
2529// Example usage:
2630// ```
27- // hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
31+ // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
2832// // Check for specific error types using errors.Is
2933// if errors.Is(err, ErrUnsupported) {
3034// // Handle capability not supported errors
@@ -51,14 +55,15 @@ type OnSuccessHookFunc func(id any, method mcp.MCPMethod, message any, result an
5155// log.Printf("Tool not found: %v", err)
5256// }
5357// })
54- type OnErrorHookFunc func(id any, method mcp.MCPMethod, message any, err error)
58+ type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
5559
5660{{range .}}
57- type OnBefore{{.HookName}}Func func(id any, message *mcp.{{.ParamType}})
58- type OnAfter{{.HookName}}Func func(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
61+ type OnBefore{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}})
62+ type OnAfter{{.HookName}}Func func(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}})
5963{{end}}
6064
6165type Hooks struct {
66+ OnRegisterSession []OnRegisterSessionHookFunc
6267 OnBeforeAny []BeforeAnyHookFunc
6368 OnSuccess []OnSuccessHookFunc
6469 OnError []OnErrorHookFunc
@@ -87,7 +92,7 @@ func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
8792//
8893// // Register hook to capture and inspect errors
8994// hooks := &Hooks{}
90- // hooks.AddOnError(func(id any, method mcp.MCPMethod, message any, err error) {
95+ // hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
9196// // For capability-related errors
9297// if errors.Is(err, ErrUnsupported) {
9398// // Handle capability not supported
@@ -124,21 +129,21 @@ func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
124129 c.OnError = append(c.OnError, hook)
125130}
126131
127- func (c *Hooks) beforeAny(id any, method mcp.MCPMethod, message any) {
132+ func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
128133 if c == nil {
129134 return
130135 }
131136 for _, hook := range c.OnBeforeAny {
132- hook(id, method, message)
137+ hook(ctx, id, method, message)
133138 }
134139}
135140
136- func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any) {
141+ func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
137142 if c == nil {
138143 return
139144 }
140145 for _, hook := range c.OnSuccess {
141- hook(id, method, message, result)
146+ hook(ctx, id, method, message, result)
142147 }
143148}
144149
@@ -156,15 +161,28 @@ func (c *Hooks) onSuccess(id any, method mcp.MCPMethod, message any, result any)
156161// - ErrResourceNotFound: When a resource is not found
157162// - ErrPromptNotFound: When a prompt is not found
158163// - ErrToolNotFound: When a tool is not found
159- func (c *Hooks) onError(id any, method mcp.MCPMethod, message any, err error) {
164+ func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
160165 if c == nil {
161166 return
162167 }
163168 for _, hook := range c.OnError {
164- hook(id, method, message, err)
169+ hook(ctx, id, method, message, err)
165170 }
166171}
167172
173+ func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
174+ c.OnRegisterSession = append(c.OnRegisterSession, hook)
175+ }
176+
177+ func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
178+ if c == nil {
179+ return
180+ }
181+ for _, hook := range c.OnRegisterSession {
182+ hook(ctx, session)
183+ }
184+ }
185+
168186{{- range .}}
169187func (c *Hooks) AddBefore{{.HookName}}(hook OnBefore{{.HookName}}Func) {
170188 c.OnBefore{{.HookName}} = append(c.OnBefore{{.HookName}}, hook)
@@ -174,23 +192,23 @@ func (c *Hooks) AddAfter{{.HookName}}(hook OnAfter{{.HookName}}Func) {
174192 c.OnAfter{{.HookName}} = append(c.OnAfter{{.HookName}}, hook)
175193}
176194
177- func (c *Hooks) before{{.HookName}}(id any, message *mcp.{{.ParamType}}) {
178- c.beforeAny(id, mcp.{{.MethodName}}, message)
195+ func (c *Hooks) before{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}) {
196+ c.beforeAny(ctx, id, mcp.{{.MethodName}}, message)
179197 if c == nil {
180198 return
181199 }
182200 for _, hook := range c.OnBefore{{.HookName}} {
183- hook(id, message)
201+ hook(ctx, id, message)
184202 }
185203}
186204
187- func (c *Hooks) after{{.HookName}}(id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
188- c.onSuccess(id, mcp.{{.MethodName}}, message, result)
205+ func (c *Hooks) after{{.HookName}}(ctx context.Context, id any, message *mcp.{{.ParamType}}, result *mcp.{{.ResultType}}) {
206+ c.onSuccess(ctx, id, mcp.{{.MethodName}}, message, result)
189207 if c == nil {
190208 return
191209 }
192210 for _, hook := range c.OnAfter{{.HookName}} {
193- hook(id, message, result)
211+ hook(ctx, id, message, result)
194212 }
195213}
196214{{- end -}}
0 commit comments