diff --git a/README.md b/README.md index e69de29..ef383bd 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,80 @@ +# CATE Webhook Hook Server Examples + +Example webhook servers for the CATE (Contextual Access for Tool Execution) hook system. These servers implement the webhook protocol for the three hook points: `tool.access`, `tool.pre`, and `tool.post`. + +## Examples + +### Full Server (with Web UI) + +A comprehensive hook server with a web dashboard for configuration. Includes all features: rules, PII redaction, A/B testing, and tool catalog integration. + +```bash +go run ./examples/full_server +# Open http://localhost:8888/ for the dashboard +``` + +[Full documentation](examples/full_server/README.md) + +### Basic Rules + +A configurable test server with YAML-based rule configuration and hot-reload. + +```bash +go run ./examples/basic_rules -config ./examples/basic_rules/example-config.yaml +``` + +[Full documentation](examples/basic_rules/README.md) + +### Content Filter + +Demonstrates blocking users, filtering toolkits, and rejecting requests based on content matching. + +```bash +go run ./examples/content_filter +``` + +[Full documentation](examples/content_filter/README.md) + +### PII Redactor + +Demonstrates detecting and redacting personally identifiable information (emails, IPs, SSNs, phone numbers, credit cards, dates of birth) from tool outputs. + +```bash +go run ./examples/pii_redactor +``` + +[Full documentation](examples/pii_redactor/README.md) + +### A/B Testing + +Demonstrates routing tool executions to different server variants for A/B testing and canary deployments. + +```bash +go run ./examples/ab_testing +``` + +[Full documentation](examples/ab_testing/README.md) + +## Hook Points + +All examples implement the three CATE webhook hook points: + +| Hook Point | Endpoint | Purpose | +| ------------- | -------- | ------------------------------------------ | +| `tool.access` | `/access`| Control which tools a user can see | +| `tool.pre` | `/pre` | Validate/modify requests before execution | +| `tool.post` | `/post` | Filter/modify responses after execution | + +Plus a health check at `/health`. + +## Response Codes + +| Code | Meaning | +| --------------------- | ------------------------- | +| `OK` | Allow / proceed | +| `CHECK_FAILED` | Block / deny | +| `RATE_LIMIT_EXCEEDED` | Rate limit exceeded | + +## Schema + +The webhook request/response types are generated from the [CATE webhook schema](https://github.com/ArcadeAI/schemas/blob/main/logical_extensions/http/1.0/schema.yaml) using oapi-codegen. See `pkg/server/` for the generated types. diff --git a/examples/ab_testing/README.md b/examples/ab_testing/README.md new file mode 100644 index 0000000..1eaa79b --- /dev/null +++ b/examples/ab_testing/README.md @@ -0,0 +1,154 @@ +# A/B Testing Example + +A minimal CATE webhook server that demonstrates how to perform A/B testing and canary deployments by routing tool executions to different server variants. + +## What It Shows + +- **Pre-execution hook**: Override server routing to direct requests to different tool versions +- **Deterministic assignment**: Hash-based variant selection ensures consistent routing +- **Sticky sessions**: Same user always gets the same variant +- **Traffic splitting**: Configurable weight distribution between variants + +## Quick Start + +```bash +go run ./examples/ab_testing -port 8890 +``` + +## How It Works + +The server intercepts tool execution requests at the **pre-hook** and: + +1. Checks if there's an active experiment for the requested tool +2. Computes a deterministic hash of the user ID + experiment name +3. Uses the hash to assign a variant based on configured weights +4. Overrides the `server` field in the response to route to the variant's server + +### Experiment Configuration + +Two example experiments are pre-configured: + +**A/B Test** - `search-v2-rollout`: +- Matches: `Search.WebSearch` +- Control (80%): Routes to `http://search-v1:8080` +- Treatment (20%): Routes to `http://search-v2:8080` + +**Canary Deploy** - `email-canary`: +- Matches: `Email.*` (all email tools) +- Stable (95%): Routes to `http://email-stable:8080` +- Canary (5%): Routes to `http://email-canary:8080` + +## Testing + +### A/B Test Routing + +```bash +# User "alice" will be assigned to one variant (deterministic) +curl -X POST http://localhost:8890/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-1", + "tool": {"name": "WebSearch", "toolkit": "Search", "version": "1.0"}, + "inputs": {"query": "test search"}, + "context": {"user_id": "alice"} + }' + +# Same user always gets the same variant +curl -X POST http://localhost:8890/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-2", + "tool": {"name": "WebSearch", "toolkit": "Search", "version": "1.0"}, + "inputs": {"query": "another search"}, + "context": {"user_id": "alice"} + }' + +# Different user may get a different variant +curl -X POST http://localhost:8890/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-3", + "tool": {"name": "WebSearch", "toolkit": "Search", "version": "1.0"}, + "inputs": {"query": "test search"}, + "context": {"user_id": "bob"} + }' +``` + +Expected response (server override): +```json +{ + "code": "OK", + "override": { + "server": { + "name": "search-v2", + "uri": "http://search-v2:8080", + "type": "arcade" + } + } +} +``` + +### Canary Deploy + +```bash +# Email tool - most users routed to stable +curl -X POST http://localhost:8890/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-4", + "tool": {"name": "sendEmail", "toolkit": "Email", "version": "1.0"}, + "inputs": {"to": "user@example.com", "body": "Hello"}, + "context": {"user_id": "user-123"} + }' +``` + +### Non-Experiment Tool + +Tools without experiments pass through unchanged: + +```bash +curl -X POST http://localhost:8890/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-5", + "tool": {"name": "calculate", "toolkit": "Math", "version": "1.0"}, + "inputs": {"expression": "2+2"}, + "context": {"user_id": "user-1"} + }' +``` + +Expected response (no override): +```json +{ + "code": "OK" +} +``` + +### View Assignments + +```bash +curl http://localhost:8890/_assignments +``` + +## Key Concepts + +### Deterministic Assignment + +The variant assignment uses SHA-256 hashing of `experiment_name:user_id` to produce a deterministic bucket (0-99). This means: + +- No external database needed for assignment tracking +- Same user always gets the same variant +- Distribution approximates the configured weights over many users + +### Canary Deployment Pattern + +To do a canary deployment: + +1. Start with a 99/1 or 95/5 split (stable/canary) +2. Monitor error rates for the canary variant +3. Gradually increase canary weight: 5% -> 10% -> 25% -> 50% -> 100% +4. If issues arise, set canary weight to 0% to roll back + +### Server Routing Override + +The A/B testing works by overriding the `server` field in the pre-hook response. The engine then routes the tool execution to the specified server instead of the default one. diff --git a/examples/ab_testing/main.go b/examples/ab_testing/main.go new file mode 100644 index 0000000..ea4f2f7 --- /dev/null +++ b/examples/ab_testing/main.go @@ -0,0 +1,331 @@ +// ab_testing is a minimal CATE webhook server that demonstrates how to +// perform A/B testing and canary deployments by routing tool executions +// to different server variants. +// +// It uses the pre-execution hook to intercept tool requests and override +// the server routing based on experiment configuration and user assignment. +// +// Usage: +// +// go run ./examples/ab_testing -port 8890 +package main + +import ( + "crypto/sha256" + "encoding/binary" + "flag" + "fmt" + "log" + "net/http" + "regexp" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + + "github.com/ArcadeAI/logical-extensions-examples/pkg/server" +) + +// ============================================================================= +// A/B Testing Configuration +// ============================================================================= + +// Variant represents one arm of an experiment. +type Variant struct { + Name string // e.g., "control", "treatment" + Weight int // Percentage of traffic (0-100) + ServerName string // Server name for routing + ServerURI string // Server URI for routing + ServerType string // "arcade" or "mcp" +} + +// Experiment defines an A/B or canary test. +type Experiment struct { + Name string + Toolkit string // Pattern to match toolkit + Tool string // Pattern to match tool + Variants []Variant +} + +// Assignment records which variant a user was assigned to. +type Assignment struct { + ExperimentName string `json:"experiment"` + VariantName string `json:"variant"` + UserID string `json:"user_id"` + AssignedAt time.Time `json:"assigned_at"` +} + +// ============================================================================= +// A/B Testing Server +// ============================================================================= + +// ABTestServer implements the CATE webhook ServerInterface. +type ABTestServer struct { + experiments []Experiment + mu sync.RWMutex + assignments map[string]*Assignment // key: "experiment:user_id" +} + +// NewABTestServer creates a server with example experiments. +func NewABTestServer() *ABTestServer { + return &ABTestServer{ + assignments: make(map[string]*Assignment), + experiments: []Experiment{ + { + // A/B test: 80/20 split between two versions of a search tool + Name: "search-v2-rollout", + Toolkit: "Search", + Tool: "WebSearch", + Variants: []Variant{ + { + Name: "control", + Weight: 80, + ServerName: "search-v1", + ServerURI: "http://search-v1:8080", + ServerType: "arcade", + }, + { + Name: "treatment", + Weight: 20, + ServerName: "search-v2", + ServerURI: "http://search-v2:8080", + ServerType: "arcade", + }, + }, + }, + { + // Canary: 95/5 split for a new email tool version + Name: "email-canary", + Toolkit: "Email", + Tool: "*", + Variants: []Variant{ + { + Name: "stable", + Weight: 95, + ServerName: "email-stable", + ServerURI: "http://email-stable:8080", + ServerType: "arcade", + }, + { + Name: "canary", + Weight: 5, + ServerName: "email-canary", + ServerURI: "http://email-canary:8080", + ServerType: "arcade", + }, + }, + }, + }, + } +} + +// HealthCheck implements webhook.ServerInterface. +func (s *ABTestServer) HealthCheck(c *gin.Context) { + status := server.Healthy + c.JSON(http.StatusOK, server.HealthResponse{Status: &status}) +} + +// AccessHook passes through - A/B testing doesn't affect tool visibility. +func (s *ABTestServer) AccessHook(c *gin.Context) { + c.JSON(http.StatusOK, server.AccessHookResult{}) +} + +// PreHook is where A/B testing routing happens. +// It intercepts tool execution requests and overrides the server routing +// to direct the request to the assigned variant's server. +func (s *ABTestServer) PreHook(c *gin.Context) { + var req server.PreHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request: " + err.Error()), + }) + return + } + + userID := "" + if req.Context.UserId != nil { + userID = *req.Context.UserId + } + + // Find a matching experiment for this tool + experiment := s.findExperiment(req.Tool.Toolkit, req.Tool.Name) + if experiment == nil { + // No experiment for this tool, pass through + c.JSON(http.StatusOK, server.PreHookResult{Code: server.OK}) + return + } + + // Assign a variant to the user (sticky - same user always gets same variant) + variant := s.assignVariant(experiment, userID) + if variant == nil { + c.JSON(http.StatusOK, server.PreHookResult{Code: server.OK}) + return + } + + log.Printf("[A/B] %s.%s -> experiment=%q user=%q variant=%q server=%s", + req.Tool.Toolkit, req.Tool.Name, experiment.Name, userID, variant.Name, variant.ServerURI) + + // Override the server routing to point to the variant's server + c.JSON(http.StatusOK, server.PreHookResult{ + Code: server.OK, + Override: &server.PreHookOverride{ + Server: &server.ServerInfo{ + Name: variant.ServerName, + Uri: variant.ServerURI, + Type: server.ServerInfoType(variant.ServerType), + }, + }, + }) +} + +// PostHook passes through - A/B testing routing only happens in pre-hook. +func (s *ABTestServer) PostHook(c *gin.Context) { + c.JSON(http.StatusOK, server.PostHookResult{Code: server.OK}) +} + +// findExperiment finds an experiment matching the given toolkit and tool. +func (s *ABTestServer) findExperiment(toolkit, tool string) *Experiment { + for i := range s.experiments { + exp := &s.experiments[i] + if matchPattern(exp.Toolkit, toolkit) && matchPattern(exp.Tool, tool) { + return exp + } + } + return nil +} + +// assignVariant uses a deterministic hash to assign a user to a variant. +// This ensures the same user always gets the same variant (sticky assignment), +// which is critical for consistent A/B test results. +func (s *ABTestServer) assignVariant(exp *Experiment, userID string) *Variant { + cacheKey := fmt.Sprintf("%s:%s", exp.Name, userID) + + // Check if already assigned + s.mu.RLock() + if a, ok := s.assignments[cacheKey]; ok { + for i := range exp.Variants { + if exp.Variants[i].Name == a.VariantName { + s.mu.RUnlock() + return &exp.Variants[i] + } + } + } + s.mu.RUnlock() + + // Compute a deterministic bucket from the user ID and experiment name. + // This is the key mechanism: hashing ensures consistent assignment + // without needing to store state externally. + hash := sha256.Sum256([]byte(cacheKey)) + bucket := int(binary.BigEndian.Uint32(hash[:4])) % 100 + + // Walk the variants and find which bucket range the user falls into + cumulative := 0 + var selected *Variant + for i := range exp.Variants { + cumulative += exp.Variants[i].Weight + if bucket < cumulative { + selected = &exp.Variants[i] + break + } + } + + // Fallback to last variant + if selected == nil && len(exp.Variants) > 0 { + selected = &exp.Variants[len(exp.Variants)-1] + } + + // Cache the assignment + if selected != nil { + s.mu.Lock() + s.assignments[cacheKey] = &Assignment{ + ExperimentName: exp.Name, + VariantName: selected.Name, + UserID: userID, + AssignedAt: time.Now(), + } + s.mu.Unlock() + } + + return selected +} + +// ============================================================================= +// Helpers +// ============================================================================= + +func matchPattern(pattern, value string) bool { + if pattern == "" || pattern == "*" { + return true + } + if strings.Contains(pattern, "*") { + regexStr := "^" + strings.ReplaceAll(regexp.QuoteMeta(pattern), "\\*", ".*") + "$" + re, err := regexp.Compile(regexStr) + if err != nil { + return false + } + return re.MatchString(value) + } + return pattern == value +} + +func strPtr(s string) *string { + return &s +} + +// ============================================================================= +// Main +// ============================================================================= + +func main() { + port := flag.Int("port", 8890, "Port to listen on") + flag.Parse() + + srv := NewABTestServer() + + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Recovery()) + + server.RegisterHandlers(router, srv) + + // Admin endpoint to view assignments + router.GET("/_assignments", func(c *gin.Context) { + srv.mu.RLock() + defer srv.mu.RUnlock() + assignments := make([]*Assignment, 0, len(srv.assignments)) + for _, a := range srv.assignments { + assignments = append(assignments, a) + } + c.JSON(http.StatusOK, gin.H{ + "count": len(assignments), + "assignments": assignments, + }) + }) + + fmt.Println(strings.Repeat("=", 50)) + fmt.Println(" A/B Testing Example") + fmt.Println(strings.Repeat("=", 50)) + fmt.Printf(" Port: %d\n", *port) + fmt.Println() + fmt.Println(" This example demonstrates:") + fmt.Println(" - A/B testing via server routing override (pre-hook)") + fmt.Println(" - Deterministic variant assignment (hash-based)") + fmt.Println(" - Sticky assignment (same user = same variant)") + fmt.Println() + fmt.Println(" Experiments:") + for _, exp := range srv.experiments { + fmt.Printf(" %s (%s.%s):\n", exp.Name, exp.Toolkit, exp.Tool) + for _, v := range exp.Variants { + fmt.Printf(" %s: %d%% -> %s\n", v.Name, v.Weight, v.ServerURI) + } + } + fmt.Println() + fmt.Println(" Admin: GET /_assignments - view variant assignments") + fmt.Println(strings.Repeat("=", 50)) + + addr := fmt.Sprintf(":%d", *port) + if err := router.Run(addr); err != nil { + log.Fatal("Failed to start server:", err) + } +} diff --git a/examples/content_filter/README.md b/examples/content_filter/README.md new file mode 100644 index 0000000..dd586a7 --- /dev/null +++ b/examples/content_filter/README.md @@ -0,0 +1,97 @@ +# Content Filter Example + +A minimal CATE webhook server that demonstrates how to block users, filter tools, and reject requests based on content matching. + +## What It Shows + +- **Access hook**: Block specific users or toolkits from being discovered +- **Pre-execution hook**: Block requests based on input content (e.g., blocked email domains) +- **Post-execution hook**: Block responses containing prohibited content + +## Quick Start + +```bash +go run ./examples/content_filter -port 8888 +``` + +## How It Works + +### Access Hook (User & Toolkit Blocking) + +The access hook runs when a user requests their available tools. It: + +1. Checks if the user is in the blocked users list +2. Filters out blocked toolkits from the response + +```bash +# This user will see no tools (blocked) +curl -X POST http://localhost:8888/access \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": "suspended-user", + "toolkits": { + "Email": {"tools": {"sendEmail": [{"version": "1.0"}]}} + } + }' + +# This toolkit will be filtered out +curl -X POST http://localhost:8888/access \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": "normal-user", + "toolkits": { + "SafeToolkit": {"tools": {"doStuff": [{"version": "1.0"}]}}, + "DangerousToolkit": {"tools": {"doHarm": [{"version": "1.0"}]}} + } + }' +``` + +### Pre-Execution Hook (Input Content Filtering) + +The pre hook runs before tool execution. It checks inputs against content rules: + +```bash +# This will be blocked (email to @blocked.com) +curl -X POST http://localhost:8888/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-1", + "tool": {"name": "sendEmail", "toolkit": "Email", "version": "1.0"}, + "inputs": {"to": "someone@blocked.com", "body": "Hello"}, + "context": {"user_id": "user-1"} + }' + +# This will be blocked (password field exists) +curl -X POST http://localhost:8888/pre \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-2", + "tool": {"name": "login", "toolkit": "Auth", "version": "1.0"}, + "inputs": {"username": "admin", "password": "secret123"}, + "context": {"user_id": "user-1"} + }' +``` + +### Post-Execution Hook (Output Content Filtering) + +The post hook runs after tool execution. It checks outputs for prohibited content: + +```bash +# This will be blocked (output contains CONFIDENTIAL) +curl -X POST http://localhost:8888/post \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-3", + "tool": {"name": "query", "toolkit": "Database", "version": "1.0"}, + "output": {"data": "CONFIDENTIAL: internal report..."}, + "server": {"name": "db", "uri": "http://db:8080", "type": "arcade"}, + "context": {"user_id": "user-1"} + }' +``` + +## Response Codes + +| Code | Meaning | +| -------------- | ---------------------------- | +| `OK` | Request allowed to proceed | +| `CHECK_FAILED` | Request blocked by a rule | diff --git a/examples/content_filter/main.go b/examples/content_filter/main.go new file mode 100644 index 0000000..61c7929 --- /dev/null +++ b/examples/content_filter/main.go @@ -0,0 +1,324 @@ +// content_filter is a minimal CATE webhook server that demonstrates how to +// block users, filter tools, and reject requests based on content matching. +// +// It shows the three hook points in action: +// - Access hook: Block specific users or toolkits from being discovered +// - Pre-execution hook: Block requests based on input content +// - Post-execution hook: Block responses based on output content +// +// Usage: +// +// go run ./examples/content_filter -port 8888 +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "regexp" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/ArcadeAI/logical-extensions-examples/pkg/server" +) + +// ============================================================================= +// Configuration +// ============================================================================= + +// BlockedUser is a user that should be denied all access. +type BlockedUser struct { + UserID string + Reason string +} + +// BlockedToolkit is a toolkit that should be hidden from all users. +type BlockedToolkit struct { + Pattern string // Exact, glob, or ~regex + Reason string +} + +// ContentRule blocks requests/responses containing specific content. +type ContentRule struct { + Toolkit string // Pattern to match toolkit + Tool string // Pattern to match tool + FieldMatch string // "field contains value" or "field=value" + ErrorMessage string +} + +// FilterConfig holds all filtering rules. +type FilterConfig struct { + BlockedUsers []BlockedUser + BlockedToolkits []BlockedToolkit + InputRules []ContentRule // Applied in pre-hook + OutputRules []ContentRule // Applied in post-hook +} + +// ============================================================================= +// Filter Server +// ============================================================================= + +// FilterServer implements the CATE webhook ServerInterface. +type FilterServer struct { + config FilterConfig +} + +// NewFilterServer creates a server with example filtering rules. +func NewFilterServer() *FilterServer { + return &FilterServer{ + config: FilterConfig{ + // Block specific users + BlockedUsers: []BlockedUser{ + {UserID: "suspended-user", Reason: "Account suspended"}, + {UserID: "terminated-user", Reason: "Account terminated"}, + }, + + // Block specific toolkits + BlockedToolkits: []BlockedToolkit{ + {Pattern: "DangerousToolkit", Reason: "Toolkit is disabled"}, + {Pattern: "Internal*", Reason: "Internal tools not available"}, + }, + + // Block requests with certain input content + InputRules: []ContentRule{ + { + Toolkit: "Email", + Tool: "sendEmail", + FieldMatch: "to contains @blocked.com", + ErrorMessage: "Cannot send emails to blocked domains", + }, + { + Toolkit: "*", + Tool: "*", + FieldMatch: "password", + ErrorMessage: "Cannot pass password fields directly", + }, + }, + + // Block responses with certain output content + OutputRules: []ContentRule{ + { + Toolkit: "Database", + Tool: "*", + FieldMatch: "data contains CONFIDENTIAL", + ErrorMessage: "Response contains confidential data", + }, + }, + }, + } +} + +// HealthCheck implements webhook.ServerInterface. +func (s *FilterServer) HealthCheck(c *gin.Context) { + status := server.Healthy + c.JSON(http.StatusOK, server.HealthResponse{Status: &status}) +} + +// AccessHook demonstrates blocking users and toolkits at the access level. +// This controls what tools a user can even see - blocked tools never appear +// in the tool list. +func (s *FilterServer) AccessHook(c *gin.Context) { + var req server.AccessHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request: " + err.Error()), + }) + return + } + + // Check if user is blocked + for _, blocked := range s.config.BlockedUsers { + if req.UserId == blocked.UserID { + log.Printf("[ACCESS] Blocked user %q: %s", req.UserId, blocked.Reason) + // Return empty "only" list to deny everything + empty := make(server.Toolkits) + c.JSON(http.StatusOK, server.AccessHookResult{ + Only: &empty, + }) + return + } + } + + // Filter out blocked toolkits + deny := make(server.Toolkits) + for toolkitName, toolkitInfo := range req.Toolkits { + for _, blocked := range s.config.BlockedToolkits { + if matchPattern(blocked.Pattern, toolkitName) { + log.Printf("[ACCESS] Blocked toolkit %q for user %q: %s", + toolkitName, req.UserId, blocked.Reason) + deny[toolkitName] = toolkitInfo + } + } + } + + result := &server.AccessHookResult{} + if len(deny) > 0 { + result.Deny = &deny + } + + c.JSON(http.StatusOK, result) +} + +// PreHook demonstrates blocking requests based on input content. +// This runs before tool execution and can stop a request from proceeding. +func (s *FilterServer) PreHook(c *gin.Context) { + var req server.PreHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request: " + err.Error()), + }) + return + } + + // Check each content rule against the inputs + for _, rule := range s.config.InputRules { + if !matchPattern(rule.Toolkit, req.Tool.Toolkit) { + continue + } + if !matchPattern(rule.Tool, req.Tool.Name) { + continue + } + if matchContent(rule.FieldMatch, req.Inputs) { + log.Printf("[PRE] Blocked %s.%s: %s", req.Tool.Toolkit, req.Tool.Name, rule.ErrorMessage) + c.JSON(http.StatusOK, server.PreHookResult{ + Code: server.CHECKFAILED, + ErrorMessage: &rule.ErrorMessage, + }) + return + } + } + + // Allow the request to proceed + c.JSON(http.StatusOK, server.PreHookResult{Code: server.OK}) +} + +// PostHook demonstrates blocking responses based on output content. +// This runs after tool execution and can prevent the response from reaching +// the agent if it contains blocked content. +func (s *FilterServer) PostHook(c *gin.Context) { + var req server.PostHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request: " + err.Error()), + }) + return + } + + // Check each content rule against the output + for _, rule := range s.config.OutputRules { + if !matchPattern(rule.Toolkit, req.Tool.Toolkit) { + continue + } + if !matchPattern(rule.Tool, req.Tool.Name) { + continue + } + if matchContent(rule.FieldMatch, req.Output) { + log.Printf("[POST] Blocked output from %s.%s: %s", + req.Tool.Toolkit, req.Tool.Name, rule.ErrorMessage) + c.JSON(http.StatusOK, server.PostHookResult{ + Code: server.CHECKFAILED, + ErrorMessage: &rule.ErrorMessage, + }) + return + } + } + + // Allow the response to pass through + c.JSON(http.StatusOK, server.PostHookResult{Code: server.OK}) +} + +// ============================================================================= +// Helpers +// ============================================================================= + +// matchPattern matches a value against a pattern (exact, glob, or regex). +func matchPattern(pattern, value string) bool { + if pattern == "" || pattern == "*" { + return true + } + if strings.HasPrefix(pattern, "~") { + re, err := regexp.Compile(pattern[1:]) + if err != nil { + return false + } + return re.MatchString(value) + } + if strings.Contains(pattern, "*") { + regexStr := "^" + strings.ReplaceAll(regexp.QuoteMeta(pattern), "\\*", ".*") + "$" + re, err := regexp.Compile(regexStr) + if err != nil { + return false + } + return re.MatchString(value) + } + return pattern == value +} + +// matchContent checks if a map contains content matching an expression. +// Supports: "key contains value", "key=value", or just "key" (exists check). +func matchContent(expr string, data map[string]interface{}) bool { + if strings.Contains(expr, " contains ") { + parts := strings.SplitN(expr, " contains ", 2) + key, sub := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if val, ok := data[key]; ok { + return strings.Contains(fmt.Sprintf("%v", val), sub) + } + return false + } + if strings.Contains(expr, "=") { + parts := strings.SplitN(expr, "=", 2) + key, expected := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if val, ok := data[key]; ok { + return fmt.Sprintf("%v", val) == expected + } + return false + } + // Just check if the key exists + _, ok := data[expr] + return ok +} + +func strPtr(s string) *string { + return &s +} + +// ============================================================================= +// Main +// ============================================================================= + +func main() { + port := flag.Int("port", 8888, "Port to listen on") + flag.Parse() + + srv := NewFilterServer() + + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Recovery()) + + server.RegisterHandlers(router, srv) + + fmt.Println(strings.Repeat("=", 50)) + fmt.Println(" Content Filter Example") + fmt.Println(strings.Repeat("=", 50)) + fmt.Printf(" Port: %d\n", *port) + fmt.Println() + fmt.Println(" This example demonstrates:") + fmt.Println(" - Blocking users at the access hook") + fmt.Println(" - Filtering toolkits at the access hook") + fmt.Println(" - Blocking requests by input content (pre)") + fmt.Println(" - Blocking responses by output content (post)") + fmt.Println() + fmt.Println(" Blocked users: suspended-user, terminated-user") + fmt.Println(" Blocked toolkits: DangerousToolkit, Internal*") + fmt.Println(" Input rules: emails to @blocked.com, password fields") + fmt.Println(" Output rules: responses containing CONFIDENTIAL") + fmt.Println(strings.Repeat("=", 50)) + + addr := fmt.Sprintf(":%d", *port) + if err := router.Run(addr); err != nil { + log.Fatal("Failed to start server:", err) + } +} diff --git a/examples/full_server/README.md b/examples/full_server/README.md new file mode 100644 index 0000000..25d5021 --- /dev/null +++ b/examples/full_server/README.md @@ -0,0 +1,152 @@ +# Full Hook Server + +A comprehensive CATE webhook hook server with a web UI for configuration. It implements all webhook endpoints (`/health`, `/access`, `/pre`, `/post`) with: + +- **Rule-based access control** - Block users, filter toolkits/tools, pattern matching +- **Pre-execution hooks** - Block or modify requests before tool execution +- **Post-execution hooks** - Block or modify responses after tool execution +- **PII redaction** - Detect and redact personally identifiable information +- **A/B testing** - Route tool executions to different variants with traffic splitting +- **Tool catalog integration** - Fetch tools from an external API for experimentation +- **Web UI** - Configure everything through a modern dashboard + +## Quick Start + +```bash +# Run with defaults (port 8888, no auth) +go run ./examples/full_server + +# Run with a specific config file +go run ./examples/full_server -config ./examples/full_server/default-config.yaml + +# Run with authentication +go run ./examples/full_server -token "my-secret-token" + +# Run with TLS +go run ./examples/full_server -tls -cert server.crt -key server.key +``` + +Then open http://localhost:8888/ in your browser to access the dashboard. + +## Command Line Flags + +| Flag | Default | Description | +| ---------- | -------------- | ---------------------------------------------------- | +| `-port` | `8888` | Port to listen on | +| `-token` | `""` | Bearer token for authentication (empty = no auth) | +| `-verbose` | `true` | Log all requests to stdout | +| `-config` | `config.yaml` | Path to YAML configuration file | +| `-tls` | `false` | Enable TLS/HTTPS | +| `-cert` | `""` | Path to server certificate (PEM) | +| `-key` | `""` | Path to server private key (PEM) | +| `-ca` | `""` | Path to CA certificate for mTLS | + +## Features + +### 1. Access Control (tool.access hook) + +Control which tools users can see based on user ID, toolkit name, or tool name. + +- **Default action**: Allow or deny all unmatched requests +- **Pattern matching**: Exact, glob (`Admin*`), or regex (`~^test-.*`) +- **First-match wins**: Rules are evaluated in order + +### 2. Pre-Execution Rules (tool.pre hook) + +Validate and modify requests before tool execution: + +- **Block** requests based on user, toolkit, tool, or input content +- **Rate limit** specific tools +- **Override** inputs, secrets, headers, or server routing +- **Content matching**: Check if inputs contain specific values + +### 3. Post-Execution Rules (tool.post hook) + +Filter and transform responses after tool execution: + +- **Block** responses based on success/failure or output content +- **Override** output values (e.g., replace sensitive data) + +### 4. PII Redaction + +Automatically detect and handle personally identifiable information: + +- **Supported types**: Email, IP address, SSN, phone number, credit card, date of birth +- **Modes**: Redact (replace with placeholders) or Block (reject response) +- **Custom patterns**: Add your own regex patterns for domain-specific PII +- **Test tool**: Try redaction before enabling in the dashboard + +### 5. A/B Testing + +Route tool executions to different server variants: + +- **Traffic splitting**: Percentage-based distribution between variants +- **Sticky assignment**: Same user always gets the same variant (hash-based) +- **User targeting**: Include/exclude users by pattern +- **Canary deploys**: Start with small percentage, gradually increase + +### 6. Tool Catalog Integration + +Fetch tools from an external API: + +- **Periodic fetching**: Configurable interval +- **Tool browsing**: View fetched tools in the dashboard +- **A/B test setup**: Use fetched tools to create experiments + +## Endpoints + +### Webhook Endpoints (CATE Protocol) + +| Method | Path | Description | +| ------ | --------- | --------------------- | +| GET | `/health` | Health check | +| POST | `/access` | Access control hook | +| POST | `/pre` | Pre-execution hook | +| POST | `/post` | Post-execution hook | + +### API Endpoints + +| Method | Path | Description | +| ------ | -------------------- | ------------------------------ | +| GET | `/api/config` | Get full configuration | +| PUT | `/api/config` | Update configuration | +| GET | `/api/logs` | Get request logs | +| DELETE | `/api/logs` | Clear request logs | +| GET | `/api/status` | Server status | +| POST | `/api/pii/test` | Test PII redaction | +| POST | `/api/arcade/fetch` | Trigger tool catalog fetch | +| GET | `/api/arcade/tools` | Get cached tools | +| GET | `/api/ab/assignments`| Get A/B variant assignments | +| DELETE | `/api/ab/assignments`| Clear variant assignments | + +### Web UI + +| Path | Description | +| ---------- | -------------------- | +| `/` | Dashboard | + +## Configuration File + +The server uses a YAML configuration file (default: `config.yaml`). See `default-config.yaml` for a fully documented example. + +The config file supports hot-reload - changes are picked up automatically. + +## Integration + +Configure a webhook plugin to point to this server: + +```yaml +plugins: + - type: webhook + name: full-hook-server + config: + endpoints: + health: http://localhost:8888/health + access: http://localhost:8888/access + pre: http://localhost:8888/pre + post: http://localhost:8888/post + auth: + type: bearer + token: my-secret-token + timeout: 5s +``` diff --git a/examples/full_server/abtesting.go b/examples/full_server/abtesting.go new file mode 100644 index 0000000..74320ea --- /dev/null +++ b/examples/full_server/abtesting.go @@ -0,0 +1,182 @@ +package main + +import ( + "crypto/sha256" + "encoding/binary" + "fmt" + "regexp" + "strings" + "sync" + "time" +) + +// ============================================================================= +// A/B Testing Engine +// ============================================================================= + +// ABTestingEngine manages experiments and variant assignment. +type ABTestingEngine struct { + mu sync.RWMutex + assignments map[string]*VariantAssignment // key: "experiment:user_id" +} + +// VariantAssignment records a user's variant assignment for an experiment. +type VariantAssignment struct { + ExperimentName string `json:"experiment_name"` + VariantName string `json:"variant_name"` + UserID string `json:"user_id"` + AssignedAt time.Time `json:"assigned_at"` +} + +// NewABTestingEngine creates a new A/B testing engine. +func NewABTestingEngine() *ABTestingEngine { + return &ABTestingEngine{ + assignments: make(map[string]*VariantAssignment), + } +} + +// AssignVariant determines which variant a user should receive for an experiment. +// It uses a deterministic hash of the user ID and experiment name to ensure +// the same user always gets the same variant (sticky assignment). +func (e *ABTestingEngine) AssignVariant(experiment *Experiment, userID string) *Variant { + if !experiment.Enabled || len(experiment.Variants) == 0 { + return nil + } + + // Check user targeting + if experiment.UserTargeting != nil { + if !e.matchesTargeting(experiment.UserTargeting, userID) { + return nil + } + } + + // Check cache + cacheKey := fmt.Sprintf("%s:%s", experiment.Name, userID) + e.mu.RLock() + if assignment, ok := e.assignments[cacheKey]; ok { + for i := range experiment.Variants { + if experiment.Variants[i].Name == assignment.VariantName { + e.mu.RUnlock() + return &experiment.Variants[i] + } + } + } + e.mu.RUnlock() + + // Compute deterministic hash for consistent assignment + hash := sha256.Sum256([]byte(fmt.Sprintf("%s:%s", experiment.Name, userID))) + bucket := int(binary.BigEndian.Uint32(hash[:4])) % 100 + + // Find the variant based on weight distribution + cumulative := 0 + var selected *Variant + for i := range experiment.Variants { + cumulative += experiment.Variants[i].Weight + if bucket < cumulative { + selected = &experiment.Variants[i] + break + } + } + + // Fallback to last variant if weights don't sum to 100 + if selected == nil && len(experiment.Variants) > 0 { + selected = &experiment.Variants[len(experiment.Variants)-1] + } + + // Cache the assignment + if selected != nil { + e.mu.Lock() + e.assignments[cacheKey] = &VariantAssignment{ + ExperimentName: experiment.Name, + VariantName: selected.Name, + UserID: userID, + AssignedAt: time.Now(), + } + e.mu.Unlock() + } + + return selected +} + +// FindExperiment finds an active experiment matching the given toolkit and tool. +func (e *ABTestingEngine) FindExperiment(cfg *ABTestingConfig, toolkit, tool string) *Experiment { + if cfg == nil || !cfg.Enabled { + return nil + } + + for i := range cfg.Experiments { + exp := &cfg.Experiments[i] + if !exp.Enabled { + continue + } + if matchesGlob(exp.Toolkit, toolkit) && matchesGlob(exp.Tool, tool) { + return exp + } + } + return nil +} + +// GetAssignments returns all current variant assignments. +func (e *ABTestingEngine) GetAssignments() []*VariantAssignment { + e.mu.RLock() + defer e.mu.RUnlock() + + result := make([]*VariantAssignment, 0, len(e.assignments)) + for _, a := range e.assignments { + result = append(result, a) + } + return result +} + +// ClearAssignments clears all cached variant assignments. +func (e *ABTestingEngine) ClearAssignments() { + e.mu.Lock() + defer e.mu.Unlock() + e.assignments = make(map[string]*VariantAssignment) +} + +// matchesTargeting checks if a user ID matches the targeting rules. +func (e *ABTestingEngine) matchesTargeting(targeting *UserTargeting, userID string) bool { + // Check excludes first + for _, pattern := range targeting.Exclude { + if matchesGlob(pattern, userID) { + return false + } + } + + // If includes are specified, user must match at least one + if len(targeting.Include) > 0 { + for _, pattern := range targeting.Include { + if matchesGlob(pattern, userID) { + return true + } + } + return false + } + + // No include list means everyone is included + return true +} + +// matchesGlob matches a pattern against a value. Supports exact, glob (*), and regex (~prefix). +func matchesGlob(pattern, value string) bool { + if pattern == "" || pattern == "*" { + return true + } + if strings.HasPrefix(pattern, "~") { + re, err := regexp.Compile(pattern[1:]) + if err != nil { + return false + } + return re.MatchString(value) + } + if strings.Contains(pattern, "*") { + regexPattern := "^" + strings.ReplaceAll(regexp.QuoteMeta(pattern), "\\*", ".*") + "$" + re, err := regexp.Compile(regexPattern) + if err != nil { + return false + } + return re.MatchString(value) + } + return pattern == value +} diff --git a/examples/full_server/arcade.go b/examples/full_server/arcade.go new file mode 100644 index 0000000..7f591ae --- /dev/null +++ b/examples/full_server/arcade.go @@ -0,0 +1,245 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "sync" + "time" +) + +// ============================================================================= +// Arcade API Client +// ============================================================================= + +// ArcadeClient fetches tool definitions from the Arcade API. +type ArcadeClient struct { + mu sync.RWMutex + httpClient *http.Client + tools []ArcadeTool + lastFetch time.Time + fetching bool + stopChan chan struct{} +} + +// ArcadeTool represents a tool fetched from the Arcade API. +type ArcadeTool struct { + Name string `json:"name"` + Description string `json:"description"` + Toolkit ArcadeToolkit `json:"toolkit"` + Version string `json:"version"` + Inputs []ArcadeToolInput `json:"inputs"` + FetchedAt time.Time `json:"fetched_at"` +} + +// ArcadeToolkit represents a toolkit from the Arcade API. +type ArcadeToolkit struct { + Name string `json:"name"` + Description string `json:"description"` +} + +// ArcadeToolInput represents a tool input parameter. +type ArcadeToolInput struct { + Name string `json:"name"` + Description string `json:"description"` + Type string `json:"type"` + Required bool `json:"required"` +} + +// ArcadeAPIResponse represents the response from the Arcade tools endpoint. +type ArcadeAPIResponse struct { + Items []ArcadeToolRaw `json:"items"` + TotalCount int `json:"total_count"` + Offset int `json:"offset"` + Limit int `json:"limit"` +} + +// ArcadeToolRaw represents the raw tool data from the API. +type ArcadeToolRaw struct { + FullyQualifiedName string `json:"fully_qualified_name"` + Description string `json:"description"` + Toolkit map[string]interface{} `json:"toolkit"` + Version string `json:"version"` + Input map[string]interface{} `json:"input"` +} + +// NewArcadeClient creates a new Arcade API client. +func NewArcadeClient() *ArcadeClient { + return &ArcadeClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + tools: []ArcadeTool{}, + stopChan: make(chan struct{}), + } +} + +// FetchTools fetches tools from the Arcade API. +func (c *ArcadeClient) FetchTools(ctx context.Context, cfg *ArcadeConfig) error { + if cfg == nil || !cfg.Enabled || cfg.APIURL == "" { + return fmt.Errorf("arcade integration not configured") + } + + c.mu.Lock() + if c.fetching { + c.mu.Unlock() + return fmt.Errorf("fetch already in progress") + } + c.fetching = true + c.mu.Unlock() + + defer func() { + c.mu.Lock() + c.fetching = false + c.mu.Unlock() + }() + + url := fmt.Sprintf("%s/v1/tools", cfg.APIURL) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + if cfg.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+cfg.APIKey) + } + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to fetch tools: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + var apiResp ArcadeAPIResponse + if err := json.Unmarshal(body, &apiResp); err != nil { + // Try parsing as a direct array + var rawTools []ArcadeToolRaw + if err2 := json.Unmarshal(body, &rawTools); err2 != nil { + return fmt.Errorf("failed to parse API response: %w", err) + } + apiResp.Items = rawTools + } + + tools := make([]ArcadeTool, 0, len(apiResp.Items)) + now := time.Now() + for _, raw := range apiResp.Items { + tool := ArcadeTool{ + Name: raw.FullyQualifiedName, + Description: raw.Description, + Version: raw.Version, + FetchedAt: now, + } + + // Extract toolkit info + if raw.Toolkit != nil { + if name, ok := raw.Toolkit["name"].(string); ok { + tool.Toolkit.Name = name + } + if desc, ok := raw.Toolkit["description"].(string); ok { + tool.Toolkit.Description = desc + } + } + + // Extract input parameters + if raw.Input != nil { + if params, ok := raw.Input["parameters"].([]interface{}); ok { + for _, p := range params { + if param, ok := p.(map[string]interface{}); ok { + input := ArcadeToolInput{} + if name, ok := param["name"].(string); ok { + input.Name = name + } + if desc, ok := param["description"].(string); ok { + input.Description = desc + } + if t, ok := param["type"].(string); ok { + input.Type = t + } + if req, ok := param["required"].(bool); ok { + input.Required = req + } + tool.Inputs = append(tool.Inputs, input) + } + } + } + } + + tools = append(tools, tool) + } + + c.mu.Lock() + c.tools = tools + c.lastFetch = now + c.mu.Unlock() + + log.Printf("Fetched %d tools from Arcade API", len(tools)) + return nil +} + +// GetTools returns the cached tools. +func (c *ArcadeClient) GetTools() []ArcadeTool { + c.mu.RLock() + defer c.mu.RUnlock() + return append([]ArcadeTool{}, c.tools...) +} + +// GetLastFetch returns when tools were last fetched. +func (c *ArcadeClient) GetLastFetch() time.Time { + c.mu.RLock() + defer c.mu.RUnlock() + return c.lastFetch +} + +// StartPeriodicFetch starts a background goroutine to fetch tools periodically. +func (c *ArcadeClient) StartPeriodicFetch(cfgMgr *ConfigManager) { + go func() { + for { + cfg := cfgMgr.Get() + if cfg.Arcade != nil && cfg.Arcade.Enabled { + interval := time.Duration(cfg.Arcade.FetchIntervalSeconds) * time.Second + if interval < 30*time.Second { + interval = 30 * time.Second + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := c.FetchTools(ctx, cfg.Arcade); err != nil { + log.Printf("Periodic tool fetch failed: %v", err) + } + cancel() + + select { + case <-time.After(interval): + continue + case <-c.stopChan: + return + } + } else { + // Not enabled, check again in 30s + select { + case <-time.After(30 * time.Second): + continue + case <-c.stopChan: + return + } + } + } + }() +} + +// Stop stops the periodic fetch goroutine. +func (c *ArcadeClient) Stop() { + close(c.stopChan) +} diff --git a/examples/full_server/config.go b/examples/full_server/config.go new file mode 100644 index 0000000..1821eb9 --- /dev/null +++ b/examples/full_server/config.go @@ -0,0 +1,367 @@ +package main + +import ( + "fmt" + "log" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/fsnotify/fsnotify" + "gopkg.in/yaml.v3" +) + +// ============================================================================= +// Root Configuration +// ============================================================================= + +// Config is the root configuration loaded from YAML. +type Config struct { + Health *HealthConfig `yaml:"health" json:"health"` + Access *AccessConfig `yaml:"access" json:"access"` + Pre *PreConfig `yaml:"pre" json:"pre"` + Post *PostConfig `yaml:"post" json:"post"` + PII *PIIConfig `yaml:"pii" json:"pii"` + ABTesting *ABTestingConfig `yaml:"ab_testing" json:"ab_testing"` + Arcade *ArcadeConfig `yaml:"arcade" json:"arcade"` +} + +// ============================================================================= +// Health Configuration +// ============================================================================= + +// HealthConfig controls health endpoint behavior. +type HealthConfig struct { + Status string `yaml:"status" json:"status"` // healthy, degraded, unhealthy +} + +// ============================================================================= +// Access Control Configuration +// ============================================================================= + +// AccessConfig controls access hook behavior. +type AccessConfig struct { + DefaultAction string `yaml:"default_action" json:"default_action"` + Rules []AccessRule `yaml:"rules" json:"rules"` +} + +// AccessRule defines a single access control rule. +type AccessRule struct { + UserID string `yaml:"user_id" json:"user_id"` + ToolkitMatch string `yaml:"toolkit" json:"toolkit"` + ToolMatch string `yaml:"tool" json:"tool"` + Action string `yaml:"action" json:"action"` + Reason string `yaml:"reason" json:"reason"` +} + +// ============================================================================= +// Pre-Execution Configuration +// ============================================================================= + +// PreConfig controls pre-execution hook behavior. +type PreConfig struct { + DefaultAction string `yaml:"default_action" json:"default_action"` + Rules []PreRule `yaml:"rules" json:"rules"` +} + +// PreRule defines a single pre-execution rule. +type PreRule struct { + UserID string `yaml:"user_id" json:"user_id"` + Toolkit string `yaml:"toolkit" json:"toolkit"` + Tool string `yaml:"tool" json:"tool"` + ExecutionID string `yaml:"execution_id" json:"execution_id"` + InputMatch string `yaml:"input_match" json:"input_match"` + Action string `yaml:"action" json:"action"` + ErrorMessage string `yaml:"error_message" json:"error_message"` + Override *PreOverrideConfig `yaml:"override" json:"override"` +} + +// PreOverrideConfig defines what to override in pre-hook. +type PreOverrideConfig struct { + Inputs map[string]interface{} `yaml:"inputs" json:"inputs"` + Secrets map[string]string `yaml:"secrets" json:"secrets"` + Headers map[string]string `yaml:"headers" json:"headers"` + Server *ServerOverride `yaml:"server" json:"server"` +} + +// ServerOverride defines server routing override. +type ServerOverride struct { + Name string `yaml:"name" json:"name"` + URI string `yaml:"uri" json:"uri"` + Type string `yaml:"type" json:"type"` +} + +// ============================================================================= +// Post-Execution Configuration +// ============================================================================= + +// PostConfig controls post-execution hook behavior. +type PostConfig struct { + DefaultAction string `yaml:"default_action" json:"default_action"` + Rules []PostRule `yaml:"rules" json:"rules"` +} + +// PostRule defines a single post-execution rule. +type PostRule struct { + UserID string `yaml:"user_id" json:"user_id"` + Toolkit string `yaml:"toolkit" json:"toolkit"` + Tool string `yaml:"tool" json:"tool"` + ExecutionID string `yaml:"execution_id" json:"execution_id"` + Success *bool `yaml:"success" json:"success"` + OutputMatch string `yaml:"output_match" json:"output_match"` + Action string `yaml:"action" json:"action"` + ErrorMessage string `yaml:"error_message" json:"error_message"` + Override *PostOverrideConfig `yaml:"override" json:"override"` +} + +// PostOverrideConfig defines what to override in post-hook. +type PostOverrideConfig struct { + Output map[string]interface{} `yaml:"output" json:"output"` +} + +// ============================================================================= +// PII Redaction Configuration +// ============================================================================= + +// PIIConfig controls PII detection and redaction. +type PIIConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Mode string `yaml:"mode" json:"mode"` // "redact" or "block" + Types PIITypes `yaml:"types" json:"types"` + CustomPatterns []CustomPIIPattern `yaml:"custom_patterns" json:"custom_patterns"` +} + +// PIITypes defines which PII types to detect. +type PIITypes struct { + Email bool `yaml:"email" json:"email"` + IPAddress bool `yaml:"ip_address" json:"ip_address"` + SSN bool `yaml:"ssn" json:"ssn"` + PhoneNumber bool `yaml:"phone_number" json:"phone_number"` + CreditCard bool `yaml:"credit_card" json:"credit_card"` + DateOfBirth bool `yaml:"date_of_birth" json:"date_of_birth"` +} + +// CustomPIIPattern defines a custom regex pattern for PII detection. +type CustomPIIPattern struct { + Name string `yaml:"name" json:"name"` + Pattern string `yaml:"pattern" json:"pattern"` + Replacement string `yaml:"replacement" json:"replacement"` +} + +// ============================================================================= +// A/B Testing Configuration +// ============================================================================= + +// ABTestingConfig controls A/B and canary testing. +type ABTestingConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + Experiments []Experiment `yaml:"experiments" json:"experiments"` +} + +// Experiment defines an A/B or canary test. +type Experiment struct { + Name string `yaml:"name" json:"name"` + Toolkit string `yaml:"toolkit" json:"toolkit"` + Tool string `yaml:"tool" json:"tool"` + Enabled bool `yaml:"enabled" json:"enabled"` + Variants []Variant `yaml:"variants" json:"variants"` + UserTargeting *UserTargeting `yaml:"user_targeting" json:"user_targeting"` +} + +// Variant defines a variant in an A/B test. +type Variant struct { + Name string `yaml:"name" json:"name"` + Weight int `yaml:"weight" json:"weight"` // percentage 0-100 + Server *ServerOverride `yaml:"server" json:"server"` +} + +// UserTargeting defines optional user targeting for experiments. +type UserTargeting struct { + Include []string `yaml:"include" json:"include"` + Exclude []string `yaml:"exclude" json:"exclude"` +} + +// ============================================================================= +// Arcade API Configuration +// ============================================================================= + +// ArcadeConfig controls Arcade API integration. +type ArcadeConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + APIURL string `yaml:"api_url" json:"api_url"` + APIKey string `yaml:"api_key" json:"api_key"` + FetchIntervalSeconds int `yaml:"fetch_interval_seconds" json:"fetch_interval_seconds"` +} + +// ============================================================================= +// Configuration Manager +// ============================================================================= + +// ConfigManager handles loading, saving, and watching configuration. +type ConfigManager struct { + mu sync.RWMutex + config *Config + filePath string +} + +// NewConfigManager creates a new ConfigManager with defaults. +func NewConfigManager(filePath string) *ConfigManager { + return &ConfigManager{ + filePath: filePath, + config: DefaultConfig(), + } +} + +// DefaultConfig returns the default configuration. +func DefaultConfig() *Config { + return &Config{ + Health: &HealthConfig{Status: "healthy"}, + Access: &AccessConfig{DefaultAction: "allow", Rules: []AccessRule{}}, + Pre: &PreConfig{DefaultAction: "proceed", Rules: []PreRule{}}, + Post: &PostConfig{DefaultAction: "proceed", Rules: []PostRule{}}, + PII: &PIIConfig{ + Enabled: false, + Mode: "redact", + Types: PIITypes{ + Email: true, + IPAddress: true, + SSN: true, + PhoneNumber: true, + CreditCard: true, + DateOfBirth: true, + }, + CustomPatterns: []CustomPIIPattern{}, + }, + ABTesting: &ABTestingConfig{ + Enabled: false, + Experiments: []Experiment{}, + }, + Arcade: &ArcadeConfig{ + Enabled: false, + APIURL: "https://api.arcade.dev", + APIKey: "", + FetchIntervalSeconds: 300, + }, + } +} + +// Get returns the current configuration (read-only copy). +func (cm *ConfigManager) Get() *Config { + cm.mu.RLock() + defer cm.mu.RUnlock() + return cm.config +} + +// Set replaces the full configuration. +func (cm *ConfigManager) Set(cfg *Config) { + cm.mu.Lock() + defer cm.mu.Unlock() + cm.config = cfg +} + +// Update merges non-nil fields from the provided config. +func (cm *ConfigManager) Update(cfg *Config) { + cm.mu.Lock() + defer cm.mu.Unlock() + if cfg.Health != nil { + cm.config.Health = cfg.Health + } + if cfg.Access != nil { + cm.config.Access = cfg.Access + } + if cfg.Pre != nil { + cm.config.Pre = cfg.Pre + } + if cfg.Post != nil { + cm.config.Post = cfg.Post + } + if cfg.PII != nil { + cm.config.PII = cfg.PII + } + if cfg.ABTesting != nil { + cm.config.ABTesting = cfg.ABTesting + } + if cfg.Arcade != nil { + cm.config.Arcade = cfg.Arcade + } +} + +// LoadFromFile loads configuration from the YAML file. +func (cm *ConfigManager) LoadFromFile() error { + data, err := os.ReadFile(cm.filePath) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("failed to parse config file: %w", err) + } + + cm.Update(&cfg) + return nil +} + +// SaveToFile writes the current configuration to the YAML file. +func (cm *ConfigManager) SaveToFile() error { + cm.mu.RLock() + data, err := yaml.Marshal(cm.config) + cm.mu.RUnlock() + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + if err := os.WriteFile(cm.filePath, data, 0644); err != nil { + return fmt.Errorf("failed to write config file: %w", err) + } + + return nil +} + +// FilePath returns the configuration file path. +func (cm *ConfigManager) FilePath() string { + return cm.filePath +} + +// WatchFile starts watching the config file for changes and reloads automatically. +func (cm *ConfigManager) WatchFile() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Printf("Failed to create file watcher: %v", err) + return + } + defer watcher.Close() + + if err := watcher.Add(cm.filePath); err != nil { + log.Printf("Failed to watch config file: %v", err) + return + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write == fsnotify.Write { + log.Println("Config file changed, reloading...") + if err := cm.LoadFromFile(); err != nil { + log.Printf("Failed to reload config: %v", err) + } else { + log.Println("Configuration reloaded successfully") + } + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Printf("Watcher error: %v", err) + case <-sigChan: + return + } + } +} diff --git a/examples/full_server/default-config.yaml b/examples/full_server/default-config.yaml new file mode 100644 index 0000000..0c33b2c --- /dev/null +++ b/examples/full_server/default-config.yaml @@ -0,0 +1,121 @@ +# ============================================================================= +# CATE Hook Server - Full Configuration +# ============================================================================= +# This file configures all features of the hook server. +# The server hot-reloads this file when it changes. + +# Health endpoint +health: + status: healthy # healthy, degraded, unhealthy + +# ============================================================================= +# Access Control +# ============================================================================= +# Determines which tools a user can see/use. +# Rules are evaluated in order - first match wins. +access: + default_action: allow # allow or deny + rules: + # Block a specific user from all tools + - user_id: "blocked-user" + action: deny + reason: "User is suspended" + + # Block a specific toolkit + - toolkit: "DangerousToolkit" + action: deny + reason: "Toolkit is disabled" + + # Allow admin users to access admin tools (glob pattern) + - user_id: "~^admin-.*" + toolkit: "Admin*" + action: allow + +# ============================================================================= +# Pre-Execution Rules +# ============================================================================= +# Runs before a tool is executed - can block or modify the request. +pre: + default_action: proceed # proceed, block, or rate_limit + rules: + # Block based on input content + - toolkit: "Email" + tool: "sendEmail" + input_match: "to contains @blocked.com" + action: block + error_message: "Sending to blocked domains is not allowed" + + # Rate limit a specific tool + - toolkit: "Search" + tool: "webSearch" + action: rate_limit + error_message: "Too many search requests" + +# ============================================================================= +# Post-Execution Rules +# ============================================================================= +# Runs after tool execution - can block or modify the response. +post: + default_action: proceed + rules: + # Block if execution failed + - success: false + action: block + error_message: "Tool execution failed" + +# ============================================================================= +# PII Redaction +# ============================================================================= +# Detects and redacts personally identifiable information from tool outputs. +pii: + enabled: false + mode: redact # redact (replace PII) or block (reject response) + types: + email: true + ip_address: true + ssn: true + phone_number: true + credit_card: true + date_of_birth: true + custom_patterns: [] + # - name: "employee_id" + # pattern: "EMP-\\d{6}" + # replacement: "[EMP-REDACTED]" + +# ============================================================================= +# A/B Testing +# ============================================================================= +# Routes tool executions to different variants for testing. +ab_testing: + enabled: false + experiments: [] + # - name: "search-v2-rollout" + # toolkit: "Search" + # tool: "WebSearch" + # enabled: true + # variants: + # - name: "control" + # weight: 80 + # server: + # name: "search-v1" + # uri: "http://search-v1:8080" + # type: "arcade" + # - name: "treatment" + # weight: 20 + # server: + # name: "search-v2" + # uri: "http://search-v2:8080" + # type: "arcade" + # user_targeting: + # include: ["beta-*"] + # exclude: [] + +# ============================================================================= +# Tool Catalog Integration +# ============================================================================= +# Fetches tools from an external API for use in A/B tests. +arcade: + enabled: false + api_url: "https://api.arcade.dev" + api_key: "" + fetch_interval_seconds: 300 diff --git a/examples/full_server/hooks.go b/examples/full_server/hooks.go new file mode 100644 index 0000000..816f4e0 --- /dev/null +++ b/examples/full_server/hooks.go @@ -0,0 +1,520 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + + "github.com/ArcadeAI/logical-extensions-examples/pkg/server" +) + +// ============================================================================= +// Request Logging +// ============================================================================= + +// RequestLog stores information about each incoming request. +type RequestLog struct { + Timestamp time.Time `json:"timestamp"` + Endpoint string `json:"endpoint"` + Body interface{} `json:"body"` + Response interface{} `json:"response"` + RuleMatch string `json:"rule_match,omitempty"` + PIIFound []PIIMatch `json:"pii_found,omitempty"` + ABVariant string `json:"ab_variant,omitempty"` +} + +// ============================================================================= +// Hook Server +// ============================================================================= + +// HookServer implements the webhook ServerInterface with all features. +type HookServer struct { + mu sync.RWMutex + logs []RequestLog + cfgMgr *ConfigManager + piiDetector *PIIDetector + abEngine *ABTestingEngine + arcadeClient *ArcadeClient + verbose bool + token string +} + +// NewHookServer creates a new hook server with all subsystems. +func NewHookServer(cfgMgr *ConfigManager, token string, verbose bool) *HookServer { + return &HookServer{ + logs: make([]RequestLog, 0), + cfgMgr: cfgMgr, + piiDetector: NewPIIDetector(), + abEngine: NewABTestingEngine(), + arcadeClient: NewArcadeClient(), + verbose: verbose, + token: token, + } +} + +func (s *HookServer) logRequest(endpoint string, body, response interface{}, ruleMatch string, piiFound []PIIMatch, abVariant string) { + s.mu.Lock() + defer s.mu.Unlock() + + entry := RequestLog{ + Timestamp: time.Now(), + Endpoint: endpoint, + Body: body, + Response: response, + RuleMatch: ruleMatch, + PIIFound: piiFound, + ABVariant: abVariant, + } + s.logs = append(s.logs, entry) + + if s.verbose { + jsonBody, _ := json.MarshalIndent(body, "", " ") + jsonResp, _ := json.MarshalIndent(response, "", " ") + fmt.Printf("\n[%s] %s\n", time.Now().Format("15:04:05"), endpoint) + if ruleMatch != "" { + fmt.Printf(" Rule matched: %s\n", ruleMatch) + } + if abVariant != "" { + fmt.Printf(" A/B Variant: %s\n", abVariant) + } + if len(piiFound) > 0 { + fmt.Printf(" PII detected: %d instances\n", len(piiFound)) + } + fmt.Printf(" Request: %s\n", string(jsonBody)) + fmt.Printf(" Response: %s\n", string(jsonResp)) + fmt.Println(strings.Repeat("-", 60)) + } +} + +// GetLogs returns all logged requests. +func (s *HookServer) GetLogs() []RequestLog { + s.mu.RLock() + defer s.mu.RUnlock() + return append([]RequestLog{}, s.logs...) +} + +// ClearLogs clears all logged requests. +func (s *HookServer) ClearLogs() { + s.mu.Lock() + defer s.mu.Unlock() + s.logs = make([]RequestLog, 0) +} + +// ============================================================================= +// Auth +// ============================================================================= + +func (s *HookServer) validateAuth(c *gin.Context) bool { + if s.token == "" { + return true + } + + auth := c.GetHeader("Authorization") + expected := "Bearer " + s.token + if auth != expected { + c.JSON(http.StatusUnauthorized, server.ErrorResponse{ + Error: strPtr("invalid or missing bearer token"), + Code: strPtr("UNAUTHORIZED"), + }) + return false + } + return true +} + +// ============================================================================= +// Health Check +// ============================================================================= + +// HealthCheck implements webhook.ServerInterface. +func (s *HookServer) HealthCheck(c *gin.Context) { + cfg := s.cfgMgr.Get() + status := server.HealthResponseStatus(cfg.Health.Status) + resp := server.HealthResponse{Status: &status} + s.logRequest("/health", nil, resp, "", nil, "") + c.JSON(http.StatusOK, resp) +} + +// ============================================================================= +// Access Hook +// ============================================================================= + +// AccessHook implements webhook.ServerInterface. +func (s *HookServer) AccessHook(c *gin.Context) { + if !s.validateAuth(c) { + return + } + + var req server.AccessHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request body: " + err.Error()), + Code: strPtr("INVALID_REQUEST"), + }) + return + } + + resp, ruleMatch := s.evaluateAccessRules(req) + s.logRequest("/access", req, resp, ruleMatch, nil, "") + c.JSON(http.StatusOK, resp) +} + +func (s *HookServer) evaluateAccessRules(req server.AccessHookRequest) (*server.AccessHookResult, string) { + cfg := s.cfgMgr.Get() + accessCfg := cfg.Access + + allow := make(server.Toolkits) + deny := make(server.Toolkits) + ruleMatch := "" + + for toolkitName, toolkitInfo := range req.Toolkits { + if toolkitInfo.Tools == nil { + continue + } + for toolName, versions := range *toolkitInfo.Tools { + action, matchedRule := s.matchAccessRule(accessCfg, req.UserId, toolkitName, toolName) + if matchedRule != "" { + ruleMatch = matchedRule + } + + if action == "deny" { + if _, ok := deny[toolkitName]; !ok { + deny[toolkitName] = server.ToolkitInfo{Tools: &map[string][]server.ToolVersionInfo{}} + } + (*deny[toolkitName].Tools)[toolName] = versions + } else { + if _, ok := allow[toolkitName]; !ok { + allow[toolkitName] = server.ToolkitInfo{Tools: &map[string][]server.ToolVersionInfo{}} + } + (*allow[toolkitName].Tools)[toolName] = versions + } + } + } + + result := &server.AccessHookResult{} + if len(allow) > 0 { + result.Only = &allow + } + if len(deny) > 0 { + result.Deny = &deny + } + + return result, ruleMatch +} + +func (s *HookServer) matchAccessRule(cfg *AccessConfig, userID, toolkit, tool string) (string, string) { + for i, rule := range cfg.Rules { + if matchesGlob(rule.UserID, userID) && + matchesGlob(rule.ToolkitMatch, toolkit) && + matchesGlob(rule.ToolMatch, tool) { + return rule.Action, fmt.Sprintf("access.rules[%d]", i) + } + } + return cfg.DefaultAction, "" +} + +// ============================================================================= +// Pre-Execution Hook +// ============================================================================= + +// PreHook implements webhook.ServerInterface. +func (s *HookServer) PreHook(c *gin.Context) { + if !s.validateAuth(c) { + return + } + + var req server.PreHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request body: " + err.Error()), + Code: strPtr("INVALID_REQUEST"), + }) + return + } + + cfg := s.cfgMgr.Get() + userID := "" + if req.Context.UserId != nil { + userID = *req.Context.UserId + } + + // 1. Check PII in inputs (if configured for pre-hook blocking) + if cfg.PII != nil && cfg.PII.Enabled && cfg.PII.Mode == "block" { + if s.piiDetector.CheckMapForPII(req.Inputs, cfg.PII) { + matches := s.piiDetector.DetectPII(fmt.Sprintf("%v", req.Inputs), cfg.PII) + resp := &server.PreHookResult{ + Code: server.CHECKFAILED, + ErrorMessage: strPtr("Request contains PII data that is not allowed"), + } + s.logRequest("/pre", req, resp, "pii_block", matches, "") + c.JSON(http.StatusOK, resp) + return + } + } + + // 2. Evaluate pre-execution rules + resp, ruleMatch := s.evaluatePreRules(req, userID) + + // 3. Check A/B testing (may override server routing) + abVariant := "" + if cfg.ABTesting != nil && cfg.ABTesting.Enabled && resp.Code == server.OK { + experiment := s.abEngine.FindExperiment(cfg.ABTesting, req.Tool.Toolkit, req.Tool.Name) + if experiment != nil { + variant := s.abEngine.AssignVariant(experiment, userID) + if variant != nil && variant.Server != nil { + abVariant = fmt.Sprintf("%s/%s", experiment.Name, variant.Name) + if resp.Override == nil { + resp.Override = &server.PreHookOverride{} + } + resp.Override.Server = &server.ServerInfo{ + Name: variant.Server.Name, + Uri: variant.Server.URI, + Type: server.ServerInfoType(variant.Server.Type), + } + } + } + } + + s.logRequest("/pre", req, resp, ruleMatch, nil, abVariant) + c.JSON(http.StatusOK, resp) +} + +func (s *HookServer) evaluatePreRules(req server.PreHookRequest, userID string) (*server.PreHookResult, string) { + cfg := s.cfgMgr.Get() + preCfg := cfg.Pre + + for i, rule := range preCfg.Rules { + if s.matchPreRule(rule, userID, req) { + result := s.applyPreRule(rule) + return result, fmt.Sprintf("pre.rules[%d]", i) + } + } + + return &server.PreHookResult{ + Code: actionToCode(preCfg.DefaultAction), + }, "" +} + +func (s *HookServer) matchPreRule(rule PreRule, userID string, req server.PreHookRequest) bool { + if !matchesGlob(rule.UserID, userID) { + return false + } + if !matchesGlob(rule.Toolkit, req.Tool.Toolkit) { + return false + } + if !matchesGlob(rule.Tool, req.Tool.Name) { + return false + } + if !matchesGlob(rule.ExecutionID, req.ExecutionId) { + return false + } + if rule.InputMatch != "" && !matchesInputs(rule.InputMatch, req.Inputs) { + return false + } + return true +} + +func (s *HookServer) applyPreRule(rule PreRule) *server.PreHookResult { + result := &server.PreHookResult{ + Code: actionToCode(rule.Action), + } + + if rule.ErrorMessage != "" { + result.ErrorMessage = &rule.ErrorMessage + } + + if rule.Override != nil && rule.Action == "proceed" { + override := &server.PreHookOverride{} + + if len(rule.Override.Inputs) > 0 { + override.Inputs = &rule.Override.Inputs + } + if len(rule.Override.Headers) > 0 { + override.Headers = &rule.Override.Headers + } + if len(rule.Override.Secrets) > 0 { + secrets := []map[string]string{rule.Override.Secrets} + override.Secrets = &secrets + } + if rule.Override.Server != nil { + override.Server = &server.ServerInfo{ + Name: rule.Override.Server.Name, + Uri: rule.Override.Server.URI, + Type: server.ServerInfoType(rule.Override.Server.Type), + } + } + + result.Override = override + } + + return result +} + +// ============================================================================= +// Post-Execution Hook +// ============================================================================= + +// PostHook implements webhook.ServerInterface. +func (s *HookServer) PostHook(c *gin.Context) { + if !s.validateAuth(c) { + return + } + + var req server.PostHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request body: " + err.Error()), + Code: strPtr("INVALID_REQUEST"), + }) + return + } + + cfg := s.cfgMgr.Get() + userID := "" + if req.Context.UserId != nil { + userID = *req.Context.UserId + } + + // 1. Evaluate post-execution rules + resp, ruleMatch := s.evaluatePostRules(req, userID) + + // 2. Apply PII redaction on output (if enabled and not already blocked) + var piiFound []PIIMatch + if cfg.PII != nil && cfg.PII.Enabled && resp.Code == server.OK { + if cfg.PII.Mode == "redact" { + redactedOutput, matches := s.piiDetector.RedactMap(req.Output, cfg.PII) + if len(matches) > 0 { + piiFound = matches + if resp.Override == nil { + resp.Override = &server.PostHookOverride{} + } + resp.Override.Output = &redactedOutput + } + } else if cfg.PII.Mode == "block" { + if s.piiDetector.CheckMapForPII(req.Output, cfg.PII) { + matches := s.piiDetector.DetectPII(fmt.Sprintf("%v", req.Output), cfg.PII) + piiFound = matches + resp = &server.PostHookResult{ + Code: server.CHECKFAILED, + ErrorMessage: strPtr("Response contains PII data that has been blocked"), + } + } + } + } + + s.logRequest("/post", req, resp, ruleMatch, piiFound, "") + c.JSON(http.StatusOK, resp) +} + +func (s *HookServer) evaluatePostRules(req server.PostHookRequest, userID string) (*server.PostHookResult, string) { + cfg := s.cfgMgr.Get() + postCfg := cfg.Post + + for i, rule := range postCfg.Rules { + if s.matchPostRule(rule, userID, req) { + result := s.applyPostRule(rule) + return result, fmt.Sprintf("post.rules[%d]", i) + } + } + + return &server.PostHookResult{ + Code: actionToCode(postCfg.DefaultAction), + }, "" +} + +func (s *HookServer) matchPostRule(rule PostRule, userID string, req server.PostHookRequest) bool { + if !matchesGlob(rule.UserID, userID) { + return false + } + if !matchesGlob(rule.Toolkit, req.Tool.Toolkit) { + return false + } + if !matchesGlob(rule.Tool, req.Tool.Name) { + return false + } + if !matchesGlob(rule.ExecutionID, req.ExecutionId) { + return false + } + if rule.Success != nil && req.Success != nil && *rule.Success != *req.Success { + return false + } + if rule.OutputMatch != "" && !matchesInputs(rule.OutputMatch, req.Output) { + return false + } + return true +} + +func (s *HookServer) applyPostRule(rule PostRule) *server.PostHookResult { + result := &server.PostHookResult{ + Code: actionToCode(rule.Action), + } + + if rule.ErrorMessage != "" { + result.ErrorMessage = &rule.ErrorMessage + } + + if rule.Override != nil && rule.Action == "proceed" { + if len(rule.Override.Output) > 0 { + result.Override = &server.PostHookOverride{ + Output: &rule.Override.Output, + } + } + } + + return result +} + +// ============================================================================= +// Helpers +// ============================================================================= + +func actionToCode(action string) server.ResponseCode { + switch action { + case "proceed", "allow", "": + return server.OK + case "block", "deny": + return server.CHECKFAILED + case "rate_limit": + return server.RATELIMITEXCEEDED + default: + return server.OK + } +} + +func matchesInputs(expr string, inputs map[string]interface{}) bool { + if strings.Contains(expr, " contains ") { + parts := strings.SplitN(expr, " contains ", 2) + if len(parts) != 2 { + return false + } + key, substring := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if val, ok := inputs[key]; ok { + return strings.Contains(fmt.Sprintf("%v", val), substring) + } + // Also check the full stringified map for general content matching + fullStr := fmt.Sprintf("%v", inputs) + return strings.Contains(fullStr, substring) + } + if strings.Contains(expr, "=") { + parts := strings.SplitN(expr, "=", 2) + if len(parts) != 2 { + return false + } + key, expected := strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + if val, ok := inputs[key]; ok { + return fmt.Sprintf("%v", val) == expected + } + return false + } + // Just check if key exists + _, ok := inputs[expr] + return ok +} + +func strPtr(s string) *string { + return &s +} diff --git a/examples/full_server/main.go b/examples/full_server/main.go new file mode 100644 index 0000000..a8a6753 --- /dev/null +++ b/examples/full_server/main.go @@ -0,0 +1,295 @@ +package main + +import ( + "context" + "crypto/tls" + "crypto/x509" + "embed" + "errors" + "flag" + "fmt" + "io/fs" + "log" + "net/http" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/ArcadeAI/logical-extensions-examples/pkg/server" +) + +//go:embed static +var staticFiles embed.FS + +func main() { + port := flag.Int("port", 8888, "Port to listen on") + token := flag.String("token", "", "Bearer token for authentication (empty = no auth)") + verbose := flag.Bool("verbose", true, "Log all requests to stdout") + configFile := flag.String("config", "config.yaml", "Path to YAML configuration file") + tlsEnabled := flag.Bool("tls", false, "Enable TLS/HTTPS") + certFile := flag.String("cert", "", "Path to server certificate file (PEM)") + keyFile := flag.String("key", "", "Path to server private key file (PEM)") + caFile := flag.String("ca", "", "Path to CA certificate for client verification (enables mTLS)") + flag.Parse() + + if *tlsEnabled { + if *certFile == "" || *keyFile == "" { + log.Fatal("TLS enabled but -cert and -key are required") + } + } + + // Initialize configuration manager + cfgMgr := NewConfigManager(*configFile) + + // Try to load existing config file + if _, err := os.Stat(*configFile); err == nil { + if err := cfgMgr.LoadFromFile(); err != nil { + log.Printf("Warning: Failed to load config from %s: %v", *configFile, err) + } else { + log.Printf("Loaded configuration from %s", *configFile) + } + } else { + // Save default config + if err := cfgMgr.SaveToFile(); err != nil { + log.Printf("Warning: Failed to save default config to %s: %v", *configFile, err) + } else { + log.Printf("Created default configuration at %s", *configFile) + } + } + + // Watch config file for changes + go cfgMgr.WatchFile() + + // Initialize hook server + hookServer := NewHookServer(cfgMgr, *token, *verbose) + + // Start arcade client periodic fetch + hookServer.arcadeClient.StartPeriodicFetch(cfgMgr) + defer hookServer.arcadeClient.Stop() + + // Set up Gin router + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Recovery()) + + // Register CATE webhook handlers + server.RegisterHandlers(router, hookServer) + + // API endpoints for the UI + api := router.Group("/api") + { + // Configuration + api.GET("/config", func(c *gin.Context) { + c.JSON(http.StatusOK, cfgMgr.Get()) + }) + api.PUT("/config", func(c *gin.Context) { + var cfg Config + if err := c.ShouldBindJSON(&cfg); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + cfgMgr.Update(&cfg) + if err := cfgMgr.SaveToFile(); err != nil { + log.Printf("Warning: Failed to save config: %v", err) + } + c.JSON(http.StatusOK, gin.H{"message": "configuration updated"}) + }) + + // Request logs + api.GET("/logs", func(c *gin.Context) { + logs := hookServer.GetLogs() + c.JSON(http.StatusOK, gin.H{"count": len(logs), "logs": logs}) + }) + api.DELETE("/logs", func(c *gin.Context) { + hookServer.ClearLogs() + c.JSON(http.StatusOK, gin.H{"message": "logs cleared"}) + }) + + // Server status + api.GET("/status", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "status": "running", + "port": *port, + "auth_enabled": *token != "", + "tls_enabled": *tlsEnabled, + "config_file": *configFile, + "request_count": len(hookServer.GetLogs()), + }) + }) + + // PII test endpoint + api.POST("/pii/test", func(c *gin.Context) { + var req struct { + Text string `json:"text"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + cfg := cfgMgr.Get() + piiCfg := cfg.PII + if piiCfg == nil { + piiCfg = &PIIConfig{ + Enabled: true, + Mode: "redact", + Types: PIITypes{Email: true, IPAddress: true, SSN: true, PhoneNumber: true, CreditCard: true, DateOfBirth: true}, + } + } + detector := NewPIIDetector() + redacted, matches := detector.RedactText(req.Text, piiCfg) + c.JSON(http.StatusOK, gin.H{ + "original": req.Text, + "redacted": redacted, + "matches": matches, + }) + }) + + // Arcade tool catalog + api.POST("/arcade/fetch", func(c *gin.Context) { + cfg := cfgMgr.Get() + ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second) + defer cancel() + if err := hookServer.arcadeClient.FetchTools(ctx, cfg.Arcade); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "tools fetched", "count": len(hookServer.arcadeClient.GetTools())}) + }) + api.GET("/arcade/tools", func(c *gin.Context) { + tools := hookServer.arcadeClient.GetTools() + lastFetch := hookServer.arcadeClient.GetLastFetch() + c.JSON(http.StatusOK, gin.H{ + "tools": tools, + "count": len(tools), + "last_fetch": lastFetch, + }) + }) + + // A/B testing info + api.GET("/ab/assignments", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "assignments": hookServer.abEngine.GetAssignments(), + }) + }) + api.DELETE("/ab/assignments", func(c *gin.Context) { + hookServer.abEngine.ClearAssignments() + c.JSON(http.StatusOK, gin.H{"message": "assignments cleared"}) + }) + } + + // Serve the web UI + staticSub, err := fs.Sub(staticFiles, "static") + if err != nil { + log.Fatal("Failed to create sub filesystem:", err) + } + router.GET("/", func(c *gin.Context) { + data, err := fs.ReadFile(staticSub, "index.html") + if err != nil { + c.String(http.StatusInternalServerError, "Failed to load UI: %v", err) + return + } + c.Data(http.StatusOK, "text/html; charset=utf-8", data) + }) + + printBanner(*port, *token, *tlsEnabled, *caFile, *configFile) + + addr := fmt.Sprintf(":%d", *port) + + if *tlsEnabled { + tlsConfig, err := buildTLSConfig(*caFile) + if err != nil { + log.Fatal("Failed to configure TLS:", err) + } + + srv := &http.Server{ + Addr: addr, + Handler: router, + TLSConfig: tlsConfig, + } + + if err := srv.ListenAndServeTLS(*certFile, *keyFile); err != nil { + log.Fatal("Failed to start TLS server:", err) + } + } else { + if err := router.Run(addr); err != nil { + log.Fatal("Failed to start server:", err) + } + } +} + +func buildTLSConfig(caFile string) (*tls.Config, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + if caFile != "" { + caCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("failed to parse CA certificate") + } + + tlsConfig.ClientCAs = caCertPool + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + return tlsConfig, nil +} + +func printBanner(port int, token string, tlsEnabled bool, caFile, configFile string) { + protocol := "http" + if tlsEnabled { + protocol = "https" + } + + fmt.Println(strings.Repeat("=", 60)) + fmt.Println(" CATE Hook Server (Full)") + fmt.Println(strings.Repeat("=", 60)) + fmt.Printf(" Port: %d\n", port) + fmt.Printf(" Auth: %s\n", authStatus(token)) + fmt.Printf(" TLS: %s\n", tlsStatus(tlsEnabled, caFile)) + fmt.Printf(" Config: %s\n", configFile) + fmt.Println(strings.Repeat("-", 60)) + fmt.Println(" Web UI:") + fmt.Printf(" %s://localhost:%d/\n", protocol, port) + fmt.Println() + fmt.Println(" Webhook Endpoints:") + fmt.Printf(" GET %s://localhost:%d/health\n", protocol, port) + fmt.Printf(" POST %s://localhost:%d/access\n", protocol, port) + fmt.Printf(" POST %s://localhost:%d/pre\n", protocol, port) + fmt.Printf(" POST %s://localhost:%d/post\n", protocol, port) + fmt.Println() + fmt.Println(" API Endpoints:") + fmt.Printf(" GET/PUT %s://localhost:%d/api/config\n", protocol, port) + fmt.Printf(" GET/DEL %s://localhost:%d/api/logs\n", protocol, port) + fmt.Printf(" GET %s://localhost:%d/api/status\n", protocol, port) + fmt.Printf(" POST %s://localhost:%d/api/pii/test\n", protocol, port) + fmt.Printf(" POST %s://localhost:%d/api/arcade/fetch\n", protocol, port) + fmt.Printf(" GET %s://localhost:%d/api/arcade/tools\n", protocol, port) + fmt.Println(strings.Repeat("=", 60)) + fmt.Println(" Ready to receive webhook requests...") + fmt.Println() +} + +func authStatus(token string) string { + if token == "" { + return "disabled" + } + return fmt.Sprintf("enabled (token: %s...)", token[:min(8, len(token))]) +} + +func tlsStatus(enabled bool, caFile string) string { + if !enabled { + return "disabled (HTTP)" + } + if caFile != "" { + return "mTLS enabled (client cert required)" + } + return "TLS enabled (HTTPS)" +} diff --git a/examples/full_server/pii.go b/examples/full_server/pii.go new file mode 100644 index 0000000..1f29323 --- /dev/null +++ b/examples/full_server/pii.go @@ -0,0 +1,255 @@ +package main + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +// ============================================================================= +// PII Detector / Redactor +// ============================================================================= + +// PIIPattern represents a named PII detection pattern. +type PIIPattern struct { + Name string + Pattern *regexp.Regexp +} + +// PIIDetector detects and redacts PII from text and structured data. +type PIIDetector struct { + // patterns is an ordered slice - order matters because longer patterns + // (like credit cards) must be matched before shorter ones (like phone numbers) + // to avoid partial matches. + patterns []PIIPattern +} + +// NewPIIDetector creates a new PIIDetector with compiled regex patterns. +// Patterns are ordered so that longer/more-specific patterns (credit_card) +// are applied before shorter/more-general ones (phone_number). +func NewPIIDetector() *PIIDetector { + return &PIIDetector{ + patterns: []PIIPattern{ + {Name: "email", Pattern: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`)}, + {Name: "ip_address", Pattern: regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)}, + {Name: "credit_card", Pattern: regexp.MustCompile(`\b(?:\d{4}[-\s]?){3}\d{4}\b`)}, + {Name: "ssn", Pattern: regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`)}, + {Name: "phone_number", Pattern: regexp.MustCompile(`(?:\+?1[-.\s]?)?\(?[2-9]\d{2}\)?[-.\s]?\d{3}[-.\s]?\d{4}`)}, + {Name: "date_of_birth", Pattern: regexp.MustCompile(`\b(?:\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}|\d{4}[/\-]\d{1,2}[/\-]\d{1,2})\b`)}, + }, + } +} + +// PIIMatch represents a detected PII instance. +type PIIMatch struct { + Type string `json:"type"` + Value string `json:"value"` + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` +} + +// replacementForType returns the replacement string for a PII type. +func replacementForType(piiType string) string { + switch piiType { + case "email": + return "[EMAIL_REDACTED]" + case "ip_address": + return "[IP_REDACTED]" + case "ssn": + return "[SSN_REDACTED]" + case "phone_number": + return "[PHONE_REDACTED]" + case "credit_card": + return "[CC_REDACTED]" + case "date_of_birth": + return "[DOB_REDACTED]" + default: + return "[REDACTED]" + } +} + +// DetectPII scans text for PII and returns all matches. +func (d *PIIDetector) DetectPII(text string, cfg *PIIConfig) []PIIMatch { + var matches []PIIMatch + + enabledTypes := d.getEnabledTypes(cfg) + + for _, p := range d.patterns { + if !enabledTypes[p.Name] { + continue + } + locs := p.Pattern.FindAllStringIndex(text, -1) + for _, loc := range locs { + matches = append(matches, PIIMatch{ + Type: p.Name, + Value: text[loc[0]:loc[1]], + StartIndex: loc[0], + EndIndex: loc[1], + }) + } + } + + // Check custom patterns + for _, cp := range cfg.CustomPatterns { + re, err := regexp.Compile(cp.Pattern) + if err != nil { + continue + } + locs := re.FindAllStringIndex(text, -1) + for _, loc := range locs { + matches = append(matches, PIIMatch{ + Type: cp.Name, + Value: text[loc[0]:loc[1]], + StartIndex: loc[0], + EndIndex: loc[1], + }) + } + } + + return matches +} + +// RedactText replaces PII in text with redaction markers. +func (d *PIIDetector) RedactText(text string, cfg *PIIConfig) (string, []PIIMatch) { + var allMatches []PIIMatch + result := text + + enabledTypes := d.getEnabledTypes(cfg) + + for _, p := range d.patterns { + if !enabledTypes[p.Name] { + continue + } + found := p.Pattern.FindAllString(result, -1) + for _, f := range found { + allMatches = append(allMatches, PIIMatch{Type: p.Name, Value: f}) + } + replacement := replacementForType(p.Name) + result = p.Pattern.ReplaceAllString(result, replacement) + } + + // Apply custom patterns + for _, cp := range cfg.CustomPatterns { + re, err := regexp.Compile(cp.Pattern) + if err != nil { + continue + } + found := re.FindAllString(result, -1) + for _, f := range found { + allMatches = append(allMatches, PIIMatch{Type: cp.Name, Value: f}) + } + replacement := cp.Replacement + if replacement == "" { + replacement = fmt.Sprintf("[%s_REDACTED]", strings.ToUpper(cp.Name)) + } + result = re.ReplaceAllString(result, replacement) + } + + return result, allMatches +} + +// ContainsPII checks if text contains any PII. +func (d *PIIDetector) ContainsPII(text string, cfg *PIIConfig) bool { + enabledTypes := d.getEnabledTypes(cfg) + + for _, p := range d.patterns { + if !enabledTypes[p.Name] { + continue + } + if p.Pattern.MatchString(text) { + return true + } + } + + for _, cp := range cfg.CustomPatterns { + re, err := regexp.Compile(cp.Pattern) + if err != nil { + continue + } + if re.MatchString(text) { + return true + } + } + + return false +} + +// RedactMap recursively redacts PII from a map of string to interface{}. +func (d *PIIDetector) RedactMap(data map[string]interface{}, cfg *PIIConfig) (map[string]interface{}, []PIIMatch) { + var allMatches []PIIMatch + result := make(map[string]interface{}, len(data)) + + for k, v := range data { + redacted, matches := d.redactValue(v, cfg) + result[k] = redacted + allMatches = append(allMatches, matches...) + } + + return result, allMatches +} + +// redactValue recursively redacts PII from any value. +func (d *PIIDetector) redactValue(v interface{}, cfg *PIIConfig) (interface{}, []PIIMatch) { + switch val := v.(type) { + case string: + redacted, matches := d.RedactText(val, cfg) + return redacted, matches + case map[string]interface{}: + redacted, matches := d.RedactMap(val, cfg) + return redacted, matches + case []interface{}: + var allMatches []PIIMatch + result := make([]interface{}, len(val)) + for i, item := range val { + redacted, matches := d.redactValue(item, cfg) + result[i] = redacted + allMatches = append(allMatches, matches...) + } + return result, allMatches + case json.Number: + return val, nil + default: + // For numbers, booleans, nil - return as-is + return v, nil + } +} + +// CheckMapForPII checks if any value in a map contains PII. +func (d *PIIDetector) CheckMapForPII(data map[string]interface{}, cfg *PIIConfig) bool { + for _, v := range data { + if d.checkValueForPII(v, cfg) { + return true + } + } + return false +} + +// checkValueForPII recursively checks if a value contains PII. +func (d *PIIDetector) checkValueForPII(v interface{}, cfg *PIIConfig) bool { + switch val := v.(type) { + case string: + return d.ContainsPII(val, cfg) + case map[string]interface{}: + return d.CheckMapForPII(val, cfg) + case []interface{}: + for _, item := range val { + if d.checkValueForPII(item, cfg) { + return true + } + } + } + return false +} + +// getEnabledTypes returns a map of enabled PII type names. +func (d *PIIDetector) getEnabledTypes(cfg *PIIConfig) map[string]bool { + return map[string]bool{ + "email": cfg.Types.Email, + "ip_address": cfg.Types.IPAddress, + "ssn": cfg.Types.SSN, + "phone_number": cfg.Types.PhoneNumber, + "credit_card": cfg.Types.CreditCard, + "date_of_birth": cfg.Types.DateOfBirth, + } +} diff --git a/examples/full_server/static/index.html b/examples/full_server/static/index.html new file mode 100644 index 0000000..4c9d1c3 --- /dev/null +++ b/examples/full_server/static/index.html @@ -0,0 +1,1224 @@ + + + + + +Hook Server Dashboard + + + + + + + + +
+ + + + + +
+ + +
+ +
+
+
Access Rules
+
0
+
Default: allow
+
+
+
Pre-Execution Rules
+
0
+
Default: proceed
+
+
+
Post-Execution Rules
+
0
+
Default: proceed
+
+
+
PII Redaction
+
Off
+
--
+
+
+
A/B Experiments
+
0
+
Disabled
+
+
+
Request Log
+
0
+
Total requests logged
+
+
+
+
+

