diff --git a/pybatfish/mcp/__init__.py b/pybatfish/mcp/__init__.py new file mode 100644 index 00000000..b2091c76 --- /dev/null +++ b/pybatfish/mcp/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MCP (Model Context Protocol) server for Batfish (Beta). + +.. warning:: + This MCP server is currently in **beta**. The tool names, parameters, and + return formats may change in future releases without prior notice. + +This package provides an official MCP server that exposes Batfish network +analysis capabilities as MCP tools, enabling AI agents (such as Claude, +Cursor, and other MCP-compatible clients) to interact with Batfish for +network configuration analysis and verification. + +Usage:: + + # Run the MCP server (stdio transport, for use with MCP clients): + python -m pybatfish.mcp + + # Or run with a specific Batfish host: + BATFISH_HOST=my-batfish-host python -m pybatfish.mcp +""" + +from pybatfish.mcp.server import create_server + +__all__ = ["create_server"] diff --git a/pybatfish/mcp/__main__.py b/pybatfish/mcp/__main__.py new file mode 100644 index 00000000..422442ea --- /dev/null +++ b/pybatfish/mcp/__main__.py @@ -0,0 +1,44 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Entry point for running the Batfish MCP server (Beta). + +.. warning:: + This MCP server is currently in **beta**. The tool names, parameters, and + return formats may change in future releases without prior notice. + +Run with:: + + python -m pybatfish.mcp + +Or, after installing pybatfish with the ``mcp`` extra:: + + batfish-mcp + +Environment variables: + +* ``BATFISH_HOST`` — hostname of the Batfish server (default: ``localhost``). +""" + +from pybatfish.mcp.server import create_server + + +def main() -> None: + """Start the Batfish MCP server using stdio transport.""" + server = create_server() + server.run(transport="stdio") + + +if __name__ == "__main__": + main() diff --git a/pybatfish/mcp/server.py b/pybatfish/mcp/server.py new file mode 100644 index 00000000..6747753c --- /dev/null +++ b/pybatfish/mcp/server.py @@ -0,0 +1,971 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Batfish MCP server implementation (Beta). + +.. warning:: + This MCP server is currently in **beta**. The tool names, parameters, and + return formats may change in future releases without prior notice. + +Exposes Batfish network analysis capabilities as MCP (Model Context Protocol) +tools, allowing AI agents to perform snapshot management, reachability +analysis, traceroute simulation, ACL/filter inspection, and routing queries. +""" + +from __future__ import annotations + +import json +import os +import threading +from typing import Any + +try: + from mcp.server.fastmcp import FastMCP +except ImportError as e: + raise ImportError( + "The 'mcp' package is required to use the Batfish MCP server. Install it with: pip install 'pybatfish[mcp]'" + ) from e + +from pybatfish.client.session import Session +from pybatfish.datamodel import HeaderConstraints, Interface + +# Legacy next-hop column names that Batfish is deprecating. The structured +# ``Next_Hop`` column contains the same information in a richer format and +# should be preferred. We drop these from route results so that consumers +# of this MCP server are not exposed to the deprecated columns. +_LEGACY_NEXTHOP_COLUMNS: frozenset[str] = frozenset( + [ + "Next_Hop_IP", + "Next_Hop_Interface", + "NextHopIp", + "NextHopInterface", + ] +) + +# Per-host Session cache. Question templates are downloaded from the Batfish +# service exactly once per host per process lifetime, covering both management +# and analysis operations. +_session_cache: dict[str, Session] = {} +_session_cache_lock = threading.Lock() + + +def _get_session(host: str) -> Session: + """Return the cached Batfish Session for the given host. + + The session is retrieved from (or added to) a process-level cache keyed by + *host*, so that question templates are downloaded from the Batfish service + **at most once per process** rather than on every tool call. + """ + with _session_cache_lock: + if host not in _session_cache: + _session_cache[host] = Session(host=host) + return _session_cache[host] + + +def _clear_session_cache() -> None: + """Clear the per-host session cache. + + Intended for use in tests and in situations where the caller wants to + force question templates to be re-fetched from the Batfish service. + """ + with _session_cache_lock: + _session_cache.clear() + + +def _resolve_host(host: str) -> str: + """Return the effective Batfish hostname. + + Returns *host* if non-empty; otherwise falls back to the + ``BATFISH_HOST`` environment variable, and finally to ``'localhost'``. + """ + return host or os.environ.get("BATFISH_HOST", "localhost") + + +def _mgmt_session(host: str, network: str = "") -> Session: + """Return the cached session with an optional network set. + + Resolves the effective hostname, fetches (or creates) the per-host cached + session, and optionally calls :meth:`~Session.set_network` when *network* + is provided. Use this for tools that perform network or snapshot management + operations (e.g. ``list_networks``, ``init_snapshot``, ``delete_snapshot``). + """ + bf = _get_session(_resolve_host(host)) + if network: + bf.set_network(network) + return bf + + +def _analysis_session(host: str, network: str, snapshot: str) -> Session: + """Return the cached session with network and snapshot set. + + Resolves the effective hostname, fetches (or creates) the per-host cached + session, then calls :meth:`~Session.set_network` and + :meth:`~Session.set_snapshot`. Use this for all tools that invoke Batfish + questions. + """ + bf = _get_session(_resolve_host(host)) + bf.set_network(network) + bf.set_snapshot(snapshot) + return bf + + +def _df_to_json(df: Any) -> str: + """Convert a pandas DataFrame (or any value) to a JSON string.""" + if hasattr(df, "to_json"): + result: str | None = df.to_json(orient="records", default_handler=str) + return result or "[]" + return json.dumps(df, default=str) + + +def _drop_legacy_nexthop_columns(df: Any) -> Any: + """Drop deprecated next-hop columns from a routes DataFrame. + + Keeps only the structured ``Next_Hop`` column and removes the legacy + ``Next_Hop_IP`` and ``Next_Hop_Interface`` columns (and their camelCase + variants) that Batfish is deprecating. + """ + if not hasattr(df, "columns"): + return df + cols_to_drop = [c for c in df.columns if c in _LEGACY_NEXTHOP_COLUMNS] + if cols_to_drop: + return df.drop(columns=cols_to_drop) + return df + + +def create_server(name: str = "Batfish") -> FastMCP: + """Create and return a configured Batfish MCP server (Beta). + + .. warning:: + This MCP server is currently in **beta**. Tool names, parameters, and + return formats may change in future releases without prior notice. + + :param name: Name for the MCP server (default: "Batfish") + :return: Configured FastMCP server instance + """ + mcp = FastMCP( + name, + instructions=( + "[BETA] This server provides tools to interact with a Batfish network analysis service. " + "Note: this MCP server is in beta — tool names and parameters may change in future releases. " + "Use these tools to load network snapshots, run traceroutes, analyze reachability, " + "inspect ACLs/firewall rules, query routing tables, and compare snapshots. " + "Most tools require a 'host' parameter (Batfish server hostname, defaults to " + "the BATFISH_HOST environment variable or 'localhost'), a 'network' parameter " + "(the network name in Batfish), and a 'snapshot' parameter (the snapshot name). " + "Start by listing networks or initializing a snapshot, then run analysis tools." + ), + ) + + # ------------------------------------------------------------------------- + # Network management tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def list_networks(host: str = "") -> str: + """List all available networks on the Batfish server. + + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of network names. + """ + bf = _mgmt_session(host) + return json.dumps(bf.list_networks()) + + @mcp.tool() + def set_network(network: str, host: str = "") -> str: + """Create or select a network on the Batfish server. + + :param network: Name of the network to create or select. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object with the active network name. + """ + bf = _mgmt_session(host) + name = bf.set_network(network) + return json.dumps({"network": name}) + + @mcp.tool() + def delete_network(network: str, host: str = "") -> str: + """Delete a network from the Batfish server. + + :param network: Name of the network to delete. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object confirming deletion. + """ + bf = _mgmt_session(host) + bf.delete_network(network) + return json.dumps({"deleted": network}) + + # ------------------------------------------------------------------------- + # Snapshot management tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def list_snapshots(network: str, host: str = "") -> str: + """List all snapshots within a network. + + :param network: Name of the network. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of snapshot names. + """ + bf = _mgmt_session(host, network) + return json.dumps(bf.list_snapshots()) + + @mcp.tool() + def init_snapshot( + network: str, + snapshot_path: str, + snapshot_name: str = "", + overwrite: bool = False, + host: str = "", + ) -> str: + """Initialize a new snapshot from a local directory or zip file. + + The snapshot directory or zip file should contain device configuration + files under a ``configs/`` sub-directory. + + :param network: Name of the network to add the snapshot to. + :param snapshot_path: Local path to a snapshot directory or zip file. + :param snapshot_name: Optional name for the snapshot. Auto-generated if empty. + :param overwrite: Whether to overwrite an existing snapshot with the same name. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object with the initialized snapshot name. + """ + bf = _mgmt_session(host, network) + name = bf.init_snapshot( + snapshot_path, + name=snapshot_name or None, + overwrite=overwrite, + ) + return json.dumps({"snapshot": name}) + + @mcp.tool() + def init_snapshot_from_text( + network: str, + config_text: str, + filename: str = "config", + snapshot_name: str = "", + platform: str = "", + overwrite: bool = False, + host: str = "", + ) -> str: + """Initialize a single-device snapshot from configuration text. + + Useful for quickly loading one device's configuration without needing + a local file or zip archive. + + :param network: Name of the network to add the snapshot to. + :param config_text: Raw configuration text (e.g. output of "show running-config"). + :param filename: Filename to use inside the snapshot (default: 'config'). + :param snapshot_name: Optional name for the snapshot. Auto-generated if empty. + :param platform: RANCID platform string (e.g. 'cisco-nx', 'arista', 'juniper'). + If empty, the platform is inferred from the configuration header. + :param overwrite: Whether to overwrite an existing snapshot with the same name. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object with the initialized snapshot name. + """ + bf = _mgmt_session(host, network) + name = bf.init_snapshot_from_text( + config_text, + filename=filename, + snapshot_name=snapshot_name or None, + platform=platform or None, + overwrite=overwrite, + ) + return json.dumps({"snapshot": name}) + + @mcp.tool() + def delete_snapshot(network: str, snapshot: str, host: str = "") -> str: + """Delete a snapshot from a network. + + :param network: Name of the network containing the snapshot. + :param snapshot: Name of the snapshot to delete. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object confirming deletion. + """ + bf = _mgmt_session(host, network) + bf.delete_snapshot(snapshot) + return json.dumps({"deleted": snapshot}) + + @mcp.tool() + def fork_snapshot( + network: str, + base_snapshot: str, + new_snapshot: str = "", + deactivate_nodes: str = "", + deactivate_interfaces: str = "", + restore_nodes: str = "", + restore_interfaces: str = "", + overwrite: bool = False, + host: str = "", + ) -> str: + """Fork an existing snapshot, optionally deactivating or restoring nodes/interfaces. + + Use this to simulate failure scenarios (e.g. deactivate a node or link) or + to restore previously deactivated elements. + + :param network: Name of the network containing the base snapshot. + :param base_snapshot: Name of the snapshot to fork from. + :param new_snapshot: Name for the new forked snapshot. Auto-generated if empty. + :param deactivate_nodes: Comma-separated list of node names to deactivate. + :param deactivate_interfaces: Comma-separated list of 'node[interface]' pairs to deactivate. + :param restore_nodes: Comma-separated list of node names to restore. + :param restore_interfaces: Comma-separated list of 'node[interface]' pairs to restore. + :param overwrite: Whether to overwrite an existing snapshot with the same name. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON object with the forked snapshot name. + """ + bf = _mgmt_session(host, network) + + deactivate_nodes_list = [n.strip() for n in deactivate_nodes.split(",") if n.strip()] or None + restore_nodes_list = [n.strip() for n in restore_nodes.split(",") if n.strip()] or None + + deactivate_ifaces = _parse_interfaces(deactivate_interfaces) + restore_ifaces = _parse_interfaces(restore_interfaces) + + name = bf.fork_snapshot( + base_snapshot, + name=new_snapshot or None, + overwrite=overwrite, + deactivate_nodes=deactivate_nodes_list, + deactivate_interfaces=deactivate_ifaces or None, + restore_nodes=restore_nodes_list, + restore_interfaces=restore_ifaces or None, + ) + return json.dumps({"snapshot": name}) + + # ------------------------------------------------------------------------- + # Reachability and traceroute tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def run_traceroute( + network: str, + snapshot: str, + start_location: str, + dst_ips: str, + src_ips: str = "", + applications: str = "", + ip_protocols: str = "", + src_ports: str = "", + dst_ports: str = "", + host: str = "", + ) -> str: + """Simulate a traceroute from a location to a destination IP address. + + Returns the forwarding path(s) a packet would take through the network, + including all hops, interfaces, and any access-list hits. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param start_location: Source location specifier (e.g. node name, interface). + :param dst_ips: Destination IP address or prefix (e.g. '10.0.0.1'). + :param src_ips: Source IP address or prefix (optional). + :param applications: Application specifier, e.g. 'ssh', 'HTTP' (optional). + :param ip_protocols: IP protocol(s) e.g. 'TCP' (optional). + :param src_ports: Source port(s) e.g. '1024-65535' (optional). + :param dst_ports: Destination port(s) e.g. '22' (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of traceroute result rows. + """ + bf = _analysis_session(host, network, snapshot) + + headers = _build_header_constraints( + dst_ips=dst_ips, + src_ips=src_ips, + applications=applications, + ip_protocols=ip_protocols, + src_ports=src_ports, + dst_ports=dst_ports, + ) + result = bf.q.traceroute(startLocation=start_location, headers=headers).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def run_bidirectional_traceroute( + network: str, + snapshot: str, + start_location: str, + dst_ips: str, + src_ips: str = "", + applications: str = "", + ip_protocols: str = "", + src_ports: str = "", + dst_ports: str = "", + host: str = "", + ) -> str: + """Simulate a bidirectional traceroute (forward + reverse paths). + + Returns both the forward path from source to destination and the + reverse path from destination back to source. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param start_location: Source location specifier (e.g. node name, interface). + :param dst_ips: Destination IP address or prefix. + :param src_ips: Source IP address or prefix (optional). + :param applications: Application specifier (optional). + :param ip_protocols: IP protocol(s) (optional). + :param src_ports: Source port(s) (optional). + :param dst_ports: Destination port(s) (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of bidirectional traceroute result rows. + """ + bf = _analysis_session(host, network, snapshot) + + headers = _build_header_constraints( + dst_ips=dst_ips, + src_ips=src_ips, + applications=applications, + ip_protocols=ip_protocols, + src_ports=src_ports, + dst_ports=dst_ports, + ) + result = bf.q.bidirectionalTraceroute(startLocation=start_location, headers=headers).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def check_reachability( + network: str, + snapshot: str, + src_locations: str = "", + dst_ips: str = "", + src_ips: str = "", + applications: str = "", + ip_protocols: str = "", + src_ports: str = "", + dst_ports: str = "", + actions: str = "", + host: str = "", + ) -> str: + """Check reachability between network locations. + + Determines which flows can successfully reach the destination, and which + are dropped, denied, or otherwise blocked. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param src_locations: Source location specifier (optional). + :param dst_ips: Destination IP address or prefix (optional). + :param src_ips: Source IP address or prefix (optional). + :param applications: Application specifier (optional). + :param ip_protocols: IP protocol(s) (optional). + :param src_ports: Source port(s) (optional). + :param dst_ports: Destination port(s) (optional). + :param actions: Disposition filter, e.g. 'DENIED_IN,DENIED_OUT,DROP' (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of reachability result rows. + """ + bf = _analysis_session(host, network, snapshot) + + headers = _build_header_constraints( + dst_ips=dst_ips, + src_ips=src_ips, + applications=applications, + ip_protocols=ip_protocols, + src_ports=src_ports, + dst_ports=dst_ports, + ) + kwargs: dict[str, Any] = {"headers": headers} + if src_locations: + kwargs["pathConstraints"] = {"startLocation": src_locations} + if actions: + kwargs["actions"] = actions + + result = bf.q.reachability(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + # ------------------------------------------------------------------------- + # ACL / filter analysis tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def analyze_acl( + network: str, + snapshot: str, + filters: str = "", + nodes: str = "", + host: str = "", + ) -> str: + """Identify unreachable (shadowed) lines in ACLs and firewall rules. + + Reports lines that can never be matched because earlier lines in the + same filter already match the same traffic. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param filters: Filter specifier to restrict analysis (optional). + :param nodes: Node specifier to restrict analysis (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of unreachable ACL/filter line rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if filters: + kwargs["filters"] = filters + if nodes: + kwargs["nodes"] = nodes + + result = bf.q.filterLineReachability(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def search_filters( + network: str, + snapshot: str, + filters: str = "", + nodes: str = "", + dst_ips: str = "", + src_ips: str = "", + applications: str = "", + ip_protocols: str = "", + src_ports: str = "", + dst_ports: str = "", + action: str = "", + host: str = "", + ) -> str: + """Search for flows that match specific filter (ACL/firewall) criteria. + + Finds concrete example flows that are permitted or denied by the + specified filters, useful for validating ACL intent. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param filters: Filter specifier (optional). + :param nodes: Node specifier (optional). + :param dst_ips: Destination IP address or prefix to match (optional). + :param src_ips: Source IP address or prefix to match (optional). + :param applications: Application specifier (optional). + :param ip_protocols: IP protocol(s) (optional). + :param src_ports: Source port(s) (optional). + :param dst_ports: Destination port(s) (optional). + :param action: Filter action: 'PERMIT' or 'DENY' (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of matched flow rows. + """ + bf = _analysis_session(host, network, snapshot) + + headers = _build_header_constraints( + dst_ips=dst_ips, + src_ips=src_ips, + applications=applications, + ip_protocols=ip_protocols, + src_ports=src_ports, + dst_ports=dst_ports, + ) + kwargs: dict[str, Any] = {"headers": headers} + if filters: + kwargs["filters"] = filters + if nodes: + kwargs["nodes"] = nodes + if action: + kwargs["action"] = action + + result = bf.q.searchFilters(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + # ------------------------------------------------------------------------- + # Routing table tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def get_routes( + network: str, + snapshot: str, + nodes: str = "", + vrfs: str = "", + network_prefix: str = "", + protocols: str = "", + host: str = "", + ) -> str: + """Retrieve the routing table (RIB) from one or more devices. + + Legacy next-hop columns (Next_Hop_IP, Next_Hop_Interface) + are omitted from the results; use the structured Next_Hop column instead. + + :param network: Name of the Batfish network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier to restrict results (optional). + :param vrfs: VRF specifier to restrict results (optional). + :param network_prefix: Prefix to filter routes by (optional). + :param protocols: Routing protocol(s) to filter by, e.g. 'bgp,ospf' (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of routing table rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if vrfs: + kwargs["vrfs"] = vrfs + if network_prefix: + kwargs["network"] = network_prefix + if protocols: + kwargs["protocols"] = protocols + + result = _drop_legacy_nexthop_columns(bf.q.routes(**kwargs).answer().frame()) # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def compare_routes( + network: str, + snapshot: str, + reference_snapshot: str, + nodes: str = "", + vrfs: str = "", + network_prefix: str = "", + protocols: str = "", + host: str = "", + ) -> str: + """Compare routing tables between two snapshots to identify route changes. + + Useful for validating that a configuration change produces the expected + routing changes (and no unintended ones). + + Legacy next-hop columns (Next_Hop_IP, Next_Hop_Interface) + are omitted from the results; use the structured Next_Hop column instead. + + :param network: Name of the Batfish network. + :param snapshot: Name of the candidate (new) snapshot. + :param reference_snapshot: Name of the reference (baseline) snapshot. + :param nodes: Node specifier to restrict results (optional). + :param vrfs: VRF specifier to restrict results (optional). + :param network_prefix: Prefix to filter routes by (optional). + :param protocols: Routing protocol(s) to filter by (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array showing route differences (added/removed routes). + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if vrfs: + kwargs["vrfs"] = vrfs + if network_prefix: + kwargs["network"] = network_prefix + if protocols: + kwargs["protocols"] = protocols + + result = _drop_legacy_nexthop_columns( + bf.q.routes(**kwargs).answer(snapshot=snapshot, reference_snapshot=reference_snapshot).frame() # type: ignore[attr-defined] + ) + return _df_to_json(result) + + # ------------------------------------------------------------------------- + # BGP tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def get_bgp_session_status( + network: str, + snapshot: str, + nodes: str = "", + remote_nodes: str = "", + status: str = "", + host: str = "", + ) -> str: + """Get the status of BGP sessions in a snapshot. + + Reports which BGP sessions are established, incompatible, or missing. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier for local BGP speakers (optional). + :param remote_nodes: Node specifier for remote BGP speakers (optional). + :param status: BGP session status specifier to filter by (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of BGP session status rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if remote_nodes: + kwargs["remoteNodes"] = remote_nodes + if status: + kwargs["status"] = status + + result = bf.q.bgpSessionStatus(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def get_bgp_session_compatibility( + network: str, + snapshot: str, + nodes: str = "", + remote_nodes: str = "", + status: str = "", + host: str = "", + ) -> str: + """Check BGP session compatibility between peers. + + Returns the full BGP session compatibility table for the snapshot. + Each row represents a BGP session and includes its compatibility status + (e.g., UNIQUE_MATCH, NO_MATCH, DYNAMIC_MATCH) along with details about + address families, authentication, and other parameters. Use the + *status* parameter to filter results to a specific compatibility status, + such as ``NO_MATCH`` or ``NO_LOCAL_AS``, when you are interested in only + mis-configured or incompatible sessions. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier for local BGP speakers (optional). + :param remote_nodes: Node specifier for remote BGP speakers (optional). + :param status: BGP compatibility status specifier to filter by (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of BGP compatibility rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if remote_nodes: + kwargs["remoteNodes"] = remote_nodes + if status: + kwargs["status"] = status + + result = bf.q.bgpSessionCompatibility(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + # ------------------------------------------------------------------------- + # Node and interface information tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def get_node_properties( + network: str, + snapshot: str, + nodes: str = "", + properties: str = "", + host: str = "", + ) -> str: + """Retrieve configuration properties of network nodes (routers/switches). + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier to restrict results (optional). + :param properties: Comma-separated list of property names to retrieve (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of node property rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if properties: + kwargs["properties"] = properties + + result = bf.q.nodeProperties(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def get_interface_properties( + network: str, + snapshot: str, + nodes: str = "", + interfaces: str = "", + properties: str = "", + host: str = "", + ) -> str: + """Retrieve configuration properties of network interfaces. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier to restrict results (optional). + :param interfaces: Interface specifier to restrict results (optional). + :param properties: Comma-separated list of property names to retrieve (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of interface property rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + if interfaces: + kwargs["interfaces"] = interfaces + if properties: + kwargs["properties"] = properties + + result = bf.q.interfaceProperties(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def get_ip_owners( + network: str, + snapshot: str, + duplicates_only: bool = False, + host: str = "", + ) -> str: + """Get the mapping of IP addresses to network interfaces. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param duplicates_only: If True, return only IPs assigned to multiple interfaces. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of IP ownership rows. + """ + bf = _analysis_session(host, network, snapshot) + + result = bf.q.ipOwners(duplicatesOnly=duplicates_only).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + # ------------------------------------------------------------------------- + # Snapshot comparison tools + # ------------------------------------------------------------------------- + + @mcp.tool() + def compare_filters( + network: str, + snapshot: str, + reference_snapshot: str, + filters: str = "", + nodes: str = "", + host: str = "", + ) -> str: + """Compare ACL/firewall filter behavior between two snapshots. + + Identifies flows that are treated differently (permitted vs. denied) + between the candidate snapshot and the reference (baseline) snapshot. + + :param network: Name of the network. + :param snapshot: Name of the candidate (new) snapshot. + :param reference_snapshot: Name of the reference (baseline) snapshot. + :param filters: Filter specifier to restrict comparison (optional). + :param nodes: Node specifier to restrict comparison (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of filter difference rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if filters: + kwargs["filters"] = filters + if nodes: + kwargs["nodes"] = nodes + + result = bf.q.compareFilters(**kwargs).answer(snapshot=snapshot, reference_snapshot=reference_snapshot).frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def get_undefined_references( + network: str, + snapshot: str, + nodes: str = "", + host: str = "", + ) -> str: + """Find undefined references in device configurations. + + Reports references to named objects (e.g. ACLs, route-maps, prefix-lists) + that are used but never defined, which can indicate configuration errors. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param nodes: Node specifier to restrict results (optional). + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of undefined reference rows. + """ + bf = _analysis_session(host, network, snapshot) + + kwargs: dict[str, Any] = {} + if nodes: + kwargs["nodes"] = nodes + + result = bf.q.undefinedReferences(**kwargs).answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + @mcp.tool() + def detect_loops( + network: str, + snapshot: str, + host: str = "", + ) -> str: + """Detect forwarding loops in the network snapshot. + + Identifies any packet flows that would loop indefinitely through the + network without being delivered. + + :param network: Name of the network. + :param snapshot: Name of the snapshot. + :param host: Batfish server hostname. Defaults to BATFISH_HOST env var or 'localhost'. + :return: JSON array of forwarding loop rows (empty if no loops found). + """ + bf = _analysis_session(host, network, snapshot) + + result = bf.q.detectLoops().answer().frame() # type: ignore[attr-defined] + return _df_to_json(result) + + return mcp + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _parse_interfaces(interfaces_str: str) -> list[Interface]: + """Parse a comma-separated 'node[iface]' string into Interface objects. + + Each token must follow the ``node[interface]`` format. Bare node names + (e.g. ``"router1"`` without a bracketed interface) are rejected with a + :exc:`ValueError`; use the *deactivate_nodes* / *restore_nodes* parameters + for node-level operations instead. + + :param interfaces_str: Comma-separated interface specifiers, e.g. + ``"r1[Gi0/0], r2[Ethernet1/1]"``. + :raises ValueError: If a token does not match the ``node[interface]`` format. + """ + result = [] + for item in interfaces_str.split(","): + item = item.strip() + if not item: + continue + if "[" in item and item.endswith("]"): + node, iface = item[:-1].split("[", 1) + result.append(Interface(hostname=node.strip(), interface=iface.strip())) + else: + raise ValueError( + f"Invalid interface specifier {item!r}: expected 'node[interface]' format. " + "Use deactivate_nodes/restore_nodes for node-level operations." + ) + return result + + +def _build_header_constraints( + dst_ips: str = "", + src_ips: str = "", + applications: str = "", + ip_protocols: str = "", + src_ports: str = "", + dst_ports: str = "", +) -> HeaderConstraints: + """Build a HeaderConstraints object from string parameters.""" + kwargs: dict[str, Any] = {} + if dst_ips: + kwargs["dstIps"] = dst_ips + if src_ips: + kwargs["srcIps"] = src_ips + if applications: + kwargs["applications"] = applications + if ip_protocols: + kwargs["ipProtocols"] = ip_protocols + if src_ports: + kwargs["srcPorts"] = src_ports + if dst_ports: + kwargs["dstPorts"] = dst_ports + return HeaderConstraints(**kwargs) diff --git a/pyproject.toml b/pyproject.toml index 964c1f1f..0a0380aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,9 @@ capirca = [ "capirca", "absl-py>=0.8.0", ] +mcp = [ + "mcp>=1.23.0", +] dev = [ "ruff", "cerberus", @@ -92,11 +95,15 @@ dev = [ "types-simplejson", "capirca", "absl-py>=0.8.0", + "mcp>=1.23.0", ] [project.entry-points."batfish_session"] bf = "pybatfish.client.session:Session" +[project.scripts] +batfish-mcp = "pybatfish.mcp.__main__:main" + [tool.setuptools] packages = {find = {exclude = ["contrib", "docs", "tests"]}} package-data = {"pybatfish" = ["py.typed"]} diff --git a/tests/integration/test_mcp_server.py b/tests/integration/test_mcp_server.py new file mode 100644 index 00000000..cc7193bc --- /dev/null +++ b/tests/integration/test_mcp_server.py @@ -0,0 +1,235 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for the Batfish MCP server. + +These tests run against a live Batfish service and are executed as part of +the integration_tests CI job (see .github/workflows/reusable-precommit.yml). +""" + +from __future__ import annotations + +import asyncio +import json +import typing +import uuid +from os.path import abspath, dirname, join, realpath +from typing import Any + +import pytest +from mcp.server.fastmcp import FastMCP +from mcp.types import TextContent + +from pybatfish.mcp.server import create_server + +_this_dir = abspath(dirname(realpath(__file__))) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_MCP_NETWORK = "mcp_integration_test_network" + + +def _call_tool(server: Any, tool_name: str, args: dict[str, Any]) -> Any: + """Call an MCP tool synchronously and return the parsed JSON result.""" + # call_tool returns (list[ContentBlock], meta_dict) at runtime + result = asyncio.run(server.call_tool(tool_name, args)) + content = result[0] + first = content[0] + assert isinstance(first, TextContent) + return json.loads(first.text) + + +@pytest.fixture(scope="module") +def mcp() -> FastMCP: + """Return a configured MCP server instance.""" + return create_server() + + +@pytest.fixture(scope="module") +def network(mcp: FastMCP) -> typing.Generator[str, None, None]: + """Create a Batfish network for testing and clean it up afterwards.""" + _call_tool(mcp, "set_network", {"network": _MCP_NETWORK}) + yield _MCP_NETWORK + try: + _call_tool(mcp, "delete_network", {"network": _MCP_NETWORK}) + except Exception: + pass + + +@pytest.fixture(scope="module") +def snapshot(mcp: FastMCP, network: str) -> typing.Generator[str, None, None]: + """Initialize a snapshot from the tracert snapshot directory. + + We use the tracert snapshot (which has real router configurations) rather + than the minimal ``snapshot`` directory, because some Batfish questions + (e.g. ``detect_loops``) require at least one interface to be present. + """ + snap_name = "mcp_snap_" + uuid.uuid4().hex[:8] + _call_tool( + mcp, + "init_snapshot", + { + "network": network, + "snapshot_path": join(_this_dir, "tracert_snapshot"), + "snapshot_name": snap_name, + }, + ) + yield snap_name + try: + _call_tool(mcp, "delete_snapshot", {"network": network, "snapshot": snap_name}) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Network management tests +# --------------------------------------------------------------------------- + + +def test_list_networks_includes_test_network(mcp: FastMCP, network: str) -> None: + """list_networks should include our test network.""" + data = _call_tool(mcp, "list_networks", {}) + assert isinstance(data, list) + assert network in data + + +def test_set_network_returns_name(mcp: FastMCP) -> None: + """set_network should return a JSON object with the network name.""" + result = _call_tool(mcp, "set_network", {"network": _MCP_NETWORK}) + assert result == {"network": _MCP_NETWORK} + + +# --------------------------------------------------------------------------- +# Snapshot management tests +# --------------------------------------------------------------------------- + + +def test_list_snapshots_includes_test_snapshot(mcp: FastMCP, network: str, snapshot: str) -> None: + """list_snapshots should include the snapshot we just created.""" + data = _call_tool(mcp, "list_snapshots", {"network": network}) + assert isinstance(data, list) + assert snapshot in data + + +def test_init_snapshot_from_text(mcp: FastMCP, network: str) -> None: + """init_snapshot_from_text should succeed and return a snapshot name.""" + config = "! Cisco IOS XE\nhostname test-router\n" + snap_name = "mcp_text_snap_" + uuid.uuid4().hex[:8] + try: + result = _call_tool( + mcp, + "init_snapshot_from_text", + { + "network": network, + "config_text": config, + "snapshot_name": snap_name, + "filename": "test-router.cfg", + }, + ) + assert result == {"snapshot": snap_name} + # Snapshot should now appear in the list + snaps = _call_tool(mcp, "list_snapshots", {"network": network}) + assert snap_name in snaps + finally: + try: + _call_tool(mcp, "delete_snapshot", {"network": network, "snapshot": snap_name}) + except Exception: + pass + + +def test_delete_snapshot(mcp: FastMCP, network: str) -> None: + """delete_snapshot should remove the snapshot from the network.""" + snap_name = "mcp_del_snap_" + uuid.uuid4().hex[:8] + _call_tool( + mcp, + "init_snapshot", + { + "network": network, + "snapshot_path": join(_this_dir, "snapshot"), + "snapshot_name": snap_name, + }, + ) + # Snapshot should exist + snaps_before = _call_tool(mcp, "list_snapshots", {"network": network}) + assert snap_name in snaps_before + + result = _call_tool(mcp, "delete_snapshot", {"network": network, "snapshot": snap_name}) + assert result == {"deleted": snap_name} + + snaps_after = _call_tool(mcp, "list_snapshots", {"network": network}) + assert snap_name not in snaps_after + + +# --------------------------------------------------------------------------- +# Analysis tool tests +# --------------------------------------------------------------------------- + + +def test_get_ip_owners_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_ip_owners should return a JSON array of IP ownership rows.""" + data = _call_tool(mcp, "get_ip_owners", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_get_node_properties_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_node_properties should return a JSON array of node property rows.""" + data = _call_tool(mcp, "get_node_properties", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_get_interface_properties_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_interface_properties should return a JSON array.""" + data = _call_tool(mcp, "get_interface_properties", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_get_routes_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_routes should return a JSON array of route rows.""" + data = _call_tool(mcp, "get_routes", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_get_routes_no_legacy_nexthop_columns(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_routes must not include deprecated next-hop columns in results.""" + data = _call_tool(mcp, "get_routes", {"network": network, "snapshot": snapshot}) + legacy_cols = {"Next_Hop_IP", "Next_Hop_Interface", "NextHopIp", "NextHopInterface"} + for row in data: + for col in legacy_cols: + assert col not in row, f"Deprecated column '{col}' found in routes result" + + +def test_get_bgp_session_status_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_bgp_session_status should return a JSON array.""" + data = _call_tool(mcp, "get_bgp_session_status", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_detect_loops_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """detect_loops should return a JSON array (empty if no loops found).""" + data = _call_tool(mcp, "detect_loops", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_get_undefined_references_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """get_undefined_references should return a JSON array.""" + data = _call_tool(mcp, "get_undefined_references", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) + + +def test_analyze_acl_returns_list(mcp: FastMCP, network: str, snapshot: str) -> None: + """analyze_acl should return a JSON array of unreachable ACL line rows.""" + data = _call_tool(mcp, "analyze_acl", {"network": network, "snapshot": snapshot}) + assert isinstance(data, list) diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..e7fc41bf --- /dev/null +++ b/tests/mcp/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/mcp/test_server.py b/tests/mcp/test_server.py new file mode 100644 index 00000000..941b0e5f --- /dev/null +++ b/tests/mcp/test_server.py @@ -0,0 +1,1138 @@ +# Copyright 2018 The Batfish Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the Batfish MCP server.""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from pybatfish.datamodel import HeaderConstraints, Interface +from pybatfish.mcp.server import ( + _analysis_session, + _build_header_constraints, + _clear_session_cache, + _df_to_json, + _drop_legacy_nexthop_columns, + _mgmt_session, + _parse_interfaces, + _resolve_host, + create_server, +) + +# --------------------------------------------------------------------------- +# Helper factories +# --------------------------------------------------------------------------- + + +def _make_answer_frame(rows: list[dict[str, Any]]) -> MagicMock: + """Return a mock that behaves like a Batfish question answering chain.""" + df = pd.DataFrame(rows) + mock_frame = MagicMock() + mock_frame.frame.return_value = df + mock_answer = MagicMock() + mock_answer.answer.return_value = mock_frame + return mock_answer + + +def _make_session_mock(list_networks: list[str] | None = None, list_snapshots: list[str] | None = None) -> MagicMock: + """Return a mock Session whose question chain is pre-configured.""" + session = MagicMock() + session.list_networks.return_value = list_networks or [] + session.list_snapshots.return_value = list_snapshots or [] + return session + + +def _call_tool(server: Any, tool_name: str, args: dict[str, Any]) -> Any: + """Call an MCP tool and return the parsed JSON result from the first content item.""" + content, _meta = asyncio.run(server.call_tool(tool_name, args)) + return json.loads(content[0].text) + + +# --------------------------------------------------------------------------- +# Unit tests for private helpers +# --------------------------------------------------------------------------- + + +class TestDfToJson: + def test_converts_dataframe(self): + df = pd.DataFrame([{"a": 1, "b": "x"}, {"a": 2, "b": "y"}]) + result = json.loads(_df_to_json(df)) + assert result == [{"a": 1, "b": "x"}, {"a": 2, "b": "y"}] + + def test_converts_non_dataframe(self): + result = json.loads(_df_to_json({"key": "value"})) + assert result == {"key": "value"} + + def test_handles_empty_dataframe(self): + df = pd.DataFrame() + result = json.loads(_df_to_json(df)) + assert result == [] + + +class TestParseInterfaces: + def test_empty_string(self): + assert _parse_interfaces("") == [] + + def test_single_interface(self): + result = _parse_interfaces("router1[GigabitEthernet0/0]") + assert len(result) == 1 + assert result[0].hostname == "router1" + assert result[0].interface == "GigabitEthernet0/0" + + def test_multiple_interfaces(self): + result = _parse_interfaces("r1[Gi0/0], r2[Gi0/1]") + assert len(result) == 2 + assert result[0].hostname == "r1" + assert result[1].hostname == "r2" + + def test_node_without_interface_raises(self): + """Bare node tokens (no [interface]) must raise ValueError.""" + with pytest.raises(ValueError, match="node\\[interface\\]"): + _parse_interfaces("router1") + + def test_mixed_entries_raises(self): + """Any bare node token in a list must raise ValueError.""" + with pytest.raises(ValueError, match="node\\[interface\\]"): + _parse_interfaces("r1[Gi0/0], r2") + + def test_invalid_format_error_message(self): + """ValueError message should include the bad token.""" + with pytest.raises(ValueError, match="bare-node"): + _parse_interfaces("bare-node") + + +class TestBuildHeaderConstraints: + def test_empty_returns_empty_constraints(self): + hc = _build_header_constraints() + assert isinstance(hc, HeaderConstraints) + assert hc.dstIps is None + assert hc.srcIps is None + + def test_dst_ips_set(self): + hc = _build_header_constraints(dst_ips="10.0.0.1") + assert hc.dstIps == "10.0.0.1" + + def test_all_fields(self): + hc = _build_header_constraints( + dst_ips="10.0.0.1", + src_ips="192.168.1.0/24", + applications="SSH", + ip_protocols="TCP", + src_ports="1024-65535", + dst_ports="22", + ) + assert hc.dstIps == "10.0.0.1" + assert hc.srcIps == "192.168.1.0/24" + # HeaderConstraints normalises single-string values to lists + assert "SSH" in hc.applications + assert "TCP" in hc.ipProtocols + assert hc.srcPorts == "1024-65535" + assert hc.dstPorts == "22" + + +class TestSessionCache: + """Tests for the per-host session cache in _get_session.""" + + def setup_method(self): + """Clear the cache before each test to ensure isolation.""" + _clear_session_cache() + + def teardown_method(self): + """Clear the cache after each test.""" + _clear_session_cache() + + def test_session_is_cached(self): + """Session must be created only once for the same host.""" + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session", return_value=mock_session) as MockSession: + from pybatfish.mcp.server import _get_session + + s1 = _get_session("bf-host") + s2 = _get_session("bf-host") + + # Session constructor called only once + assert MockSession.call_count == 1 + # Both calls return the same cached object + assert s1 is s2 + + def test_different_hosts_get_different_cached_sessions(self): + """Each host gets its own independent cache entry.""" + mock_a = MagicMock() + mock_b = MagicMock() + sessions = [mock_a, mock_b] + with patch("pybatfish.mcp.server.Session", side_effect=sessions) as MockSession: + from pybatfish.mcp.server import _get_session + + sa = _get_session("host-a") + sb = _get_session("host-b") + + assert MockSession.call_count == 2 + assert sa is not sb + + def test_clear_session_cache_forces_new_session(self): + """After _clear_session_cache(), the next call creates a fresh session.""" + from pybatfish.mcp.server import _clear_session_cache, _get_session + + mock_first = MagicMock() + mock_second = MagicMock() + sessions = [mock_first, mock_second] + with patch("pybatfish.mcp.server.Session", side_effect=sessions) as MockSession: + s1 = _get_session("bf-host") + _clear_session_cache() + s2 = _get_session("bf-host") + + assert MockSession.call_count == 2 + assert s1 is not s2 + + +class TestResolveHost: + """Tests for the _resolve_host() helper.""" + + def test_returns_explicit_host(self): + assert _resolve_host("my-host") == "my-host" + + def test_falls_back_to_env_var(self, monkeypatch): + monkeypatch.setenv("BATFISH_HOST", "env-host") + assert _resolve_host("") == "env-host" + + def test_falls_back_to_localhost(self, monkeypatch): + monkeypatch.delenv("BATFISH_HOST", raising=False) + assert _resolve_host("") == "localhost" + + +class TestMgmtSession: + """Tests for the _mgmt_session() helper.""" + + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + + def test_creates_cached_session(self): + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.return_value = MagicMock() + _mgmt_session("localhost") + MockSession.assert_called_once_with(host="localhost") + + def test_sets_network_when_provided(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session", return_value=mock_session): + _mgmt_session("localhost", "my-network") + mock_session.set_network.assert_called_once_with("my-network") + + def test_skips_set_network_when_empty(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session", return_value=mock_session): + _mgmt_session("localhost", "") + mock_session.set_network.assert_not_called() + + def test_resolves_host_from_env(self, monkeypatch): + monkeypatch.setenv("BATFISH_HOST", "env-bf") + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.return_value = MagicMock() + _mgmt_session("") + MockSession.assert_called_once_with(host="env-bf") + + def test_shares_cache_with_analysis_session(self): + """_mgmt_session and _analysis_session must return the same cached session.""" + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session", return_value=mock_session) as MockSession: + s1 = _mgmt_session("localhost") + s2 = _analysis_session("localhost", "net1", "snap1") + # Session constructor called only once — both helpers share the cache + assert MockSession.call_count == 1 + assert s1 is s2 + + +class TestAnalysisSession: + """Tests for the _analysis_session() helper.""" + + def setup_method(self): + _clear_session_cache() + + def teardown_method(self): + _clear_session_cache() + + def test_sets_network_and_snapshot(self): + mock_session = MagicMock() + with patch("pybatfish.mcp.server.Session", return_value=mock_session): + _analysis_session("localhost", "net1", "snap1") + mock_session.set_network.assert_called_once_with("net1") + mock_session.set_snapshot.assert_called_once_with("snap1") + + def test_creates_cached_session(self): + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.return_value = MagicMock() + _analysis_session("localhost", "net1", "snap1") + MockSession.assert_called_once_with(host="localhost") + + def test_resolves_host_from_env(self, monkeypatch): + _clear_session_cache() + monkeypatch.setenv("BATFISH_HOST", "env-bf") + with patch("pybatfish.mcp.server.Session") as MockSession: + MockSession.return_value = MagicMock() + _analysis_session("", "net1", "snap1") + MockSession.assert_called_once_with(host="env-bf") + + +class TestDropLegacyNexthopColumns: + def test_drops_known_legacy_columns(self): + df = pd.DataFrame( + [ + { + "Node": "r1", + "Network": "10.0.0.0/8", + "Next_Hop": "ip 1.2.3.4", + "Next_Hop_IP": "1.2.3.4", + "Next_Hop_Interface": "GigabitEthernet0/0", + } + ] + ) + result = _drop_legacy_nexthop_columns(df) + assert "Next_Hop" in result.columns + assert "Next_Hop_IP" not in result.columns + assert "Next_Hop_Interface" not in result.columns + + def test_leaves_unrelated_columns_intact(self): + df = pd.DataFrame([{"Node": "r1", "Network": "10.0.0.0/8", "Next_Hop": "ip 1.2.3.4"}]) + result = _drop_legacy_nexthop_columns(df) + assert list(result.columns) == ["Node", "Network", "Next_Hop"] + + def test_handles_non_dataframe_gracefully(self): + result = _drop_legacy_nexthop_columns({"key": "value"}) + assert result == {"key": "value"} + + def test_drops_camelcase_variants(self): + df = pd.DataFrame([{"Node": "r1", "NextHopIp": "1.2.3.4", "NextHopInterface": "Gi0/0"}]) + result = _drop_legacy_nexthop_columns(df) + assert "NextHopIp" not in result.columns + assert "NextHopInterface" not in result.columns + + +# --------------------------------------------------------------------------- +# Integration-style tests for MCP server tools (Session is mocked) +# --------------------------------------------------------------------------- + + +PATCH_TARGET = "pybatfish.mcp.server._get_session" + + +class TestCreateServer: + def test_returns_fastmcp_instance(self): + from mcp.server.fastmcp import FastMCP + + server = create_server() + assert isinstance(server, FastMCP) + + def test_custom_name(self): + server = create_server(name="MyBatfish") + assert server.name == "MyBatfish" + + +class TestListNetworksTool: + def test_returns_network_list(self): + mock_session = _make_session_mock(list_networks=["net1", "net2"]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool(server, "list_networks", {"host": "localhost"}) + assert data == ["net1", "net2"] + + def test_uses_env_host(self, monkeypatch): + monkeypatch.setenv("BATFISH_HOST", "my-bf-host") + mock_session = _make_session_mock(list_networks=["net1"]) + with patch(PATCH_TARGET, return_value=mock_session) as mock_get: + server = create_server() + _call_tool(server, "list_networks", {}) + mock_get.assert_called_once_with("my-bf-host") + + +class TestSetNetworkTool: + def test_returns_network_name(self): + mock_session = MagicMock() + mock_session.set_network.return_value = "my-network" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool(server, "set_network", {"network": "my-network", "host": "localhost"}) + assert data == {"network": "my-network"} + + +class TestDeleteNetworkTool: + def test_returns_deleted_name(self): + mock_session = MagicMock() + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool(server, "delete_network", {"network": "old-net", "host": "localhost"}) + assert data == {"deleted": "old-net"} + mock_session.delete_network.assert_called_once_with("old-net") + + +class TestListSnapshotsTool: + def test_returns_snapshot_list(self): + mock_session = _make_session_mock(list_snapshots=["snap1", "snap2"]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool(server, "list_snapshots", {"network": "net1", "host": "localhost"}) + assert data == ["snap1", "snap2"] + + +class TestInitSnapshotTool: + def test_returns_snapshot_name(self): + mock_session = MagicMock() + mock_session.init_snapshot.return_value = "my-snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "init_snapshot", + {"network": "net1", "snapshot_path": "/path/to/snap", "host": "localhost"}, + ) + assert data == {"snapshot": "my-snap"} + + def test_passes_name_and_overwrite(self): + mock_session = MagicMock() + mock_session.init_snapshot.return_value = "named-snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "init_snapshot", + { + "network": "net1", + "snapshot_path": "/path", + "snapshot_name": "named-snap", + "overwrite": True, + "host": "localhost", + }, + ) + mock_session.init_snapshot.assert_called_once_with("/path", name="named-snap", overwrite=True) + + +class TestInitSnapshotFromTextTool: + def test_returns_snapshot_name(self): + mock_session = MagicMock() + mock_session.init_snapshot_from_text.return_value = "text-snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "init_snapshot_from_text", + {"network": "net1", "config_text": "hostname router1", "host": "localhost"}, + ) + assert data == {"snapshot": "text-snap"} + + def test_passes_platform_when_set(self): + mock_session = MagicMock() + mock_session.init_snapshot_from_text.return_value = "snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "init_snapshot_from_text", + { + "network": "net1", + "config_text": "config", + "platform": "arista", + "host": "localhost", + }, + ) + call_kwargs = mock_session.init_snapshot_from_text.call_args[1] + assert call_kwargs["platform"] == "arista" + + def test_passes_none_platform_when_empty(self): + mock_session = MagicMock() + mock_session.init_snapshot_from_text.return_value = "snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "init_snapshot_from_text", + {"network": "net1", "config_text": "config", "host": "localhost"}, + ) + call_kwargs = mock_session.init_snapshot_from_text.call_args[1] + assert call_kwargs["platform"] is None + + +class TestDeleteSnapshotTool: + def test_returns_deleted_name(self): + mock_session = MagicMock() + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool(server, "delete_snapshot", {"network": "net1", "snapshot": "snap1", "host": "localhost"}) + assert data == {"deleted": "snap1"} + mock_session.delete_snapshot.assert_called_once_with("snap1") + + +class TestForkSnapshotTool: + def test_basic_fork(self): + mock_session = MagicMock() + mock_session.fork_snapshot.return_value = "forked-snap" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "fork_snapshot", + {"network": "net1", "base_snapshot": "base", "new_snapshot": "forked", "host": "localhost"}, + ) + assert data == {"snapshot": "forked-snap"} + + def test_deactivate_nodes(self): + mock_session = MagicMock() + mock_session.fork_snapshot.return_value = "forked" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "fork_snapshot", + { + "network": "net1", + "base_snapshot": "base", + "deactivate_nodes": "r1,r2", + "host": "localhost", + }, + ) + call_kwargs = mock_session.fork_snapshot.call_args[1] + assert call_kwargs["deactivate_nodes"] == ["r1", "r2"] + + def test_deactivate_interfaces(self): + mock_session = MagicMock() + mock_session.fork_snapshot.return_value = "forked" + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "fork_snapshot", + { + "network": "net1", + "base_snapshot": "base", + "deactivate_interfaces": "r1[Gi0/0]", + "host": "localhost", + }, + ) + call_kwargs = mock_session.fork_snapshot.call_args[1] + assert call_kwargs["deactivate_interfaces"] == [Interface(hostname="r1", interface="Gi0/0")] + + +class TestRunTracerouteTool: + def test_returns_json_rows(self): + rows = [{"Flow": "f1", "Traces": "t1"}] + mock_session = MagicMock() + mock_session.q.traceroute.return_value = _make_answer_frame(rows) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "run_traceroute", + { + "network": "net1", + "snapshot": "snap1", + "start_location": "router1", + "dst_ips": "10.0.0.1", + "host": "localhost", + }, + ) + assert len(data) == 1 + assert data[0]["Flow"] == "f1" + + def test_optional_header_params_passed(self): + mock_session = MagicMock() + mock_session.q.traceroute.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "run_traceroute", + { + "network": "net1", + "snapshot": "snap1", + "start_location": "router1", + "dst_ips": "10.0.0.1", + "src_ips": "192.168.0.1", + "applications": "ssh", + "ip_protocols": "TCP", + "src_ports": "1024", + "dst_ports": "22", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.traceroute.call_args[1] + assert call_kwargs["headers"].dstIps == "10.0.0.1" + assert call_kwargs["headers"].srcIps == "192.168.0.1" + + +class TestRunBidirectionalTracerouteTool: + def test_returns_json_rows(self): + rows = [{"Forward_Flow": "f1", "Reverse_Flow": "f2"}] + mock_session = MagicMock() + mock_session.q.bidirectionalTraceroute.return_value = _make_answer_frame(rows) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "run_bidirectional_traceroute", + { + "network": "net1", + "snapshot": "snap1", + "start_location": "router1", + "dst_ips": "10.0.0.1", + "host": "localhost", + }, + ) + assert len(data) == 1 + assert data[0]["Forward_Flow"] == "f1" + + def test_optional_header_params_passed(self): + mock_session = MagicMock() + mock_session.q.bidirectionalTraceroute.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "run_bidirectional_traceroute", + { + "network": "net1", + "snapshot": "snap1", + "start_location": "router1", + "dst_ips": "10.0.0.1", + "src_ips": "192.168.0.1", + "applications": "ssh", + "ip_protocols": "TCP", + "src_ports": "1024", + "dst_ports": "22", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.bidirectionalTraceroute.call_args[1] + assert call_kwargs["headers"].dstIps == "10.0.0.1" + assert call_kwargs["headers"].srcIps == "192.168.0.1" + + +class TestCheckReachabilityTool: + def test_basic_call(self): + mock_session = MagicMock() + mock_session.q.reachability.return_value = _make_answer_frame([{"Flow": "f", "Action": "ACCEPT"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "check_reachability", + {"network": "net1", "snapshot": "snap1", "dst_ips": "8.8.8.8", "host": "localhost"}, + ) + assert data[0]["Action"] == "ACCEPT" + + def test_optional_params_passed(self): + mock_session = MagicMock() + mock_session.q.reachability.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "check_reachability", + { + "network": "net1", + "snapshot": "snap1", + "src_locations": "router1", + "actions": "DENIED_IN,DROP", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.reachability.call_args[1] + assert call_kwargs["pathConstraints"] == {"startLocation": "router1"} + assert call_kwargs["actions"] == "DENIED_IN,DROP" + + +class TestAnalyzeAclTool: + def test_returns_acl_rows(self): + mock_session = MagicMock() + mock_session.q.filterLineReachability.return_value = _make_answer_frame([{"Filter": "acl1"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "analyze_acl", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Filter"] == "acl1" + + def test_optional_params_passed(self): + mock_session = MagicMock() + mock_session.q.filterLineReachability.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "analyze_acl", + {"network": "net1", "snapshot": "snap1", "filters": "acl1", "nodes": "r1", "host": "localhost"}, + ) + call_kwargs = mock_session.q.filterLineReachability.call_args[1] + assert call_kwargs["filters"] == "acl1" + assert call_kwargs["nodes"] == "r1" + + +class TestSearchFiltersTool: + def test_permit_action_passed(self): + mock_session = MagicMock() + mock_session.q.searchFilters.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "search_filters", + {"network": "net1", "snapshot": "snap1", "action": "PERMIT", "host": "localhost"}, + ) + call_kwargs = mock_session.q.searchFilters.call_args[1] + assert call_kwargs["action"] == "PERMIT" + + def test_optional_filters_and_nodes_passed(self): + mock_session = MagicMock() + mock_session.q.searchFilters.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "search_filters", + { + "network": "net1", + "snapshot": "snap1", + "filters": "acl1", + "nodes": "r1", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.searchFilters.call_args[1] + assert call_kwargs["filters"] == "acl1" + assert call_kwargs["nodes"] == "r1" + + +class TestGetRoutesTool: + def test_returns_routes(self): + mock_session = MagicMock() + mock_session.q.routes.return_value = _make_answer_frame([{"Node": "r1", "Network": "0.0.0.0/0"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_routes", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Node"] == "r1" + + def test_legacy_nexthop_columns_dropped(self): + mock_session = MagicMock() + mock_session.q.routes.return_value = _make_answer_frame( + [ + { + "Node": "r1", + "Network": "0.0.0.0/0", + "Next_Hop": "ip 1.2.3.4", + "Next_Hop_IP": "1.2.3.4", + "Next_Hop_Interface": "GigabitEthernet0/0", + } + ] + ) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_routes", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert "Next_Hop" in data[0] + assert "Next_Hop_IP" not in data[0] + assert "Next_Hop_Interface" not in data[0] + + def test_filters_passed(self): + mock_session = MagicMock() + mock_session.q.routes.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_routes", + { + "network": "net1", + "snapshot": "snap1", + "nodes": "r1", + "vrfs": "default", + "network_prefix": "10.0.0.0/8", + "protocols": "bgp", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.routes.call_args[1] + assert call_kwargs["nodes"] == "r1" + assert call_kwargs["vrfs"] == "default" + assert call_kwargs["network"] == "10.0.0.0/8" + assert call_kwargs["protocols"] == "bgp" + + +class TestCompareRoutesTool: + def test_calls_differential_answer(self): + mock_frame_obj = MagicMock() + mock_frame_obj.frame.return_value = pd.DataFrame([{"Node": "r1"}]) + mock_answer_obj = MagicMock() + mock_answer_obj.answer.return_value = mock_frame_obj + mock_session = MagicMock() + mock_session.q.routes.return_value = mock_answer_obj + + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "compare_routes", + { + "network": "net1", + "snapshot": "snap-new", + "reference_snapshot": "snap-old", + "host": "localhost", + }, + ) + mock_answer_obj.answer.assert_called_once_with(snapshot="snap-new", reference_snapshot="snap-old") + + def test_optional_filters_passed(self): + mock_frame_obj = MagicMock() + mock_frame_obj.frame.return_value = pd.DataFrame([]) + mock_answer_obj = MagicMock() + mock_answer_obj.answer.return_value = mock_frame_obj + mock_session = MagicMock() + mock_session.q.routes.return_value = mock_answer_obj + + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "compare_routes", + { + "network": "net1", + "snapshot": "snap-new", + "reference_snapshot": "snap-old", + "nodes": "r1", + "vrfs": "default", + "network_prefix": "10.0.0.0/8", + "protocols": "bgp", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.routes.call_args[1] + assert call_kwargs["nodes"] == "r1" + assert call_kwargs["vrfs"] == "default" + assert call_kwargs["network"] == "10.0.0.0/8" + assert call_kwargs["protocols"] == "bgp" + + +class TestGetBgpSessionStatusTool: + def test_returns_bgp_rows(self): + mock_session = MagicMock() + mock_session.q.bgpSessionStatus.return_value = _make_answer_frame([{"Node": "r1", "Status": "ESTABLISHED"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_bgp_session_status", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Status"] == "ESTABLISHED" + + def test_optional_params_passed(self): + mock_session = MagicMock() + mock_session.q.bgpSessionStatus.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_bgp_session_status", + { + "network": "net1", + "snapshot": "snap1", + "nodes": "r1", + "remote_nodes": "r2", + "status": "ESTABLISHED", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.bgpSessionStatus.call_args[1] + assert call_kwargs["nodes"] == "r1" + assert call_kwargs["remoteNodes"] == "r2" + assert call_kwargs["status"] == "ESTABLISHED" + + +class TestGetBgpSessionCompatibilityTool: + def test_returns_compat_rows(self): + mock_session = MagicMock() + mock_session.q.bgpSessionCompatibility.return_value = _make_answer_frame([{"Node": "r1"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_bgp_session_compatibility", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Node"] == "r1" + + def test_optional_params_passed(self): + mock_session = MagicMock() + mock_session.q.bgpSessionCompatibility.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_bgp_session_compatibility", + { + "network": "net1", + "snapshot": "snap1", + "nodes": "r1", + "remote_nodes": "r2", + "status": "UNIQUE_MATCH", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.bgpSessionCompatibility.call_args[1] + assert call_kwargs["nodes"] == "r1" + assert call_kwargs["remoteNodes"] == "r2" + assert call_kwargs["status"] == "UNIQUE_MATCH" + + +class TestGetNodePropertiesTool: + def test_returns_node_properties(self): + mock_session = MagicMock() + mock_session.q.nodeProperties.return_value = _make_answer_frame([{"Node": "r1", "AS_Path_Access_Lists": []}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_node_properties", + {"network": "net1", "snapshot": "snap1", "nodes": "r1", "host": "localhost"}, + ) + assert data[0]["Node"] == "r1" + + def test_properties_param_passed(self): + mock_session = MagicMock() + mock_session.q.nodeProperties.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_node_properties", + {"network": "net1", "snapshot": "snap1", "properties": "Hostname,NTP_Servers", "host": "localhost"}, + ) + call_kwargs = mock_session.q.nodeProperties.call_args[1] + assert call_kwargs["properties"] == "Hostname,NTP_Servers" + + +class TestGetInterfacePropertiesTool: + def test_returns_interface_properties(self): + mock_session = MagicMock() + mock_session.q.interfaceProperties.return_value = _make_answer_frame([{"Interface": "r1[Gi0/0]"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_interface_properties", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Interface"] == "r1[Gi0/0]" + + def test_optional_params_passed(self): + mock_session = MagicMock() + mock_session.q.interfaceProperties.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_interface_properties", + { + "network": "net1", + "snapshot": "snap1", + "nodes": "r1", + "interfaces": "Gi0/0", + "properties": "Active,Description", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.interfaceProperties.call_args[1] + assert call_kwargs["nodes"] == "r1" + assert call_kwargs["interfaces"] == "Gi0/0" + assert call_kwargs["properties"] == "Active,Description" + + +class TestGetIpOwnersTool: + def test_returns_ip_rows(self): + mock_session = MagicMock() + mock_session.q.ipOwners.return_value = _make_answer_frame([{"IP": "10.0.0.1", "Node": "r1"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_ip_owners", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["IP"] == "10.0.0.1" + + def test_duplicates_only_flag(self): + mock_session = MagicMock() + mock_session.q.ipOwners.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_ip_owners", + {"network": "net1", "snapshot": "snap1", "duplicates_only": True, "host": "localhost"}, + ) + mock_session.q.ipOwners.assert_called_once_with(duplicatesOnly=True) + + +class TestCompareFiltersTool: + def test_calls_differential_answer(self): + mock_frame_obj = MagicMock() + mock_frame_obj.frame.return_value = pd.DataFrame([{"Node": "r1"}]) + mock_answer_obj = MagicMock() + mock_answer_obj.answer.return_value = mock_frame_obj + mock_session = MagicMock() + mock_session.q.compareFilters.return_value = mock_answer_obj + + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "compare_filters", + { + "network": "net1", + "snapshot": "snap-new", + "reference_snapshot": "snap-old", + "host": "localhost", + }, + ) + mock_answer_obj.answer.assert_called_once_with(snapshot="snap-new", reference_snapshot="snap-old") + + def test_optional_params_passed(self): + mock_frame_obj = MagicMock() + mock_frame_obj.frame.return_value = pd.DataFrame([]) + mock_answer_obj = MagicMock() + mock_answer_obj.answer.return_value = mock_frame_obj + mock_session = MagicMock() + mock_session.q.compareFilters.return_value = mock_answer_obj + + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "compare_filters", + { + "network": "net1", + "snapshot": "snap-new", + "reference_snapshot": "snap-old", + "filters": "acl1", + "nodes": "r1", + "host": "localhost", + }, + ) + call_kwargs = mock_session.q.compareFilters.call_args[1] + assert call_kwargs["filters"] == "acl1" + assert call_kwargs["nodes"] == "r1" + + +class TestGetUndefinedReferencesTool: + def test_returns_reference_rows(self): + mock_session = MagicMock() + mock_session.q.undefinedReferences.return_value = _make_answer_frame([{"Node": "r1", "Ref_Name": "acl-foo"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "get_undefined_references", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Ref_Name"] == "acl-foo" + + def test_nodes_param_passed(self): + mock_session = MagicMock() + mock_session.q.undefinedReferences.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + _call_tool( + server, + "get_undefined_references", + {"network": "net1", "snapshot": "snap1", "nodes": "r1", "host": "localhost"}, + ) + call_kwargs = mock_session.q.undefinedReferences.call_args[1] + assert call_kwargs["nodes"] == "r1" + + +class TestDetectLoopsTool: + def test_returns_loop_rows(self): + mock_session = MagicMock() + mock_session.q.detectLoops.return_value = _make_answer_frame([{"Node": "r1"}]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "detect_loops", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data[0]["Node"] == "r1" + + def test_no_loops(self): + mock_session = MagicMock() + mock_session.q.detectLoops.return_value = _make_answer_frame([]) + with patch(PATCH_TARGET, return_value=mock_session): + server = create_server() + data = _call_tool( + server, + "detect_loops", + {"network": "net1", "snapshot": "snap1", "host": "localhost"}, + ) + assert data == [] + + +class TestToolListCompleteness: + """Verify the server exposes the expected set of tools.""" + + EXPECTED_TOOLS = { + "list_networks", + "set_network", + "delete_network", + "list_snapshots", + "init_snapshot", + "init_snapshot_from_text", + "delete_snapshot", + "fork_snapshot", + "run_traceroute", + "run_bidirectional_traceroute", + "check_reachability", + "analyze_acl", + "search_filters", + "get_routes", + "compare_routes", + "get_bgp_session_status", + "get_bgp_session_compatibility", + "get_node_properties", + "get_interface_properties", + "get_ip_owners", + "compare_filters", + "get_undefined_references", + "detect_loops", + } + + def test_all_expected_tools_registered(self): + server = create_server() + tools = asyncio.run(server.list_tools()) + tool_names = {t.name for t in tools} + assert self.EXPECTED_TOOLS == tool_names