Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion crewai_tools/adapters/mcp_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from crewai.tools import BaseTool

from crewai_tools.adapters.tool_collection import ToolCollection

"""
MCPServer for CrewAI.

Expand Down Expand Up @@ -70,6 +73,8 @@ def __init__(
self,
serverparams: StdioServerParameters | dict[str, Any],
*tool_names: str,
connect_timeout: int = 30,
client_session_timeout_seconds: float | timedelta | None = 5,
):
"""Initialize the MCP Server

Expand All @@ -78,6 +83,8 @@ def __init__(
`StdioServerParameters` or a `dict` respectively for STDIO and SSE.
*tool_names: Optional names of tools to filter. If provided, only tools with
matching names will be available.
connect_timeout: Timeout for connecting to the MCP server (default: 30 seconds).
client_session_timeout_seconds: Timeout for client sessions (default: 5 seconds).

"""

Expand Down Expand Up @@ -106,7 +113,12 @@ def __init__(

try:
self._serverparams = serverparams
self._adapter = MCPAdapt(self._serverparams, CrewAIAdapter())
self._adapter = MCPAdapt(
self._serverparams,
CrewAIAdapter(),
connect_timeout=connect_timeout,
client_session_timeout_seconds=client_session_timeout_seconds,
)
self.start()

except Exception as e:
Expand Down
66 changes: 61 additions & 5 deletions tests/adapters/mcp_adapter_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from textwrap import dedent

import pytest
Expand All @@ -6,6 +7,7 @@
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.tool_collection import ToolCollection


@pytest.fixture
def echo_server_script():
return dedent(
Expand Down Expand Up @@ -83,7 +85,8 @@ def test_context_manager_syntax(echo_server_script):
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
assert tools[1].run(a=5, b=3) == "8"


def test_context_manager_syntax_sse(echo_sse_server):
sse_serverparams = echo_sse_server
Expand All @@ -92,7 +95,8 @@ def test_context_manager_syntax_sse(echo_sse_server):
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
assert tools[1].run(a=5, b=3) == "8"


def test_try_finally_syntax(echo_server_script):
serverparams = StdioServerParameters(
Expand All @@ -105,10 +109,11 @@ def test_try_finally_syntax(echo_server_script):
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
assert tools[1].run(a=5, b=3) == "8"
finally:
mcp_server_adapter.stop()


def test_try_finally_syntax_sse(echo_sse_server):
sse_serverparams = echo_sse_server
mcp_server_adapter = MCPServerAdapter(sse_serverparams)
Expand All @@ -118,10 +123,11 @@ def test_try_finally_syntax_sse(echo_sse_server):
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
assert tools[1].run(a=5, b=3) == "8"
finally:
mcp_server_adapter.stop()


def test_context_manager_with_filtered_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
Expand All @@ -138,20 +144,22 @@ def test_context_manager_with_filtered_tools(echo_server_script):
with pytest.raises(KeyError):
_ = tools["calc_tool"]


def test_context_manager_sse_with_filtered_tools(echo_sse_server):
sse_serverparams = echo_sse_server
# Only select the calc_tool
with MCPServerAdapter(sse_serverparams, "calc_tool") as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 1
assert tools[0].name == "calc_tool"
assert tools[0].run(a=10, b=5) == '15'
assert tools[0].run(a=10, b=5) == "15"
# Check that echo_tool is not present
with pytest.raises(IndexError):
_ = tools[1]
with pytest.raises(KeyError):
_ = tools["echo_tool"]


def test_try_finally_with_filtered_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
Expand All @@ -168,6 +176,7 @@ def test_try_finally_with_filtered_tools(echo_server_script):
finally:
mcp_server_adapter.stop()


def test_filter_with_nonexistent_tool(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
Expand All @@ -178,6 +187,7 @@ def test_filter_with_nonexistent_tool(echo_server_script):
assert len(tools) == 1
assert tools[0].name == "echo_tool"


def test_filter_with_only_nonexistent_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
Expand All @@ -187,3 +197,49 @@ def test_filter_with_only_nonexistent_tools(echo_server_script):
# Should return an empty tool collection
assert isinstance(tools, ToolCollection)
assert len(tools) == 0


def test_timeout_parameters_are_set(echo_server_script):
"""Test that connect_timeout and client_session_timeout_seconds are properly set."""

serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)

# Test with custom timeout values
connect_timeout = 45
client_session_timeout_seconds = 10.5

try:
mcp_server_adapter = MCPServerAdapter(
serverparams,
connect_timeout=connect_timeout,
client_session_timeout_seconds=client_session_timeout_seconds,
)

# Verify the timeout parameters are set on the adapter
assert mcp_server_adapter._adapter.connect_timeout == connect_timeout
assert (
mcp_server_adapter._adapter.client_session_timeout_seconds
== client_session_timeout_seconds
)

# Test with timedelta for client_session_timeout_seconds
client_session_timeout_timedelta = timedelta(seconds=15)
mcp_server_adapter_timedelta = MCPServerAdapter(
serverparams,
connect_timeout=60,
client_session_timeout_seconds=client_session_timeout_timedelta,
)

assert mcp_server_adapter_timedelta._adapter.connect_timeout == 60
assert (
mcp_server_adapter_timedelta._adapter.client_session_timeout_seconds
== client_session_timeout_timedelta
)

finally:
if "mcp_server_adapter" in locals():
mcp_server_adapter.stop()
if "mcp_server_adapter_timedelta" in locals():
mcp_server_adapter_timedelta.stop()