|
| 1 | +""" |
| 2 | +Tests for the ast-grep tools implementation. |
| 3 | +""" |
| 4 | + |
| 5 | +import os |
| 6 | +import tempfile |
| 7 | +import pytest |
| 8 | +from unittest.mock import patch, MagicMock |
| 9 | + |
| 10 | +# Import the tools module - adjust the import path as needed |
| 11 | +from src.tools.ast_grep.tools import register_tools, validate_file_path, run_ast_grep_command |
| 12 | + |
| 13 | +# Create a mock MCP instance for testing |
| 14 | +class MockMCP: |
| 15 | + def __init__(self): |
| 16 | + self.registered_tools = {} |
| 17 | + |
| 18 | + def tool(self): |
| 19 | + def decorator(func): |
| 20 | + self.registered_tools[func.__name__] = func |
| 21 | + return func |
| 22 | + return decorator |
| 23 | + |
| 24 | +# Create a mock Context for testing |
| 25 | +class MockContext: |
| 26 | + async def info(self, message): |
| 27 | + pass |
| 28 | + |
| 29 | + async def error(self, message): |
| 30 | + pass |
| 31 | + |
| 32 | + async def warning(self, message): |
| 33 | + pass |
| 34 | + |
| 35 | + async def report_progress(self, current, total): |
| 36 | + pass |
| 37 | + |
| 38 | +@pytest.fixture |
| 39 | +def mcp(): |
| 40 | + return MockMCP() |
| 41 | + |
| 42 | +@pytest.fixture |
| 43 | +def context(): |
| 44 | + return MockContext() |
| 45 | + |
| 46 | +@pytest.fixture |
| 47 | +def temp_project_dir(): |
| 48 | + """Create a temporary directory for testing project path operations.""" |
| 49 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 50 | + # Create a test file |
| 51 | + test_file_path = os.path.join(temp_dir, "test_file.cs") |
| 52 | + with open(test_file_path, "w") as f: |
| 53 | + f.write("public class TestClass { }") |
| 54 | + |
| 55 | + yield temp_dir |
| 56 | + |
| 57 | +def test_validate_file_path(temp_project_dir): |
| 58 | + """Test the validate_file_path function.""" |
| 59 | + # Test with valid file |
| 60 | + valid_path = "test_file.cs" |
| 61 | + full_path = validate_file_path(valid_path, temp_project_dir) |
| 62 | + assert os.path.isfile(full_path) |
| 63 | + |
| 64 | + # Test with non-existent file |
| 65 | + with pytest.raises(ValueError, match="does not exist"): |
| 66 | + validate_file_path("non_existent.cs", temp_project_dir) |
| 67 | + |
| 68 | + # Test with path outside project directory |
| 69 | + with pytest.raises(ValueError, match="outside the project directory"): |
| 70 | + validate_file_path("../test.cs", temp_project_dir) |
| 71 | + |
| 72 | +@patch("subprocess.run") |
| 73 | +def test_run_ast_grep_command(mock_run): |
| 74 | + """Test the run_ast_grep_command function.""" |
| 75 | + # Mock subprocess.run return value |
| 76 | + mock_process = MagicMock() |
| 77 | + mock_process.stdout = '{"result": "success"}' |
| 78 | + mock_process.returncode = 0 |
| 79 | + mock_run.return_value = mock_process |
| 80 | + |
| 81 | + # Test with JSON output |
| 82 | + result = run_ast_grep_command(["ast-grep", "run"], "/tmp", expect_json=True) |
| 83 | + assert result == {"result": "success"} |
| 84 | + |
| 85 | + # Test without JSON output |
| 86 | + mock_process.stdout = "Command executed successfully" |
| 87 | + result = run_ast_grep_command(["ast-grep", "run"], "/tmp", expect_json=False) |
| 88 | + assert result["stdout"] == "Command executed successfully" |
| 89 | + assert result["returncode"] == 0 |
| 90 | + |
| 91 | +@pytest.mark.asyncio |
| 92 | +async def test_set_project_path(mcp, context): |
| 93 | + """Test the ast_grep_set_project_path tool.""" |
| 94 | + # Register the tools |
| 95 | + register_tools(mcp) |
| 96 | + |
| 97 | + # Get the set_project_path function |
| 98 | + set_project_path = mcp.registered_tools["ast_grep_set_project_path"] |
| 99 | + |
| 100 | + # Test with valid path (using monkeypatch to avoid actual directory checks) |
| 101 | + with patch("os.path.isdir", return_value=True): |
| 102 | + result = await set_project_path("/valid/path", context) |
| 103 | + assert result["success"] is True |
| 104 | + assert result["project_path"] == "/valid/path" |
| 105 | + |
| 106 | + # Test with invalid path |
| 107 | + with patch("os.path.isdir", return_value=False): |
| 108 | + result = await set_project_path("/invalid/path", context) |
| 109 | + assert result["success"] is False |
| 110 | + assert "Directory not found" in result["error"] |
| 111 | + |
| 112 | +@pytest.mark.asyncio |
| 113 | +async def test_parse_code(mcp, context): |
| 114 | + """Test the ast_grep_parse_code tool.""" |
| 115 | + # Register the tools |
| 116 | + register_tools(mcp) |
| 117 | + |
| 118 | + # Get the parse_code function |
| 119 | + parse_code = mcp.registered_tools["ast_grep_parse_code"] |
| 120 | + |
| 121 | + # Mock run_ast_grep_command to avoid actual command execution |
| 122 | + with patch("src.tools.ast_grep.tools.run_ast_grep_command") as mock_run: |
| 123 | + mock_run.return_value = {"matches": []} |
| 124 | + |
| 125 | + # Test with valid code |
| 126 | + result = await parse_code("public class Test {}", "csharp", context) |
| 127 | + assert result["success"] is True |
| 128 | + |
| 129 | + # Verify temp file creation and cleanup |
| 130 | + mock_run.assert_called_once() |
| 131 | + cmd_args = mock_run.call_args[0][0] |
| 132 | + assert cmd_args[0] == "ast-grep" |
| 133 | + assert cmd_args[1] == "run" |
| 134 | + |
| 135 | +@pytest.mark.asyncio |
| 136 | +async def test_find_pattern(mcp, context, temp_project_dir): |
| 137 | + """Test the ast_grep_find_pattern tool.""" |
| 138 | + # Register the tools |
| 139 | + register_tools(mcp) |
| 140 | + |
| 141 | + # Get the find_pattern function |
| 142 | + find_pattern = mcp.registered_tools["ast_grep_find_pattern"] |
| 143 | + |
| 144 | + # Set the global current_project_path |
| 145 | + import src.tools.ast_grep.tools |
| 146 | + src.tools.ast_grep.tools.current_project_path = temp_project_dir |
| 147 | + |
| 148 | + # Mock run_ast_grep_command to avoid actual command execution |
| 149 | + with patch("src.tools.ast_grep.tools.run_ast_grep_command") as mock_run: |
| 150 | + mock_run.return_value = [{"text": "public class TestClass", "start": 0, "end": 23}] |
| 151 | + |
| 152 | + # Test with valid pattern |
| 153 | + result = await find_pattern("test_file.cs", "class $NAME", None, context) |
| 154 | + assert result["success"] is True |
| 155 | + assert result["count"] == 1 |
| 156 | + assert len(result["matches"]) == 1 |
| 157 | + |
| 158 | +@pytest.mark.asyncio |
| 159 | +async def test_replace_pattern(mcp, context, temp_project_dir): |
| 160 | + """Test the ast_grep_replace_pattern tool.""" |
| 161 | + # Register the tools |
| 162 | + register_tools(mcp) |
| 163 | + |
| 164 | + # Get the replace_pattern function |
| 165 | + replace_pattern = mcp.registered_tools["ast_grep_replace_pattern"] |
| 166 | + |
| 167 | + # Set the global current_project_path |
| 168 | + import src.tools.ast_grep.tools |
| 169 | + src.tools.ast_grep.tools.current_project_path = temp_project_dir |
| 170 | + |
| 171 | + # Mock functions to avoid actual file operations |
| 172 | + with patch("src.tools.ast_grep.tools.run_ast_grep_command") as mock_run, \ |
| 173 | + patch("builtins.open", create=True), \ |
| 174 | + patch("shutil.copy2"): |
| 175 | + |
| 176 | + # Mock file content check |
| 177 | + mock_file = MagicMock() |
| 178 | + mock_file.__enter__.return_value.read.return_value = "public class TestClass : IInterface { }" |
| 179 | + with patch("builtins.open", return_value=mock_file): |
| 180 | + |
| 181 | + # Test with valid replacement |
| 182 | + result = await replace_pattern( |
| 183 | + "test_file.cs", |
| 184 | + "class TestClass", |
| 185 | + "class TestClass : IInterface", |
| 186 | + None, |
| 187 | + context |
| 188 | + ) |
| 189 | + assert result["success"] is True |
| 190 | + assert "backup_path" in result |
| 191 | + |
| 192 | +# Add more tests for other tools (run_yaml_rule, scan_project, initialize_project, test_rule) |
0 commit comments