Quick Actions

+
+
+ + + + + +
+
+
+ + +
+ +
+
+

Default Action

+
+
+ + +
+
+
+
+

Access Rules

+ +
+ + + +
User IDToolkitToolActionReason
+
+
+ + +
+ +
+
+

Default Action

+
+
+ + +
+
+
+
+

Rules

+ +
+ + + +
UserToolkitToolInput MatchActionMessage
+
+
+ + +
+ +
+
+

Default Action

+
+
+ + +
+
+
+
+

Rules

+ +
+ + + +
UserToolkitToolOutput MatchActionMessage
+
+
+ + +
+ +
+
+

PII Protection

+
+
+ + Enable PII Detection +
+
+ + +
+
+
+
+

PII Types

+
+
+
+ + Email Addresses +
+
+ + IP Addresses +
+
+ + Social Security Numbers +
+
+ + Phone Numbers +
+
+ + Credit Card Numbers +
+
+ + Dates of Birth +
+
+
+
+
+

Custom Patterns

+ +
+ + + +
NamePattern (Regex)Replacement
+
+
+
+

Test PII Redaction

+
+
+ + +
+ +
+
+
+ + +
+ +
+

A/B Testing

+
+ + Enable A/B Testing +
+
+
+
+

Experiments

+ +
+
+
+
+ + +
+ +
+

API Configuration

+
+ + Enable Tool Fetching +
+
+
+ + +
+
+ + +
+
+
+ + +
+ +
+
+
+

Fetched Tools

+ +
+ + + +
NameToolkitVersionDescription
+
+
+ + +
+ +
+
+

Logs

+
+ + +
+
+
+
+
+ + +
+ +
+
+

Configuration

+
+ + +
+
+ +
+
+
+ + + + diff --git a/examples/pii_redactor/README.md b/examples/pii_redactor/README.md new file mode 100644 index 0000000..fb116e3 --- /dev/null +++ b/examples/pii_redactor/README.md @@ -0,0 +1,132 @@ +# PII Redactor Example + +A minimal CATE webhook server that demonstrates how to detect and redact personally identifiable information (PII) from tool outputs using the post-execution hook. + +## What It Shows + +- **Post-execution hook**: Scan tool output for PII patterns and either redact or block +- **Output modification**: Replace sensitive data with redaction markers using the `override` mechanism +- **Two modes**: Redact (replace PII) or Block (reject the entire response) + +## Quick Start + +```bash +# Redact mode (default) - replace PII with markers +go run ./examples/pii_redactor -port 8889 + +# Block mode - reject responses containing PII +go run ./examples/pii_redactor -port 8889 -mode block +``` + +## Supported PII Types + +| Type | Example | Replacement | +| -------------- | -------------------- | ------------------ | +| Email | john@example.com | [EMAIL_REDACTED] | +| IP Address | 192.168.1.1 | [IP_REDACTED] | +| SSN | 123-45-6789 | [SSN_REDACTED] | +| Phone Number | (555) 123-4567 | [PHONE_REDACTED] | +| Credit Card | 4111-1111-1111-1111 | [CC_REDACTED] | +| Date of Birth | 01/15/1990 | [DOB_REDACTED] | + +## Testing + +### Redact Mode + +The post-hook replaces PII with markers and returns the modified output: + +```bash +curl -X POST http://localhost:8889/post \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-1", + "tool": {"name": "search", "toolkit": "Contacts", "version": "1.0"}, + "output": { + "name": "John Doe", + "email": "john@example.com", + "phone": "(555) 123-4567", + "ip": "Connected from 192.168.1.100", + "ssn": "SSN: 123-45-6789", + "notes": "Born on 01/15/1990, card ending 4111-1111-1111-1111" + }, + "server": {"name": "contacts", "uri": "http://contacts:8080", "type": "arcade"}, + "context": {"user_id": "user-1"} + }' +``` + +Expected response (redact mode): +```json +{ + "code": "OK", + "override": { + "output": { + "name": "John Doe", + "email": "[EMAIL_REDACTED]", + "phone": "[PHONE_REDACTED]", + "ip": "Connected from [IP_REDACTED]", + "ssn": "SSN: [SSN_REDACTED]", + "notes": "Born on [DOB_REDACTED], card ending [CC_REDACTED]" + } + } +} +``` + +### Block Mode + +When PII is found, the entire response is blocked: + +```bash +# Start in block mode +go run ./examples/pii_redactor -port 8889 -mode block + +# This response will be blocked because it contains an email +curl -X POST http://localhost:8889/post \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-2", + "tool": {"name": "lookup", "toolkit": "HR", "version": "1.0"}, + "output": {"result": "Contact: admin@company.com"}, + "server": {"name": "hr", "uri": "http://hr:8080", "type": "arcade"}, + "context": {"user_id": "user-1"} + }' +``` + +Expected response (block mode): +```json +{ + "code": "CHECK_FAILED", + "error_message": "Response blocked: contains personally identifiable information" +} +``` + +### No PII (Pass Through) + +When no PII is detected, the response passes through unchanged: + +```bash +curl -X POST http://localhost:8889/post \ + -H "Content-Type: application/json" \ + -d '{ + "execution_id": "exec-3", + "tool": {"name": "calculate", "toolkit": "Math", "version": "1.0"}, + "output": {"result": 42, "formula": "6 * 7"}, + "server": {"name": "math", "uri": "http://math:8080", "type": "arcade"}, + "context": {"user_id": "user-1"} + }' +``` + +Expected response: +```json +{ + "code": "OK" +} +``` + +## How It Works + +1. Tool executes and produces output +2. Engine sends the output to this post-hook server +3. Server scans all string values (recursively) for PII patterns +4. In **redact** mode: PII is replaced with `[TYPE_REDACTED]` markers using the `override.output` field +5. In **block** mode: If PII is found, returns `CHECK_FAILED` to reject the response entirely +6. If no PII is found, returns `OK` to pass through unchanged diff --git a/examples/pii_redactor/main.go b/examples/pii_redactor/main.go new file mode 100644 index 0000000..028e9a4 --- /dev/null +++ b/examples/pii_redactor/main.go @@ -0,0 +1,300 @@ +// pii_redactor is a minimal CATE webhook server that demonstrates how to +// detect and redact personally identifiable information (PII) from tool outputs. +// +// It uses the post-execution hook to scan tool responses and replace PII with +// redaction markers before the data reaches the agent. +// +// Supported PII types: email, IP address, SSN, phone number, credit card, date of birth. +// +// Usage: +// +// go run ./examples/pii_redactor -port 8889 +// go run ./examples/pii_redactor -port 8889 -mode block # Block instead of redact +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "regexp" + "strings" + + "github.com/gin-gonic/gin" + + "github.com/ArcadeAI/logical-extensions-examples/pkg/server" +) + +// ============================================================================= +// PII Detection +// ============================================================================= + +// PIIType represents a type of PII to detect. +type PIIType struct { + Name string + Pattern *regexp.Regexp + Replacement string + Enabled bool +} + +// PIIRedactor detects and redacts PII from text and structured data. +type PIIRedactor struct { + types []*PIIType + mode string // "redact" or "block" +} + +// NewPIIRedactor creates a redactor with all standard PII types enabled. +func NewPIIRedactor(mode string) *PIIRedactor { + // Order matters: credit_card must be matched before phone_number + // to avoid partial matches on credit card digits. + return &PIIRedactor{ + mode: mode, + types: []*PIIType{ + { + Name: "email", + Pattern: regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`), + Replacement: "[EMAIL_REDACTED]", + Enabled: true, + }, + { + Name: "ip_address", + Pattern: regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`), + Replacement: "[IP_REDACTED]", + Enabled: true, + }, + { + Name: "credit_card", + Pattern: regexp.MustCompile(`\b(?:\d{4}[-\s]?){3}\d{4}\b`), + Replacement: "[CC_REDACTED]", + Enabled: true, + }, + { + Name: "ssn", + Pattern: regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`), + Replacement: "[SSN_REDACTED]", + Enabled: true, + }, + { + Name: "phone_number", + Pattern: regexp.MustCompile(`(?:\+?1[-.\s]?)?\(?[2-9]\d{2}\)?[-.\s]?\d{3}[-.\s]?\d{4}`), + Replacement: "[PHONE_REDACTED]", + Enabled: true, + }, + { + Name: "date_of_birth", + Pattern: regexp.MustCompile(`\b(?:\d{1,2}[/\-]\d{1,2}[/\-]\d{2,4}|\d{4}[/\-]\d{1,2}[/\-]\d{1,2})\b`), + Replacement: "[DOB_REDACTED]", + Enabled: true, + }, + }, + } +} + +// RedactText scans text for PII and replaces it with redaction markers. +func (r *PIIRedactor) RedactText(text string) (string, int) { + result := text + count := 0 + for _, t := range r.types { + if !t.Enabled { + continue + } + matches := t.Pattern.FindAllString(result, -1) + count += len(matches) + result = t.Pattern.ReplaceAllString(result, t.Replacement) + } + return result, count +} + +// ContainsPII checks if text contains any PII. +func (r *PIIRedactor) ContainsPII(text string) bool { + for _, t := range r.types { + if !t.Enabled { + continue + } + if t.Pattern.MatchString(text) { + return true + } + } + return false +} + +// RedactMap recursively scans a map and redacts PII from all string values. +func (r *PIIRedactor) RedactMap(data map[string]interface{}) (map[string]interface{}, int) { + totalCount := 0 + result := make(map[string]interface{}, len(data)) + for k, v := range data { + redacted, count := r.redactValue(v) + result[k] = redacted + totalCount += count + } + return result, totalCount +} + +func (r *PIIRedactor) redactValue(v interface{}) (interface{}, int) { + switch val := v.(type) { + case string: + redacted, count := r.RedactText(val) + return redacted, count + case map[string]interface{}: + return r.RedactMap(val) + case []interface{}: + total := 0 + result := make([]interface{}, len(val)) + for i, item := range val { + redacted, count := r.redactValue(item) + result[i] = redacted + total += count + } + return result, total + default: + return v, 0 + } +} + +// CheckMapForPII checks if any value in a map contains PII. +func (r *PIIRedactor) CheckMapForPII(data map[string]interface{}) bool { + for _, v := range data { + if r.checkValue(v) { + return true + } + } + return false +} + +func (r *PIIRedactor) checkValue(v interface{}) bool { + switch val := v.(type) { + case string: + return r.ContainsPII(val) + case map[string]interface{}: + return r.CheckMapForPII(val) + case []interface{}: + for _, item := range val { + if r.checkValue(item) { + return true + } + } + } + return false +} + +// ============================================================================= +// PII Redactor Server +// ============================================================================= + +// RedactorServer implements the CATE webhook ServerInterface. +type RedactorServer struct { + redactor *PIIRedactor +} + +// HealthCheck implements webhook.ServerInterface. +func (s *RedactorServer) HealthCheck(c *gin.Context) { + status := server.Healthy + c.JSON(http.StatusOK, server.HealthResponse{Status: &status}) +} + +// AccessHook passes through - PII redaction doesn't affect tool visibility. +func (s *RedactorServer) AccessHook(c *gin.Context) { + // Allow all tools - PII redaction only applies to outputs + c.JSON(http.StatusOK, server.AccessHookResult{}) +} + +// PreHook passes through - PII redaction only applies to outputs. +func (s *RedactorServer) PreHook(c *gin.Context) { + // Allow all requests to proceed + c.JSON(http.StatusOK, server.PreHookResult{Code: server.OK}) +} + +// PostHook scans tool output for PII and redacts or blocks it. +// This is where the PII redaction happens - after the tool executes, +// before the response reaches the agent. +func (s *RedactorServer) PostHook(c *gin.Context) { + var req server.PostHookRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, server.ErrorResponse{ + Error: strPtr("invalid request: " + err.Error()), + }) + return + } + + if s.redactor.mode == "block" { + // Block mode: reject the entire response if PII is found + if s.redactor.CheckMapForPII(req.Output) { + msg := "Response blocked: contains personally identifiable information" + log.Printf("[POST] Blocked PII in output from %s.%s", + req.Tool.Toolkit, req.Tool.Name) + c.JSON(http.StatusOK, server.PostHookResult{ + Code: server.CHECKFAILED, + ErrorMessage: &msg, + }) + return + } + } else { + // Redact mode: replace PII with placeholders + redactedOutput, count := s.redactor.RedactMap(req.Output) + if count > 0 { + log.Printf("[POST] Redacted %d PII instances in output from %s.%s", + count, req.Tool.Toolkit, req.Tool.Name) + c.JSON(http.StatusOK, server.PostHookResult{ + Code: server.OK, + Override: &server.PostHookOverride{ + Output: &redactedOutput, + }, + }) + return + } + } + + // No PII found, pass through + c.JSON(http.StatusOK, server.PostHookResult{Code: server.OK}) +} + +func strPtr(s string) *string { + return &s +} + +// ============================================================================= +// Main +// ============================================================================= + +func main() { + port := flag.Int("port", 8889, "Port to listen on") + mode := flag.String("mode", "redact", "PII handling mode: 'redact' or 'block'") + flag.Parse() + + redactor := NewPIIRedactor(*mode) + srv := &RedactorServer{redactor: redactor} + + gin.SetMode(gin.ReleaseMode) + router := gin.New() + router.Use(gin.Recovery()) + + server.RegisterHandlers(router, srv) + + fmt.Println(strings.Repeat("=", 50)) + fmt.Println(" PII Redactor Example") + fmt.Println(strings.Repeat("=", 50)) + fmt.Printf(" Port: %d\n", *port) + fmt.Printf(" Mode: %s\n", *mode) + fmt.Println() + fmt.Println(" This example demonstrates:") + fmt.Println(" - Detecting PII in tool outputs (post-hook)") + if *mode == "redact" { + fmt.Println(" - Replacing PII with redaction markers") + } else { + fmt.Println(" - Blocking responses that contain PII") + } + fmt.Println() + fmt.Println(" Detected PII types:") + fmt.Println(" - Email addresses -> [EMAIL_REDACTED]") + fmt.Println(" - IP addresses -> [IP_REDACTED]") + fmt.Println(" - SSN -> [SSN_REDACTED]") + fmt.Println(" - Phone numbers -> [PHONE_REDACTED]") + fmt.Println(" - Credit card numbers-> [CC_REDACTED]") + fmt.Println(" - Dates of birth -> [DOB_REDACTED]") + fmt.Println(strings.Repeat("=", 50)) + + addr := fmt.Sprintf(":%d", *port) + if err := router.Run(addr); err != nil { + log.Fatal("Failed to start server:", err) + } +}