From f90782c74e26819e7d6afa32a99ada65d68d7a50 Mon Sep 17 00:00:00 2001 From: Omri Shtamberger Date: Sun, 20 Jul 2025 14:06:03 +0300 Subject: [PATCH 1/2] Add text_search and get_index_dialect tools --- src/common/config.py | 2 - src/tools/hash.py | 2 - src/tools/redis_query_engine.py | 242 ++++++++++++++++++++++++++++++++ 3 files changed, 242 insertions(+), 4 deletions(-) diff --git a/src/common/config.py b/src/common/config.py index 134ca8a..c3c606c 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -1,5 +1,3 @@ -import sys - from dotenv import load_dotenv import os import urllib.parse diff --git a/src/tools/hash.py b/src/tools/hash.py index 7f30a5a..49586af 100644 --- a/src/tools/hash.py +++ b/src/tools/hash.py @@ -1,5 +1,3 @@ -import sys - from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp diff --git a/src/tools/redis_query_engine.py b/src/tools/redis_query_engine.py index d7a6ae5..d845fc6 100644 --- a/src/tools/redis_query_engine.py +++ b/src/tools/redis_query_engine.py @@ -1,4 +1,5 @@ import json +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp @@ -136,3 +137,244 @@ async def vector_search_hash(query_vector: list, return [doc.__dict__ for doc in results.docs] except RedisError as e: return f"Error performing vector search on index '{index_name}': {str(e)}" + + +def _get_index_dialect(r, index_name: str) -> int: + """ + Get the dialect version of a Redis search index. + + Args: + r: Redis connection + index_name: Name of the index + + Returns: + int: Dialect version (1, 2, 3, or 4). Defaults to 1 if not detected. + """ + try: + # Get index info + info = r.ft(index_name).info() + + # Check if dialect is specified in index info + if 'dialect' in info: + return int(info['dialect']) + + # Try to detect dialect by checking for dialect-specific features + # This is a fallback method when dialect isn't explicitly stated + + # Check for dialect 4 features (introduced in Redis 7.2) + # Dialect 4 supports more advanced vector search features + if 'vector_fields' in info or any('VECTOR' in str(field) for field in info.get('attributes', [])): + return 4 + + # Check for dialect 3 features (introduced in Redis 7.0) + # Dialect 3 supports JSON path queries + if any('JSONPath' in str(field) for field in info.get('attributes', [])): + return 3 + + # Check for dialect 2 features (default in Redis 6.2+) + # Dialect 2 supports more query operators and syntax + if 'stopwords' in info or 'max_text_fields' in info: + return 2 + + # Default to dialect 1 (legacy) + return 1 + + except Exception: + # If we can't determine dialect, default to 1 (most compatible) + return 1 + + +@mcp.tool() +async def text_search(query_text: str, + index_name: str, + return_fields: list = None, + limit: int = 10, + offset: int = 0, + sort_by: str = None, + sort_ascending: bool = True, + dialect: int = None) -> str: + """ + Perform a general text search using Redis FT.SEARCH command with automatic dialect detection. + + This function allows you to search through indexed text fields using Redis Search. + It automatically detects the dialect of the index and adjusts the query accordingly. + RediSearch supports different dialects (1, 2, 3, 4) with varying query syntax and capabilities. + + Args: + query_text: The search query string. Syntax depends on dialect: + + Dialect 1 (Legacy): + - Simple terms: "hello world" + - Field search: "@title:redis" + - Phrase search: "\"exact phrase\"" + - Boolean: "redis search" (implicit AND) + + Dialect 2+ (Modern): + - All dialect 1 features plus: + - Explicit boolean: "redis AND search", "redis OR query" + - Negation: "redis -search" + - Wildcards: "hel*", "red?" + - Numeric ranges: "@price:[10 20]" + - Geo queries: "@location:[lng lat radius unit]" + + Dialect 3+ (JSON support): + - JSONPath queries: "@$.user.name:john" + + Dialect 4+ (Advanced vectors): + - Enhanced vector search syntax + + index_name: The name of the Redis search index to query against. + return_fields: Optional list of fields to return in results. If None, returns all fields. + limit: Maximum number of results to return (default: 10). + offset: Number of results to skip for pagination (default: 0). + sort_by: Optional field name to sort results by. + sort_ascending: Sort direction, True for ascending, False for descending (default: True). + dialect: Optional explicit dialect to use. If None, will auto-detect from index. + + Returns: + str: JSON string containing search results with document data, metadata, and dialect info, or error message. + + Example queries by dialect: + Dialect 1: "redis database", "@title:redis" + Dialect 2: "redis AND search", "@price:[10 50]", "red*" + Dialect 3: "@$.user.name:john AND @$.status:active" + Dialect 4: Enhanced vector and hybrid search queries + """ + try: + r = RedisConnectionManager.get_connection() + + # Determine dialect + if dialect is None: + detected_dialect = _get_index_dialect(r, index_name) + else: + detected_dialect = dialect + + # Build the query + query = Query(query_text) + + # Set dialect for the query + query = query.dialect(detected_dialect) + + # Add pagination + query = query.paging(offset, limit) + + # Add return fields if specified + if return_fields: + query = query.return_fields(*return_fields) + + # Add sorting if specified + if sort_by: + query = query.sort_by(sort_by, asc=sort_ascending) + + # Execute the search + results = r.ft(index_name).search(query) + + # Format the results + formatted_results = { + "total": results.total, + "docs": [doc.__dict__ for doc in results.docs], + "query": query_text, + "dialect": detected_dialect, + "offset": offset, + "limit": limit, + "index_name": index_name + } + + return json.dumps(formatted_results, indent=2) + + except RedisError as e: + return f"Error performing text search on index '{index_name}': {str(e)}" + + +@mcp.tool() +async def get_index_dialect(index_name: str) -> str: + """ + Get the dialect version of a Redis search index. + + RediSearch supports different dialects with varying capabilities: + - Dialect 1 (Legacy): Basic text search, field queries, phrase search + - Dialect 2 (Modern): Boolean operators, wildcards, numeric ranges, geo queries + - Dialect 3 (JSON): JSONPath queries, enhanced JSON support + - Dialect 4 (Advanced): Enhanced vector search, latest features + + Args: + index_name: The name of the Redis search index. + + Returns: + str: JSON string containing dialect information and capabilities, or error message. + """ + try: + r = RedisConnectionManager.get_connection() + + # Get the dialect + dialect = _get_index_dialect(r, index_name) + + # Get index info for additional context + info = r.ft(index_name).info() + + # Define capabilities by dialect + capabilities = { + 1: { + "description": "Legacy dialect - basic text search", + "features": [ + "Simple term search", + "Field-specific search (@field:value)", + "Phrase search (\"exact phrase\")", + "Implicit AND between terms" + ] + }, + 2: { + "description": "Modern dialect - enhanced query syntax", + "features": [ + "All dialect 1 features", + "Explicit boolean operators (AND, OR, NOT)", + "Negation (-term)", + "Wildcards (*, ?)", + "Numeric ranges (@field:[min max])", + "Geo queries (@location:[lng lat radius unit])", + "Parentheses for grouping" + ] + }, + 3: { + "description": "JSON dialect - JSONPath support", + "features": [ + "All dialect 2 features", + "JSONPath queries (@$.path:value)", + "Enhanced JSON field indexing", + "Nested object search" + ] + }, + 4: { + "description": "Advanced dialect - latest features", + "features": [ + "All dialect 3 features", + "Enhanced vector search syntax", + "Hybrid search capabilities", + "Latest RediSearch features" + ] + } + } + + result = { + "index_name": index_name, + "dialect": dialect, + "capabilities": capabilities.get(dialect, {"description": "Unknown dialect", "features": []}), + "index_info": { + "num_docs": info.get('num_docs', 0), + "max_doc_id": info.get('max_doc_id', 0), + "num_terms": info.get('num_terms', 0), + "num_records": info.get('num_records', 0), + "inverted_sz_mb": info.get('inverted_sz_mb', 0), + "vector_index_sz_mb": info.get('vector_index_sz_mb', 0), + "total_inverted_index_blocks": info.get('total_inverted_index_blocks', 0), + "offset_vectors_sz_mb": info.get('offset_vectors_sz_mb', 0), + "doc_table_size_mb": info.get('doc_table_size_mb', 0), + "sortable_values_size_mb": info.get('sortable_values_size_mb', 0), + "key_table_size_mb": info.get('key_table_size_mb', 0) + } + } + + return json.dumps(result, indent=2) + + except RedisError as e: + return f"Error getting dialect for index '{index_name}': {str(e)}" From 5a8fb81db71d87ac9804085bfff5f4e73acd1b80 Mon Sep 17 00:00:00 2001 From: chkp-omris Date: Wed, 8 Oct 2025 14:38:38 +0300 Subject: [PATCH 2/2] Streamable http (#2) * Implement streamable-http and add --transport flag with stdio or streamable-http * Multi hosts (#1) * Add connect, list_connections, disconnect, switch_default_connection tools to allow dynamically working with multiple redis deployments * Fix cluster mode * Ignore /{db} in --url when --cluster-mode is set. Unify connection config handling of CLI and coonection management. * Refactor config to class-based design and standardize connection management API with cluster_mode support * Fix cluster mode connection * Fix config building * Remove debug logs * Fix dbsize, scan_keys, scan_all_keys in cluster mode. Fix scan_all_keys, json_get, get_index_info, get_indexed_keys_number, hgetall, get_vector_from_hash decoding * Add missing import * Revert dbsize and scan operations special cluster mode handling. Fix client decoding * Auto detect cluster mode in connect tool * Correctly display db info in cluster mode * Fix JsonType --------- Co-authored-by: Omri Shtamberger --------- Co-authored-by: Omri Shtamberger --- README.md | 26 +- pyproject.toml | 4 +- src/__init__.py | 1 + src/common/__init__.py | 4 + src/common/config.py | 147 +++++++-- src/common/connection.py | 473 +++++++++++++++++++++++++---- src/common/server.py | 2 +- src/common/stdio_server.py | 11 + src/common/streaming_server.py | 18 ++ src/main.py | 89 +++--- src/tools/connection_management.py | 156 ++++++++++ src/tools/hash.py | 36 ++- src/tools/json.py | 13 +- src/tools/list.py | 89 ++++-- src/tools/misc.py | 39 ++- src/tools/redis_query_engine.py | 4 +- src/tools/server_management.py | 29 +- src/tools/set.py | 16 +- src/tools/string.py | 11 +- 19 files changed, 964 insertions(+), 204 deletions(-) create mode 100644 src/common/stdio_server.py create mode 100644 src/common/streaming_server.py create mode 100644 src/tools/connection_management.py diff --git a/README.md b/README.md index 7238c5d..aff1c5a 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Additional tools. ## Installation -The Redis MCP Server supports the `stdio` [transport](https://modelcontextprotocol.io/docs/concepts/transports#standard-input%2Foutput-stdio). Support to the `stremable-http` transport will be added in the future. +The Redis MCP Server supports both `stdio` and `streamable-http` [transports](https://modelcontextprotocol.io/docs/concepts/transports). > No PyPi package is available at the moment. @@ -89,10 +89,26 @@ uvx --from git+https://github.com/redis/mcp-redis.git redis-mcp-server --url "re # Run with individual parameters uvx --from git+https://github.com/redis/mcp-redis.git redis-mcp-server --host localhost --port 6379 --password mypassword +# Run with streamable HTTP transport (default on http://127.0.0.1:8000/mcp) +uvx --from git+https://github.com/redis/mcp-redis.git redis-mcp-server --transport streamable-http --url redis://localhost:6379/0 + +# Run with streamable HTTP on custom host/port +uvx --from git+https://github.com/redis/mcp-redis.git redis-mcp-server --transport streamable-http --http-host 0.0.0.0 --http-port 8080 --url redis://localhost:6379/0 + # See all options uvx --from git+https://github.com/redis/mcp-redis.git redis-mcp-server --help ``` +### Running with Streamable HTTP + +```sh +# Development mode with streamable HTTP +uv run redis-mcp-server --transport streamable-http --url redis://localhost:6379/0 + +# Production mode with custom host and port +uv run redis-mcp-server --transport streamable-http --http-host 0.0.0.0 --http-port 8000 --url redis://localhost:6379/0 +``` + ### Development Installation For development or if you prefer to clone the repository: @@ -110,6 +126,12 @@ uv sync # Run with CLI interface uv run redis-mcp-server --help +# Run with stdio transport (default) +uv run src/main.py + +# Run with streamable HTTP transport +uv run src/main.py --transport streamable-http --http-host 127.0.0.1 --http-port 8000 + # Or run the main file directly (uses environment variables) uv run src/main.py ``` @@ -365,7 +387,7 @@ The procedure will create the proper configuration in the `claude_desktop_config ### VS Code with GitHub Copilot -To use the Redis MCP Server with VS Code, you must nable the [agent mode](https://code.visualstudio.com/docs/copilot/chat/chat-agent-mode) tools. Add the following to your `settings.json`: +To use the Redis MCP Server with VS Code, you must enable the [agent mode](https://code.visualstudio.com/docs/copilot/chat/chat-agent-mode) tools. Add the following to your `settings.json`: ```json { diff --git a/pyproject.toml b/pyproject.toml index 2264189..eee87d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,11 @@ classifiers = [ dependencies = [ "mcp[cli]>=1.9.4", "redis>=6.0.0", - "dotenv>=0.9.9", + "python-dotenv>=0.9.9", "numpy>=2.2.4", "click>=8.0.0", + "uvicorn>=0.23.0", + "starlette>=0.27.0", ] [project.scripts] diff --git a/src/__init__.py b/src/__init__.py index e69de29..358ca3f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1 @@ +"""Redis MCP Server package.""" diff --git a/src/common/__init__.py b/src/common/__init__.py index e69de29..41d168a 100644 --- a/src/common/__init__.py +++ b/src/common/__init__.py @@ -0,0 +1,4 @@ +# Copyright Redis Contributors +# SPDX-License-Identifier: MIT + +"""Common utilities for Redis MCP Server.""" diff --git a/src/common/config.py b/src/common/config.py index c3c606c..446e325 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -4,18 +4,52 @@ load_dotenv() -REDIS_CFG = {"host": os.getenv('REDIS_HOST', '127.0.0.1'), - "port": int(os.getenv('REDIS_PORT',6379)), - "username": os.getenv('REDIS_USERNAME', None), - "password": os.getenv('REDIS_PWD',''), - "ssl": os.getenv('REDIS_SSL', False) in ('true', '1', 't'), - "ssl_ca_path": os.getenv('REDIS_SSL_CA_PATH', None), - "ssl_keyfile": os.getenv('REDIS_SSL_KEYFILE', None), - "ssl_certfile": os.getenv('REDIS_SSL_CERTFILE', None), - "ssl_cert_reqs": os.getenv('REDIS_SSL_CERT_REQS', 'required'), - "ssl_ca_certs": os.getenv('REDIS_SSL_CA_CERTS', None), - "cluster_mode": os.getenv('REDIS_CLUSTER_MODE', False) in ('true', '1', 't'), - "db": int(os.getenv('REDIS_DB', 0))} + +class RedisConfig: + """Redis configuration management class.""" + + def __init__(self): + self._config = { + "host": os.getenv('REDIS_HOST', '127.0.0.1'), + "port": int(os.getenv('REDIS_PORT', 6379)), + "username": os.getenv('REDIS_USERNAME', None), + "password": os.getenv('REDIS_PWD', ''), + "ssl": os.getenv('REDIS_SSL', False) in ('true', '1', 't'), + "ssl_ca_path": os.getenv('REDIS_SSL_CA_PATH', None), + "ssl_keyfile": os.getenv('REDIS_SSL_KEYFILE', None), + "ssl_certfile": os.getenv('REDIS_SSL_CERTFILE', None), + "ssl_cert_reqs": os.getenv('REDIS_SSL_CERT_REQS', 'required'), + "ssl_ca_certs": os.getenv('REDIS_SSL_CA_CERTS', None), + "cluster_mode": os.getenv('REDIS_CLUSTER_MODE', False) in ('true', '1', 't'), + "db": int(os.getenv('REDIS_DB', 0)) + } + + @property + def config(self) -> dict: + """Get the current configuration.""" + return self._config.copy() + + def get(self, key: str, default=None): + """Get a configuration value.""" + return self._config.get(key, default) + + def __getitem__(self, key: str): + """Get a configuration value using dictionary syntax.""" + return self._config[key] + + def update(self, config: dict): + """Update configuration from dictionary.""" + for key, value in config.items(): + if key in ['port', 'db']: + # Keep port and db as integers + self._config[key] = int(value) + elif key in ['ssl', 'cluster_mode']: + # Keep ssl and cluster_mode as booleans + self._config[key] = bool(value) + else: + # Store other values as-is + self._config[key] = value if value is not None else None + def parse_redis_uri(uri: str) -> dict: """Parse a Redis URI and return connection parameters.""" @@ -81,17 +115,78 @@ def parse_redis_uri(uri: str) -> dict: return config -def set_redis_config_from_cli(config: dict): - for key, value in config.items(): - if key in ['port', 'db']: - # Keep port and db as integers - REDIS_CFG[key] = int(value) - elif key == 'ssl' or key == 'cluster_mode': - # Keep ssl and cluster_mode as booleans - REDIS_CFG[key] = bool(value) - elif isinstance(value, bool): - # Convert other booleans to strings for environment compatibility - REDIS_CFG[key] = 'true' if value else 'false' - else: - # Convert other values to strings - REDIS_CFG[key] = str(value) if value is not None else None +def build_redis_config(url=None, host=None, port=None, db=None, username=None, + password=None, ssl=None, ssl_ca_path=None, ssl_keyfile=None, + ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None, + cluster_mode=None, host_id=None): + """ + Build Redis configuration from URL or individual parameters. + Handles cluster mode conflicts and parameter validation. + + Returns: + dict: Redis configuration dictionary + str: Generated host_id if not provided + """ + # Parse configuration from URL or individual parameters + if url: + config = parse_redis_uri(url) + parsed_url = urllib.parse.urlparse(url) + # Generate host_id from URL if not provided + if host_id is None: + host_id = f"{parsed_url.hostname}:{parsed_url.port or 6379}" + else: + # Build config from individual parameters + config = { + "host": host or "127.0.0.1", + "port": port or 6379, + "db": db or 0, + "username": username, + "password": password or "", + "ssl": ssl or False, + "ssl_ca_path": ssl_ca_path, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs or "required", + "ssl_ca_certs": ssl_ca_certs, + "cluster_mode": cluster_mode # Allow None for auto-detection + } + # Generate host_id from host:port if not provided + if host_id is None: + host_id = f"{config['host']}:{config['port']}" + + # Override individual parameters if provided (useful when using URL + specific overrides) + # Only override URL values if the parameter was explicitly specified + if url is None or (host is not None and host != "127.0.0.1"): + if host is not None: + config["host"] = host + if url is None or (port is not None and port != 6379): + if port is not None: + config["port"] = port + if url is None or (db is not None and db != 0): + if db is not None: + config["db"] = db + if username is not None: + config["username"] = username + if password is not None: + config["password"] = password + if ssl is not None: + config["ssl"] = ssl + if ssl_ca_path is not None: + config["ssl_ca_path"] = ssl_ca_path + if ssl_keyfile is not None: + config["ssl_keyfile"] = ssl_keyfile + if ssl_certfile is not None: + config["ssl_certfile"] = ssl_certfile + if ssl_cert_reqs is not None: + config["ssl_cert_reqs"] = ssl_cert_reqs + if ssl_ca_certs is not None: + config["ssl_ca_certs"] = ssl_ca_certs + if cluster_mode is not None: + config["cluster_mode"] = cluster_mode + + # Handle cluster mode conflicts + if config.get("cluster_mode", False): + # Remove db parameter in cluster mode as it's not supported + config.pop('db', None) + + return config, host_id diff --git a/src/common/connection.py b/src/common/connection.py index 848150c..c8c7041 100644 --- a/src/common/connection.py +++ b/src/common/connection.py @@ -1,77 +1,424 @@ import sys +import urllib.parse +from typing import Dict, Optional, Type, Union, Any +from enum import Enum from src.version import __version__ import redis from redis import Redis -from redis.cluster import RedisCluster -from typing import Optional, Type, Union -from src.common.config import REDIS_CFG +from redis.cluster import RedisCluster, ClusterNode -class RedisConnectionManager: - _instance: Optional[Redis] = None +def detect_cluster_mode(config: dict) -> bool: + """ + Detect if a Redis instance is running in cluster mode by connecting and checking INFO. + + Args: + config: Redis connection configuration dictionary + + Returns: + True if cluster mode is detected, False otherwise + """ + try: + # Create a temporary non-cluster connection to check INFO + temp_config = config.copy() + temp_config.pop('cluster_mode', None) # Remove cluster_mode to force standalone connection + + connection_params = { + "decode_responses": True, + "lib_name": f"redis-py(mcp-server_v{__version__})", + } + + # Add all config parameters except cluster_mode + for key, value in temp_config.items(): + if value is not None and key != "cluster_mode": + connection_params[key] = value + connection_params["max_connections"] = 10 + + # Create a temporary Redis connection + temp_redis = Redis(**connection_params) + + # Get server info to check cluster_enabled field + info = temp_redis.info("cluster") + cluster_enabled = info.get("cluster_enabled", 0) + + # Close the temporary connection + temp_redis.close() + + # cluster_enabled = 1 means cluster mode is enabled + return cluster_enabled == 1 + + except redis.exceptions.ResponseError as e: + # If we get "This instance has cluster support disabled", it's not a cluster + if "cluster support disabled" in str(e).lower(): + return False + # For other response errors, assume it's not a cluster + return False + except Exception: + # For any other connection issues, default to False + return False - @classmethod - def get_connection(cls, decode_responses=True) -> Redis: + +class DecodeResponsesType(Enum): + """Enum for decode_responses connection types.""" + DECODED = "decoded" # decode_responses=True + RAW = "raw" # decode_responses=False + + +class RedisConnectionPool: + """Manages multiple Redis connections identified by host identifier (Singleton).""" + + _instance = None + _initialized = False + + def __new__(cls): if cls._instance is None: + cls._instance = super(RedisConnectionPool, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + # Store connections with separate pools for DECODED/RAW + self._connections: Dict[str, Dict[DecodeResponsesType, Redis]] = {} # host_id -> {DECODED: conn, RAW: conn} + self._configs: Dict[str, dict] = {} # Store original configurations + self._default_host: Optional[str] = None + self._initialized = True + + def _create_connection_params(self, config: dict, decode_responses: bool = True) -> dict: + """Create connection parameters from config dictionary.""" + base_params = { + "decode_responses": decode_responses, + "lib_name": f"redis-py(mcp-server_v{__version__})", + } + + cluster_mode = config.get("cluster_mode", False) + + if cluster_mode: + # For cluster mode, we need to use startup_nodes instead of host/port + host = config.get("host", "127.0.0.1") + port = config.get("port", 6379) + startup_nodes = [ClusterNode(host=host, port=port)] + base_params["startup_nodes"] = startup_nodes + base_params["max_connections_per_node"] = 10 + + # Add cluster-specific parameters, excluding host, port, db + cluster_incompatible_keys = {"host", "port", "db", "cluster_mode"} + for key, value in config.items(): + if value is not None and key not in cluster_incompatible_keys: + base_params[key] = value + else: + # For non-cluster mode, add all config parameters except cluster_mode + for key, value in config.items(): + if value is not None and key != "cluster_mode": + base_params[key] = value + base_params["max_connections"] = 10 + + return base_params + + def _get_redis_class(self, cluster_mode: bool) -> Type[Union[Redis, RedisCluster]]: + """Get the appropriate Redis class based on cluster mode.""" + return redis.cluster.RedisCluster if cluster_mode else redis.Redis + + def add_connection(self, host_id: str, config: dict) -> str: + """Add a new Redis connection to the pool. Creates both RAW and DECODED connections.""" + try: + # Initialize connection dict for this host if not exists + if host_id not in self._connections: + self._connections[host_id] = {} + + # Auto-detect cluster mode if not explicitly specified + working_config = config.copy() + if "cluster_mode" not in config or config.get("cluster_mode") is None: + detected_cluster_mode = detect_cluster_mode(config) + working_config["cluster_mode"] = detected_cluster_mode + # Log the detection result + if detected_cluster_mode: + print(f"Auto-detected cluster mode for {host_id}") + else: + print(f"Auto-detected standalone mode for {host_id}") + + # Get the Redis class to use + redis_class = self._get_redis_class(working_config.get("cluster_mode", False)) + + # Create both DECODED and RAW connections + for decode_type in DecodeResponsesType: + is_decoded = (decode_type == DecodeResponsesType.DECODED) + + # Create connection parameters for this decode type + connection_params = self._create_connection_params(working_config, is_decoded) + + # Create and test the connection + connection = redis_class(**connection_params) + connection.ping() + + # Store the connection + self._connections[host_id][decode_type] = connection + + # Store the final config with detected cluster mode + # For cluster connections, don't preserve db in the stored config + config_to_store = working_config.copy() + if working_config.get("cluster_mode", False): + # Remove db from stored config for cluster connections + config_to_store.pop("db", None) + elif "db" not in config_to_store: + # For standalone connections, ensure db is set to default if not specified + config_to_store["db"] = config.get("db", 0) + + self._configs[host_id] = config_to_store + + # Set as default if it's the first connection + if self._default_host is None: + self._default_host = host_id + + cluster_status = "cluster" if working_config.get("cluster_mode", False) else "standalone" + return f"Successfully connected to Redis at {host_id} (both decoded and raw modes, {cluster_status})" + + except redis.exceptions.ConnectionError as e: + raise Exception(f"Failed to connect to Redis server at {host_id}: {e}") + except redis.exceptions.AuthenticationError as e: + raise Exception(f"Authentication failed for Redis server at {host_id}: {e}") + except redis.exceptions.TimeoutError as e: + raise Exception(f"Connection timed out for Redis server at {host_id}: {e}") + except redis.exceptions.ResponseError as e: + raise Exception(f"Response error for Redis server at {host_id}: {e}") + except redis.exceptions.RedisError as e: + raise Exception(f"Redis error for server at {host_id}: {e}") + except redis.exceptions.ClusterError as e: + raise Exception(f"Redis Cluster error for server at {host_id}: {e}") + except Exception as e: + raise Exception(f"Unexpected error connecting to Redis server at {host_id}: {e}") + + def get_connection(self, host_id: Optional[str] = None, decode_responses: bool = True) -> Redis: + """Get a Redis connection by host identifier.""" + if host_id is None: + host_id = self._default_host + + if host_id is None: + raise Exception("No Redis connections available. Use the 'connect' tool to establish a connection first.") + + if host_id not in self._connections: + raise Exception(f"No connection found for host '{host_id}'. Available hosts: {list(self._connections.keys())}") + + # Convert boolean to enum + decode_type = DecodeResponsesType.DECODED if decode_responses else DecodeResponsesType.RAW + + # Both connection types should always exist since add_connection creates both + if decode_type not in self._connections[host_id]: + raise Exception(f"Connection type {decode_type.value} not found for host '{host_id}'. This should not happen.") + + return self._connections[host_id][decode_type] + + def list_connections(self) -> Dict[str, dict]: + """List all active connections with their details.""" + result = {} + for host_id, conn_dict in self._connections.items(): try: - if REDIS_CFG["cluster_mode"]: - redis_class: Type[Union[Redis, RedisCluster]] = redis.cluster.RedisCluster - connection_params = { - "host": REDIS_CFG["host"], - "port": REDIS_CFG["port"], - "username": REDIS_CFG["username"], - "password": REDIS_CFG["password"], - "ssl": REDIS_CFG["ssl"], - "ssl_ca_path": REDIS_CFG["ssl_ca_path"], - "ssl_keyfile": REDIS_CFG["ssl_keyfile"], - "ssl_certfile": REDIS_CFG["ssl_certfile"], - "ssl_cert_reqs": REDIS_CFG["ssl_cert_reqs"], - "ssl_ca_certs": REDIS_CFG["ssl_ca_certs"], - "decode_responses": decode_responses, - "lib_name": f"redis-py(mcp-server_v{__version__})", - "max_connections_per_node": 10 + # Get info from the DECODED connection if available, fallback to RAW + conn = conn_dict.get(DecodeResponsesType.DECODED) or conn_dict.get(DecodeResponsesType.RAW) + if conn: + info = conn.info("server") + config = self._configs.get(host_id, {}) + + # Handle database information properly + is_cluster = config.get("cluster_mode", False) + if is_cluster: + # In cluster mode, database selection is not supported + db_info = "N/A (cluster)" + else: + # For standalone Redis, show the database number + db_info = config.get("db", getattr(conn, 'db', 0)) + if db_info == 'unknown': + db_info = 0 # Default to 0 if unknown + + result[host_id] = { + "status": "connected", + "redis_version": info.get("redis_version", "unknown"), + "host": config.get("host", getattr(conn, 'host', 'unknown')), + "port": config.get("port", getattr(conn, 'port', 'unknown')), + "db": db_info, + "cluster_mode": config.get("cluster_mode", False), + "ssl": config.get("ssl", False), + "is_default": host_id == self._default_host, + "available_modes": [decode_type.value for decode_type in conn_dict.keys()] } + except Exception as e: + config = self._configs.get(host_id, {}) + # Handle database information properly even in error case + is_cluster = config.get("cluster_mode", False) + if is_cluster: + db_info = "N/A (cluster)" else: - redis_class: Type[Union[Redis, RedisCluster]] = redis.Redis - connection_params = { - "host": REDIS_CFG["host"], - "port": REDIS_CFG["port"], - "db": REDIS_CFG["db"], - "username": REDIS_CFG["username"], - "password": REDIS_CFG["password"], - "ssl": REDIS_CFG["ssl"], - "ssl_ca_path": REDIS_CFG["ssl_ca_path"], - "ssl_keyfile": REDIS_CFG["ssl_keyfile"], - "ssl_certfile": REDIS_CFG["ssl_certfile"], - "ssl_cert_reqs": REDIS_CFG["ssl_cert_reqs"], - "ssl_ca_certs": REDIS_CFG["ssl_ca_certs"], - "decode_responses": decode_responses, - "lib_name": f"redis-py(mcp-server_v{__version__})", - "max_connections": 10 - } + db_info = config.get("db", 0) + if db_info == 'unknown': + db_info = 0 + + result[host_id] = { + "status": f"error: {e}", + "host": config.get("host", "unknown"), + "port": config.get("port", "unknown"), + "db": db_info, + "cluster_mode": config.get("cluster_mode", False), + "ssl": config.get("ssl", False), + "is_default": host_id == self._default_host, + "available_modes": [decode_type.value for decode_type in conn_dict.keys()] if host_id in self._connections else [] + } + return result + + def get_connection_details(self, host_id: Optional[str] = None) -> Dict[str, Any]: + """Get details for a specific connection or the default connection.""" + if host_id is None: + host_id = self._default_host + + if host_id is None: + return {"error": "No Redis connections available"} + + if host_id not in self._connections: + available = list(self._connections.keys()) + return {"error": f"Connection '{host_id}' not found. Available connections: {available}"} + + conn_dict = self._connections[host_id] + config = self._configs.get(host_id, {}) + + try: + # Get info from the DECODED connection if available, fallback to RAW + conn = conn_dict.get(DecodeResponsesType.DECODED) or conn_dict.get(DecodeResponsesType.RAW) + if conn: + info = conn.info("server") - cls._instance = redis_class(**connection_params) + # Handle database information properly + is_cluster = config.get("cluster_mode", False) + if is_cluster: + # In cluster mode, database selection is not supported + db_info = "N/A (cluster)" + else: + # For standalone Redis, show the database number + db_info = config.get("db", getattr(conn, 'db', 0)) + if db_info == 'unknown': + db_info = 0 # Default to 0 if unknown + + return { + "host_id": host_id, + "status": "connected", + "redis_version": info.get("redis_version", "unknown"), + "host": config.get("host", getattr(conn, 'host', 'unknown')), + "port": config.get("port", getattr(conn, 'port', 'unknown')), + "db": db_info, + "cluster_mode": config.get("cluster_mode", False), + "ssl": config.get("ssl", False), + "is_default": host_id == self._default_host, + "available_modes": [decode_type.value for decode_type in conn_dict.keys()] + } + except Exception as e: + # Handle database information properly even in error case + is_cluster = config.get("cluster_mode", False) + if is_cluster: + db_info = "N/A (cluster)" + else: + db_info = config.get("db", 0) + if db_info == 'unknown': + db_info = 0 + + return { + "host_id": host_id, + "status": f"error: {e}", + "host": config.get("host", "unknown"), + "port": config.get("port", "unknown"), + "db": db_info, + "cluster_mode": config.get("cluster_mode", False), + "ssl": config.get("ssl", False), + "is_default": host_id == self._default_host, + "available_modes": [decode_type.value for decode_type in conn_dict.keys()] if host_id in self._connections else [] + } + + def remove_connection(self, host_id: str) -> str: + """Remove a connection from the pool.""" + if host_id not in self._connections: + return f"No connection found for host '{host_id}'" + + try: + # Close all connections for this host (both DECODED and RAW) + conn_dict = self._connections[host_id] + for decode_type, conn in conn_dict.items(): + try: + conn.close() + except: + pass # Ignore close errors + except: + pass # Ignore close errors + + # Remove both connection and config + del self._connections[host_id] + self._configs.pop(host_id, None) # Remove config, ignore if not found + + # Update default if needed + if self._default_host == host_id: + self._default_host = next(iter(self._connections.keys())) if self._connections else None + + return f"Connection to '{host_id}' removed successfully" + + @classmethod + def get_instance(cls) -> 'RedisConnectionPool': + """Get the singleton instance.""" + return cls() + + @classmethod + def add_connection_to_pool(cls, host_id: str, config: dict) -> str: + """Class method to add a connection to the singleton pool.""" + return cls.get_instance().add_connection(host_id, config) + + @classmethod + def get_connection_from_pool(cls, host_id: Optional[str] = None, decode_responses: bool = True) -> Redis: + """Class method to get a connection from the singleton pool.""" + return cls.get_instance().get_connection(host_id, decode_responses) + + @classmethod + def list_connections_in_pool(cls) -> Dict[str, dict]: + """Class method to list all connections in the singleton pool.""" + return cls.get_instance().list_connections() + + @classmethod + def remove_connection_from_pool(cls, host_id: str) -> str: + """Class method to remove a connection from the singleton pool.""" + return cls.get_instance().remove_connection(host_id) + + @classmethod + def get_connection_details_from_pool(cls, host_id: Optional[str] = None) -> Dict[str, Any]: + """Class method to get connection details from the singleton pool.""" + return cls.get_instance().get_connection_details(host_id) - except redis.exceptions.ConnectionError: - print("Failed to connect to Redis server", file=sys.stderr) - raise - except redis.exceptions.AuthenticationError: - print("Authentication failed", file=sys.stderr) - raise - except redis.exceptions.TimeoutError: - print("Connection timed out", file=sys.stderr) - raise - except redis.exceptions.ResponseError as e: - print(f"Response error: {e}", file=sys.stderr) - raise - except redis.exceptions.RedisError as e: - print(f"Redis error: {e}", file=sys.stderr) - raise - except redis.exceptions.ClusterError as e: - print(f"Redis Cluster error: {e}", file=sys.stderr) - raise - except Exception as e: - print(f"Unexpected error: {e}", file=sys.stderr) - raise +def get_connection(host_id: Optional[str] = None, decode_responses: bool = True) -> Redis: + """Get a Redis connection by host identifier (legacy function).""" + return RedisConnectionPool.get_connection_from_pool(host_id, decode_responses) - return cls._instance +def get_connection_pool() -> RedisConnectionPool: + """Get the connection pool instance (legacy function).""" + return RedisConnectionPool.get_instance() + + +class RedisConnectionManager: + """Compatibility wrapper for the connection pool.""" + + @classmethod + def get_connection(cls, host_id: Optional[str] = None, decode_responses=True) -> Redis: + """Get a connection for the specified host or the default connection.""" + pool = RedisConnectionPool.get_instance() + + # Get the host_id for the connection + if host_id is None: + host_id = pool._default_host + + # Initialize default connection if none exists and no specific host_id requested + if not pool._connections and host_id is None: + # Create default configuration from environment variables + from src.common.config import RedisConfig + default_config = RedisConfig() + default_host_id = f"{default_config['host']}:{default_config['port']}" + pool.add_connection(default_host_id, default_config.config) + host_id = default_host_id + + # Use the pool's get_connection method which handles both decode_responses types + return pool.get_connection(host_id, decode_responses) + + @classmethod + def get_pool(cls) -> RedisConnectionPool: + """Get the connection pool instance.""" + return RedisConnectionPool.get_instance() diff --git a/src/common/server.py b/src/common/server.py index eb9609e..eb78abc 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -3,6 +3,6 @@ # Initialize FastMCP server mcp = FastMCP( "Redis MCP Server", - dependencies=["redis", "dotenv", "numpy"] + dependencies=["redis", "dotenv", "numpy"], ) diff --git a/src/common/stdio_server.py b/src/common/stdio_server.py new file mode 100644 index 0000000..15db963 --- /dev/null +++ b/src/common/stdio_server.py @@ -0,0 +1,11 @@ +# Copyright Redis Contributors +# SPDX-License-Identifier: MIT + + +async def serve_stdio() -> None: + """Serve the MCP server using stdio transport.""" + # Import the existing FastMCP server + from src.common.server import mcp + + # FastMCP handles stdio transport natively + await mcp.run_stdio_async() diff --git a/src/common/streaming_server.py b/src/common/streaming_server.py new file mode 100644 index 0000000..0b97a84 --- /dev/null +++ b/src/common/streaming_server.py @@ -0,0 +1,18 @@ +# Copyright Redis Contributors +# SPDX-License-Identifier: MIT + + +async def serve_streaming( + host: str = '0.0.0.0', + port: int = 8000, +) -> None: + """Serve the MCP server using streaming (SSE/Streamable HTTP) transport.""" + # Import the existing FastMCP server + from src.common.server import mcp + + # Update host and port settings + mcp.settings.host = host + mcp.settings.port = port + + # FastMCP handles streamable HTTP transport natively + await mcp.run_streamable_http_async() diff --git a/src/main.py b/src/main.py index 58f427c..9cfed3b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,8 +1,11 @@ import sys import click -from src.common.connection import RedisConnectionManager -from src.common.server import mcp -from src.common.config import parse_redis_uri, set_redis_config_from_cli +import asyncio + +from src.common.config import build_redis_config +from src.common.connection import RedisConnectionPool +from src.common.stdio_server import serve_stdio +from src.common.streaming_server import serve_streaming import src.tools.server_management import src.tools.misc import src.tools.redis_query_engine @@ -14,17 +17,14 @@ import src.tools.set import src.tools.stream import src.tools.pub_sub - - -class RedisMCPServer: - def __init__(self): - print("Starting the Redis MCP Server", file=sys.stderr) - - def run(self): - mcp.run() +import src.tools.connection_management @click.command() +@click.option('--transport', default='stdio', type=click.Choice(['stdio', 'streamable-http']), + help='Transport method (stdio or streamable-http)') +@click.option('--http-host', default='127.0.0.1', help='HTTP server host (for streamable-http transport)') +@click.option('--http-port', default=8000, type=int, help='HTTP server port (for streamable-http transport)') @click.option('--url', help='Redis connection URI (redis://user:pass@host:port/db or rediss:// for SSL)') @click.option('--host', default='127.0.0.1', help='Redis host') @click.option('--port', default=6379, type=int, help='Redis port') @@ -38,55 +38,40 @@ def run(self): @click.option('--ssl-cert-reqs', default='required', help='SSL certificate requirements') @click.option('--ssl-ca-certs', help='Path to CA certificates file') @click.option('--cluster-mode', is_flag=True, help='Enable Redis cluster mode') -def cli(url, host, port, db, username, password, +def cli(transport, http_host, http_port, url, host, port, db, username, password, ssl, ssl_ca_path, ssl_keyfile, ssl_certfile, ssl_cert_reqs, ssl_ca_certs, cluster_mode): """Redis MCP Server - Model Context Protocol server for Redis.""" + + try: + # Build configuration using unified logic - URL takes precedence but individual params can override + config, host_id = build_redis_config( + url=url, host=host, port=port, db=db, username=username, password=password, + ssl=ssl, ssl_ca_path=ssl_ca_path, ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, cluster_mode=cluster_mode + ) + + # Add connection directly to pool + RedisConnectionPool.add_connection_to_pool(host_id, config) + + except ValueError as e: + click.echo(f"Error parsing Redis configuration: {e}", err=True) + sys.exit(1) + except Exception as e: + click.echo(f"Error connecting to Redis: {e}", err=True) + sys.exit(1) - # Handle Redis URI if provided - if url: - try: - uri_config = parse_redis_uri(url) - set_redis_config_from_cli(uri_config) - except ValueError as e: - click.echo(f"Error parsing Redis URI: {e}", err=True) - sys.exit(1) + # Start the appropriate server + if transport == "streamable-http": + asyncio.run(serve_streaming(host=http_host, port=http_port)) else: - # Set individual Redis parameters - config = { - 'host': host, - 'port': port, - 'db': db, - 'ssl': ssl, - 'cluster_mode': cluster_mode - } - - if username: - config['username'] = username - if password: - config['password'] = password - if ssl_ca_path: - config['ssl_ca_path'] = ssl_ca_path - if ssl_keyfile: - config['ssl_keyfile'] = ssl_keyfile - if ssl_certfile: - config['ssl_certfile'] = ssl_certfile - if ssl_cert_reqs: - config['ssl_cert_reqs'] = ssl_cert_reqs - if ssl_ca_certs: - config['ssl_ca_certs'] = ssl_ca_certs - - set_redis_config_from_cli(config) - - # Start the server - server = RedisMCPServer() - server.run() + asyncio.run(serve_stdio()) def main(): - """Legacy main function for backward compatibility.""" - server = RedisMCPServer() - server.run() + """Main entry point for backward compatibility.""" + cli() if __name__ == "__main__": diff --git a/src/tools/connection_management.py b/src/tools/connection_management.py new file mode 100644 index 0000000..c1e6993 --- /dev/null +++ b/src/tools/connection_management.py @@ -0,0 +1,156 @@ +from typing import Optional, Dict, Any +from src.common.connection import RedisConnectionPool +from src.common.config import build_redis_config +from src.common.server import mcp +import urllib.parse + + +@mcp.tool() +async def connect( + url: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + db: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + ssl: Optional[bool] = None, + ssl_ca_path: Optional[str] = None, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + ssl_cert_reqs: Optional[str] = None, + ssl_ca_certs: Optional[str] = None, + cluster_mode: Optional[bool] = None, + host_id: Optional[str] = None +) -> str: + """Connect to a Redis server and add it to the connection pool. + + Args: + url: Redis connection URI (redis://user:pass@host:port/db or rediss:// for SSL) + host: Redis host (default: 127.0.0.1) + port: Redis port (default: 6379) + db: Redis database number (default: 0) + username: Redis username + password: Redis password + ssl: Use SSL connection + ssl_ca_path: Path to CA certificate file + ssl_keyfile: Path to SSL key file + ssl_certfile: Path to SSL certificate file + ssl_cert_reqs: SSL certificate requirements (default: required) + ssl_ca_certs: Path to CA certificates file + cluster_mode: Enable Redis cluster mode + host_id: Custom identifier for this connection (auto-generated if not provided) + + Returns: + Success message with connection details or error message. + """ + try: + # Build configuration using unified logic + config, generated_host_id = build_redis_config( + url=url, host=host, port=port, db=db, username=username, + password=password, ssl=ssl, ssl_ca_path=ssl_ca_path, + ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, + cluster_mode=cluster_mode, host_id=host_id + ) + + # Use provided host_id or generated one + final_host_id = host_id or generated_host_id + + # Add connection to pool + result = RedisConnectionPool.add_connection_to_pool(final_host_id, config) + + return f"{result}. Host identifier: '{final_host_id}'" + + except Exception as e: + return f"Failed to connect to Redis: {str(e)}" + + +@mcp.tool() +async def list_connections() -> Dict[str, Any]: + """List all active Redis connections in the pool. + + Returns: + Dictionary containing details of all active connections. + """ + try: + connections = RedisConnectionPool.list_connections_in_pool() + + if not connections: + return {"message": "No active connections", "connections": {}} + + return { + "message": f"Found {len(connections)} active connection(s)", + "connections": connections + } + + except Exception as e: + return {"error": f"Failed to list connections: {str(e)}"} + + +@mcp.tool() +async def disconnect(host_id: str) -> str: + """Disconnect from a Redis server and remove it from the connection pool. + + Args: + host_id: The identifier of the connection to remove + + Returns: + Success message or error message. + """ + try: + result = RedisConnectionPool.remove_connection_from_pool(host_id) + return result + + except Exception as e: + return f"Failed to disconnect from {host_id}: {str(e)}" + + +@mcp.tool() +async def switch_default_connection(host_id: str) -> str: + """Switch the default connection to a different host. + + Args: + host_id: The identifier of the connection to set as default + + Returns: + Success message or error message. + """ + try: + pool = RedisConnectionPool.get_instance() + + # Check if connection exists + if host_id not in pool._connections: + available = list(pool._connections.keys()) + return f"Connection '{host_id}' not found. Available connections: {available}" + + # Set as default + pool._default_host = host_id + return f"Default connection switched to '{host_id}'" + + except Exception as e: + return f"Failed to switch default connection: {str(e)}" + + +@mcp.tool() +async def get_connection(host_id: Optional[str] = None) -> Dict[str, Any]: + """Get details for a specific Redis connection or the default connection. + + Args: + host_id: The identifier of the connection to get details for. If not provided, uses the default connection. + + Returns: + Dictionary containing connection details or error message. + """ + try: + details = RedisConnectionPool.get_connection_details_from_pool(host_id) + + if "error" in details: + return {"error": details["error"]} + + return { + "message": f"Connection details for '{details['host_id']}'", + "connection": details + } + + except Exception as e: + return {"error": f"Failed to get connection details: {str(e)}"} diff --git a/src/tools/hash.py b/src/tools/hash.py index 49586af..0e90c21 100644 --- a/src/tools/hash.py +++ b/src/tools/hash.py @@ -1,3 +1,4 @@ +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp @@ -5,7 +6,7 @@ @mcp.tool() -async def hset(name: str, key: str, value: str | int | float, expire_seconds: int = None) -> str: +async def hset(name: str, key: str, value: str | int | float, expire_seconds: int = None, host_id: Optional[str] = None) -> str: """Set a field in a hash stored at key with an optional expiration time. Args: @@ -13,12 +14,13 @@ async def hset(name: str, key: str, value: str | int | float, expire_seconds: in key: The field name inside the hash. value: The value to set. expire_seconds: Optional; time in seconds after which the key should expire. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A success message or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) r.hset(name, key, str(value)) if expire_seconds is not None: @@ -30,89 +32,94 @@ async def hset(name: str, key: str, value: str | int | float, expire_seconds: in return f"Error setting field '{key}' in hash '{name}': {str(e)}" @mcp.tool() -async def hget(name: str, key: str) -> str: +async def hget(name: str, key: str, host_id: Optional[str] = None) -> str: """Get the value of a field in a Redis hash. Args: name: The Redis hash key. key: The field name inside the hash. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: The field value or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) value = r.hget(name, key) return value if value else f"Field '{key}' not found in hash '{name}'." except RedisError as e: return f"Error getting field '{key}' from hash '{name}': {str(e)}" @mcp.tool() -async def hdel(name: str, key: str) -> str: +async def hdel(name: str, key: str, host_id: Optional[str] = None) -> str: """Delete a field from a Redis hash. Args: name: The Redis hash key. key: The field name inside the hash. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A success message or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) deleted = r.hdel(name, key) return f"Field '{key}' deleted from hash '{name}'." if deleted else f"Field '{key}' not found in hash '{name}'." except RedisError as e: return f"Error deleting field '{key}' from hash '{name}': {str(e)}" @mcp.tool() -async def hgetall(name: str) -> dict: +async def hgetall(name: str, host_id: Optional[str] = None) -> dict: """Get all fields and values from a Redis hash. Args: name: The Redis hash key. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A dictionary of field-value pairs or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) hash_data = r.hgetall(name) return {k: v for k, v in hash_data.items()} if hash_data else f"Hash '{name}' is empty or does not exist." except RedisError as e: return f"Error getting all fields from hash '{name}': {str(e)}" @mcp.tool() -async def hexists(name: str, key: str) -> bool: +async def hexists(name: str, key: str, host_id: Optional[str] = None) -> bool: """Check if a field exists in a Redis hash. Args: name: The Redis hash key. key: The field name inside the hash. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: True if the field exists, False otherwise. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) return r.hexists(name, key) except RedisError as e: return f"Error checking existence of field '{key}' in hash '{name}': {str(e)}" @mcp.tool() -async def set_vector_in_hash(name: str, vector: list, vector_field: str = "vector") -> bool: +async def set_vector_in_hash(name: str, vector: list, vector_field: str = "vector", host_id: Optional[str] = None) -> bool: """Store a vector as a field in a Redis hash. Args: name: The Redis hash key. vector_field: The field name inside the hash. Unless specifically required, use the default field name vector: The vector (list of numbers) to store in the hash. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: True if the vector was successfully stored, False otherwise. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) # Convert the vector to a NumPy array, then to a binary blob using np.float32 vector_array = np.array(vector, dtype=np.float32) @@ -125,18 +132,19 @@ async def set_vector_in_hash(name: str, vector: list, vector_field: str = "vecto @mcp.tool() -async def get_vector_from_hash(name: str, vector_field: str = "vector"): +async def get_vector_from_hash(name: str, vector_field: str = "vector", host_id: Optional[str] = None): """Retrieve a vector from a Redis hash and convert it back from binary blob. Args: name: The Redis hash key. vector_field: The field name inside the hash. Unless specifically required, use the default field name + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: The vector as a list of floats, or an error message if retrieval fails. """ try: - r = RedisConnectionManager.get_connection(decode_responses=False) + r = RedisConnectionManager.get_connection(host_id, decode_responses=False) # Retrieve the binary blob stored in the hash binary_blob = r.hget(name, vector_field) diff --git a/src/tools/json.py b/src/tools/json.py index 892017b..c8ff480 100644 --- a/src/tools/json.py +++ b/src/tools/json.py @@ -1,7 +1,16 @@ +from typing import Union, Mapping, Optional, List, TYPE_CHECKING from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp -from redis.commands.json._util import JsonType +# Define JsonType for type checking to match redis-py definition +# Use object as runtime type to avoid issubclass() issues with Any in Python 3.10 +if TYPE_CHECKING: + JsonType = Union[ + str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"] + ] +else: + # Use object at runtime to avoid MCP framework issubclass() issues + JsonType = object @mcp.tool() @@ -31,7 +40,7 @@ async def json_set(name: str, path: str, value: JsonType, expire_seconds: int = @mcp.tool() -async def json_get(name: str, path: str = '$') -> str: +async def json_get(name: str, path: str = '$') -> str | Optional[List[JsonType]]: """Retrieve a JSON value from Redis at a given path. Args: diff --git a/src/tools/list.py b/src/tools/list.py index 8938400..eaa7fa3 100644 --- a/src/tools/list.py +++ b/src/tools/list.py @@ -1,14 +1,25 @@ import json +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp from redis.typing import FieldT @mcp.tool() -async def lpush(name: str, value: FieldT, expire: int = None) -> str: - """Push a value onto the left of a Redis list and optionally set an expiration time.""" +async def lpush(name: str, value: FieldT, expire: int = None, host_id: Optional[str] = None) -> str: + """Push a value onto the left of a Redis list and optionally set an expiration time. + + Args: + name: The Redis list key. + value: The value to push. + expire: Optional expiration time in seconds. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + A success message or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) r.lpush(name, value) if expire: r.expire(name, expire) @@ -17,10 +28,20 @@ async def lpush(name: str, value: FieldT, expire: int = None) -> str: return f"Error pushing value to list '{name}': {str(e)}" @mcp.tool() -async def rpush(name: str, value: FieldT, expire: int = None) -> str: - """Push a value onto the right of a Redis list and optionally set an expiration time.""" +async def rpush(name: str, value: FieldT, expire: int = None, host_id: Optional[str] = None) -> str: + """Push a value onto the right of a Redis list and optionally set an expiration time. + + Args: + name: The Redis list key. + value: The value to push. + expire: Optional expiration time in seconds. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + A success message or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) r.rpush(name, value) if expire: r.expire(name, expire) @@ -29,34 +50,56 @@ async def rpush(name: str, value: FieldT, expire: int = None) -> str: return f"Error pushing value to list '{name}': {str(e)}" @mcp.tool() -async def lpop(name: str) -> str: - """Remove and return the first element from a Redis list.""" +async def lpop(name: str, host_id: Optional[str] = None) -> str: + """Remove and return the first element from a Redis list. + + Args: + name: The Redis list key. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + The popped value or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) value = r.lpop(name) return value if value else f"List '{name}' is empty or does not exist." except RedisError as e: return f"Error popping value from list '{name}': {str(e)}" @mcp.tool() -async def rpop(name: str) -> str: - """Remove and return the last element from a Redis list.""" +async def rpop(name: str, host_id: Optional[str] = None) -> str: + """Remove and return the last element from a Redis list. + + Args: + name: The Redis list key. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + The popped value or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) value = r.rpop(name) return value if value else f"List '{name}' is empty or does not exist." except RedisError as e: return f"Error popping value from list '{name}': {str(e)}" @mcp.tool() -async def lrange(name: str, start: int, stop: int) -> list: +async def lrange(name: str, start: int, stop: int, host_id: Optional[str] = None) -> list: """Get elements from a Redis list within a specific range. - Returns: - str: A JSON string containing the list of elements or an error message. + Args: + name: The Redis list key. + start: Starting index. + stop: Ending index. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + A JSON string containing the list of elements or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) values = r.lrange(name, start, stop) if not values: return f"List '{name}' is empty or does not exist." @@ -66,10 +109,18 @@ async def lrange(name: str, start: int, stop: int) -> list: return f"Error retrieving values from list '{name}': {str(e)}" @mcp.tool() -async def llen(name: str) -> int: - """Get the length of a Redis list.""" +async def llen(name: str, host_id: Optional[str] = None) -> int: + """Get the length of a Redis list. + + Args: + name: The Redis list key. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + The length of the list or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) return r.llen(name) except RedisError as e: return f"Error retrieving length of list '{name}': {str(e)}" diff --git a/src/tools/misc.py b/src/tools/misc.py index 8f2390b..f1cf8ce 100644 --- a/src/tools/misc.py +++ b/src/tools/misc.py @@ -134,8 +134,18 @@ async def scan_keys(pattern: str = "*", count: int = 100, cursor: int = 0) -> di r = RedisConnectionManager.get_connection() cursor, keys = r.scan(cursor=cursor, match=pattern, count=count) - # Convert bytes to strings if needed - decoded_keys = [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys] + # Convert bytes to strings if needed - safer decoding + decoded_keys = [] + for key in keys: + if isinstance(key, bytes): + try: + decoded_keys.append(key.decode('utf-8')) + except UnicodeDecodeError: + decoded_keys.append(key.decode('utf-8', errors='replace')) + elif isinstance(key, str): + decoded_keys.append(key) + else: + decoded_keys.append(str(key)) return { 'cursor': cursor, @@ -171,11 +181,28 @@ async def scan_all_keys(pattern: str = "*", batch_size: int = 100) -> list: cursor = 0 while True: - cursor, keys = r.scan(cursor=cursor, match=pattern, count=batch_size) + scan_result = r.scan(cursor=cursor, match=pattern, count=batch_size) + + # Handle different return formats + if isinstance(scan_result, tuple) and len(scan_result) == 2: + cursor, keys = scan_result + else: + break # Convert bytes to strings if needed and add to results - decoded_keys = [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys] - all_keys.extend(decoded_keys) + if keys: + decoded_keys = [] + for key in keys: + if isinstance(key, bytes): + try: + decoded_keys.append(key.decode('utf-8')) + except UnicodeDecodeError: + decoded_keys.append(key.decode('utf-8', errors='replace')) + elif isinstance(key, str): + decoded_keys.append(key) + else: + decoded_keys.append(str(key)) + all_keys.extend(decoded_keys) # Break when scan is complete (cursor returns to 0) if cursor == 0: @@ -183,4 +210,6 @@ async def scan_all_keys(pattern: str = "*", batch_size: int = 100) -> list: return all_keys except RedisError as e: + return f"Error scanning all keys with pattern '{pattern}': {str(e)}" + except Exception as e: return f"Error scanning all keys with pattern '{pattern}': {str(e)}" \ No newline at end of file diff --git a/src/tools/redis_query_engine.py b/src/tools/redis_query_engine.py index d845fc6..992e450 100644 --- a/src/tools/redis_query_engine.py +++ b/src/tools/redis_query_engine.py @@ -24,7 +24,7 @@ async def get_indexes() -> str: @mcp.tool() -async def get_index_info(index_name: str) -> str: +async def get_index_info(index_name: str) -> str | dict: """Retrieve schema and information about a specific Redis index using FT.INFO. Args: @@ -41,7 +41,7 @@ async def get_index_info(index_name: str) -> str: @mcp.tool() -async def get_indexed_keys_number(index_name: str) -> str: +async def get_indexed_keys_number(index_name: str) -> int: """Retrieve the number of indexed keys by the index Args: diff --git a/src/tools/server_management.py b/src/tools/server_management.py index 19e9be8..b826c8c 100644 --- a/src/tools/server_management.py +++ b/src/tools/server_management.py @@ -1,30 +1,38 @@ +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp @mcp.tool() -async def dbsize() -> int: +async def dbsize(host_id: Optional[str] = None) -> str | int: """Get the number of keys stored in the Redis database + + Args: + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + The number of keys or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) return r.dbsize() except RedisError as e: return f"Error getting database size: {str(e)}" @mcp.tool() -async def info(section: str = "default") -> dict: +async def info(section: str = "default", host_id: Optional[str] = None) -> dict: """Get Redis server information and statistics. Args: section: The section of the info command (default, memory, cpu, etc.). + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A dictionary of server information or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) info = r.info(section) return info except RedisError as e: @@ -32,10 +40,17 @@ async def info(section: str = "default") -> dict: @mcp.tool() -async def client_list() -> list: - """Get a list of connected clients to the Redis server.""" +async def client_list(host_id: Optional[str] = None) -> list: + """Get a list of connected clients to the Redis server. + + Args: + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. + + Returns: + A list of connected clients or an error message. + """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) clients = r.client_list() return clients except RedisError as e: diff --git a/src/tools/set.py b/src/tools/set.py index cfbcbba..4c44fde 100644 --- a/src/tools/set.py +++ b/src/tools/set.py @@ -1,22 +1,24 @@ +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp @mcp.tool() -async def sadd(name: str, value: str, expire_seconds: int = None) -> str: +async def sadd(name: str, value: str, expire_seconds: int = None, host_id: Optional[str] = None) -> str: """Add a value to a Redis set with an optional expiration time. Args: name: The Redis set key. value: The value to add to the set. expire_seconds: Optional; time in seconds after which the set should expire. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A success message or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) r.sadd(name, value) if expire_seconds is not None: @@ -29,18 +31,19 @@ async def sadd(name: str, value: str, expire_seconds: int = None) -> str: @mcp.tool() -async def srem(name: str, value: str) -> str: +async def srem(name: str, value: str, host_id: Optional[str] = None) -> str: """Remove a value from a Redis set. Args: name: The Redis set key. value: The value to remove from the set. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A success message or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) removed = r.srem(name, value) return f"Value '{value}' removed from set '{name}'." if removed else f"Value '{value}' not found in set '{name}'." except RedisError as e: @@ -48,17 +51,18 @@ async def srem(name: str, value: str) -> str: @mcp.tool() -async def smembers(name: str) -> list: +async def smembers(name: str, host_id: Optional[str] = None) -> list: """Get all members of a Redis set. Args: name: The Redis set key. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: A list of values in the set or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) members = r.smembers(name) return list(members) if members else f"Set '{name}' is empty or does not exist." except RedisError as e: diff --git a/src/tools/string.py b/src/tools/string.py index 94e280c..42dc819 100644 --- a/src/tools/string.py +++ b/src/tools/string.py @@ -1,3 +1,4 @@ +from typing import Optional from src.common.connection import RedisConnectionManager from redis.exceptions import RedisError from src.common.server import mcp @@ -5,19 +6,20 @@ @mcp.tool() -async def set(key: str, value: EncodableT, expiration: int = None) -> str: +async def set(key: str, value: EncodableT, expiration: int = None, host_id: Optional[str] = None) -> str: """Set a Redis string value with an optional expiration time. Args: key (str): The key to set. value (str): The value to store. expiration (int, optional): Expiration time in seconds. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: str: Confirmation message or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) if expiration: r.setex(key, expiration, value) else: @@ -28,17 +30,18 @@ async def set(key: str, value: EncodableT, expiration: int = None) -> str: @mcp.tool() -async def get(key: str) -> str: +async def get(key: str, host_id: Optional[str] = None) -> str: """Get a Redis string value. Args: key (str): The key to retrieve. + host_id (str, optional): Redis host identifier. If not provided, uses the default connection. Returns: str: The stored value or an error message. """ try: - r = RedisConnectionManager.get_connection() + r = RedisConnectionManager.get_connection(host_id) value = r.get(key) return value if value else f"Key {key} does not exist" except RedisError as e: