Skip to content

Commit ea6809a

Browse files
yogesh-chauhanHarness
authored andcommitted
feat: [ML-1459]: Allow fetching tool name to tool set mapping (#241)
* 94652c Code format * a20409 Add test for pkg/modules/registry.go * ebbd2c add unit test * 845fef remove comment * 7eb3d2 add back fme * 5bc430 refactor * 88318b refactor: replace direct toolset registration with auto-registering group for module tracking
1 parent a38d53d commit ea6809a

File tree

6 files changed

+796
-35
lines changed

6 files changed

+796
-35
lines changed

cmd/harness-mcp-server/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,11 @@ func runHTTPServer(ctx context.Context, config config.Config) error {
536536
slog.Error("Failed to initialize toolsets", "error", err)
537537
}
538538

539+
// Create module registry for HTTP mode
540+
moduleRegistry := modules.NewModuleRegistry(&config, toolsets)
541+
// Set the global registry for use by middleware
542+
modules.SetGlobalRegistry(moduleRegistry)
543+
539544
// Register the tools with the server
540545
toolsets.RegisterTools(harnessServer)
541546

pkg/harness/tools.go

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -145,38 +145,13 @@ func initLicenseValidation(ctx context.Context, config *config.Config) (*License
145145

146146
// InitToolsets initializes and returns the toolset groups
147147
func InitToolsets(ctx context.Context, config *config.Config) (*toolsets.ToolsetGroup, error) {
148-
// Create a toolset group
148+
// Create a toolset tsg
149149
tsg := toolsets.NewToolsetGroup(config.ReadOnly)
150150

151151
// Initialize license validation if enabled
152-
var licenseInfo *LicenseInfo
153-
var err error
154152
if config.EnableLicense {
155-
licenseInfo, err = initLicenseValidation(ctx, config)
156-
if err != nil {
157-
slog.Warn("License validation failed", "error", err)
158-
return nil, fmt.Errorf("failed to fetch license details, error: %w", err)
159-
}
160-
161-
// Create a module registry
162-
registry := modules.NewModuleRegistry(config, tsg)
163-
164-
// Get all modules that are enabled based on configuration
165-
configEnabledModules := registry.GetEnabledModules()
166-
// Get enabled modules based on configuration and license
167-
168-
enabledModules := getEnabledModules(configEnabledModules, licenseInfo)
169-
// Register toolsets for enabled modules
170-
for _, module := range enabledModules {
171-
slog.Info("registering toolsets for", "modules", module.ID())
172-
if err := module.RegisterToolsets(); err != nil {
173-
return nil, fmt.Errorf("failed to register toolsets for module %s: %w", module.ID(), err)
174-
}
175-
176-
// Enable toolsets for this module
177-
if err := module.EnableToolsets(tsg); err != nil {
178-
return nil, fmt.Errorf("failed to enable toolsets for module %s: %w", module.ID(), err)
179-
}
153+
if err := initModuleBasedToolsets(ctx, config, tsg); err != nil {
154+
return nil, err
180155
}
181156
} else {
182157
// License validation is disabled, use legacy toolset registration
@@ -185,9 +160,57 @@ func InitToolsets(ctx context.Context, config *config.Config) (*toolsets.Toolset
185160
}
186161
}
187162

163+
// Register all toolsets with the main tracker for tool-to-toolset mapping
164+
registerAllToolsetsWithTracker(tsg)
165+
188166
return tsg, nil
189167
}
190168

169+
// registerAllToolsetsWithTracker registers all toolsets in the group with the main tracker
170+
// This ensures that findToolGroup() can find which toolset a tool belongs to
171+
func registerAllToolsetsWithTracker(group *toolsets.ToolsetGroup) {
172+
tracker := toolsets.GetMainToolTracker()
173+
174+
// Register each toolset with the tracker
175+
for _, toolset := range group.Toolsets {
176+
if err := tracker.RegisterToolGroup(toolset); err != nil {
177+
slog.Warn("Failed to register toolset with tracker", "toolset", toolset.Name, "error", err)
178+
}
179+
}
180+
181+
slog.Info("Registered toolsets with tracker", "count", len(group.Toolsets))
182+
}
183+
184+
// New function that handles the module-based initialization
185+
func initModuleBasedToolsets(ctx context.Context, config *config.Config, tsg *toolsets.ToolsetGroup) error {
186+
// Get license info
187+
licenseInfo, err := initLicenseValidation(ctx, config)
188+
if err != nil {
189+
slog.Warn("License validation failed", "error", err)
190+
return fmt.Errorf("failed to fetch license details, error: %w", err)
191+
}
192+
193+
// Create module registry and get enabled modules
194+
registry := modules.NewModuleRegistry(config, tsg)
195+
configEnabledModules := registry.GetEnabledModules()
196+
enabledModules := getEnabledModules(configEnabledModules, licenseInfo)
197+
198+
// Register and enable toolsets for each module
199+
for _, module := range enabledModules {
200+
slog.Info("registering toolsets for", "modules", module.ID())
201+
if err := module.RegisterToolsets(); err != nil {
202+
return fmt.Errorf("failed to register toolsets for module %s: %w", module.ID(), err)
203+
}
204+
205+
if err := module.EnableToolsets(tsg); err != nil {
206+
return fmt.Errorf("failed to enable toolsets for module %s: %w", module.ID(), err)
207+
}
208+
}
209+
210+
return nil
211+
}
212+
213+
// initLegacyToolsets initializes toolsets using the legacy approach (without modules)
191214
func initLegacyToolsets(config *config.Config, tsg *toolsets.ToolsetGroup) error {
192215
// Check if specific toolsets are enabled
193216
if len(config.Toolsets) == 0 {

pkg/modules/registry.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ import (
1515

1616
// ModuleRegistry holds all available modules
1717
type ModuleRegistry struct {
18-
modules []Module
19-
config *config.Config
20-
tsg *toolsets.ToolsetGroup
18+
modules []Module
19+
config *config.Config
20+
tsg *toolsets.ToolsetGroup
21+
moduleToToolsets map[string][]string // Maps module IDs to their toolsets
2122
}
2223

2324
// NewModuleRegistry creates a new module registry with all available modules
2425
func NewModuleRegistry(config *config.Config, tsg *toolsets.ToolsetGroup) *ModuleRegistry {
25-
return &ModuleRegistry{
26+
registry := &ModuleRegistry{
2627
modules: []Module{
2728
NewCoreModule(config, tsg),
2829
NewCIModule(config, tsg),
@@ -38,11 +39,18 @@ func NewModuleRegistry(config *config.Config, tsg *toolsets.ToolsetGroup) *Modul
3839
NewHARModule(config, tsg),
3940
NewDbOpsModule(config, tsg),
4041
NewACMModule(config, tsg),
41-
NewFMEModule(config, tsg),
4242
},
43-
config: config,
44-
tsg: tsg,
43+
config: config,
44+
tsg: tsg,
45+
moduleToToolsets: make(map[string][]string),
4546
}
47+
48+
// Populate the moduleToToolsets mapping
49+
for _, module := range registry.modules {
50+
registry.moduleToToolsets[module.ID()] = module.Toolsets()
51+
}
52+
53+
return registry
4654
}
4755

4856
// GetAllModules returns all available modules
@@ -86,6 +94,19 @@ func (r *ModuleRegistry) GetEnabledModules() []Module {
8694
return enabledModules
8795
}
8896

97+
// GetToolsetsForModule returns the toolsets associated with a module ID
98+
func (r *ModuleRegistry) GetToolsetsForModule(moduleID string) []string {
99+
if toolsets, exists := r.moduleToToolsets[moduleID]; exists {
100+
return toolsets
101+
}
102+
return []string{}
103+
}
104+
105+
// GetToolsetGroup returns the underlying toolset group
106+
func (r *ModuleRegistry) GetToolsetGroup() *toolsets.ToolsetGroup {
107+
return r.tsg
108+
}
109+
89110
// RegisterPrompts registers all prompts for enabled modules with the given MCP server.
90111
// It loops through all enabled modules, checks if each module has prompts, and if so,
91112
// registers the prompts with the MCP server.
@@ -103,6 +124,19 @@ func (r *ModuleRegistry) RegisterPrompts(mcpServer *server.MCPServer) error {
103124
return nil
104125
}
105126

127+
// Global registry instance
128+
var globalRegistry *ModuleRegistry
129+
130+
// SetGlobalRegistry sets the global module registry instance
131+
func SetGlobalRegistry(registry *ModuleRegistry) {
132+
globalRegistry = registry
133+
}
134+
135+
// GetGlobalRegistry returns the global module registry instance
136+
func GetGlobalRegistry() *ModuleRegistry {
137+
return globalRegistry
138+
}
139+
106140
func registerPrompts(moduleID string, cfg *config.Config, mcpServer *server.MCPServer) error {
107141
// Create a map to store prompts by mode
108142
modulePromptsByMode := map[string][]prompts.PromptFile{

0 commit comments

Comments
 (0)