diff --git a/docs/libp2p.security.pnet.rst b/docs/libp2p.security.pnet.rst new file mode 100644 index 000000000..9e7be3ea5 --- /dev/null +++ b/docs/libp2p.security.pnet.rst @@ -0,0 +1,29 @@ +libp2p.security.pnet package +================================ + +Submodules +---------- + +libp2p.security.pnet.protector module +------------------------------------- + +.. automodule:: libp2p.security.pnet.protector + :members: + :undoc-members: + :show-inheritance: + +libp2p.security.pnet.psk_conn module +------------------------------------ + +.. automodule:: libp2p.security.pnet.psk_conn + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: libp2p.security.pnet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/libp2p.security.rst b/docs/libp2p.security.rst index fc55df33b..41ea5f399 100644 --- a/docs/libp2p.security.rst +++ b/docs/libp2p.security.rst @@ -9,6 +9,7 @@ Subpackages libp2p.security.insecure libp2p.security.noise + libp2p.security.pnet libp2p.security.secio libp2p.security.tls diff --git a/examples/ping/ping.py b/examples/ping/ping.py index f62689aa5..130b63330 100644 --- a/examples/ping/ping.py +++ b/examples/ping/ping.py @@ -4,6 +4,7 @@ import multiaddr import trio +from examples.advanced.network_discover import get_optimal_binding_address from libp2p import ( new_host, ) @@ -25,6 +26,7 @@ PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0") PING_LENGTH = 32 RESP_TIMEOUT = 60 +PSK = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" async def handle_ping(stream: INetStream) -> None: @@ -60,18 +62,27 @@ async def send_ping(stream: INetStream) -> None: print(f"error occurred : {e}") -async def run(port: int, destination: str) -> None: +async def run(port: int, destination: str, psk: int, transport: str) -> None: from libp2p.utils.address_validation import ( find_free_port, get_available_interfaces, - get_optimal_binding_address, ) if port <= 0: port = find_free_port() - listen_addrs = get_available_interfaces(port) - host = new_host(listen_addrs=listen_addrs) + _ = get_available_interfaces(8000) + _ = get_optimal_binding_address(8000) + + if transport == "tcp": + listen_addrs = get_available_interfaces(port) + if transport == "ws": + listen_addrs = [multiaddr.Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/ws")] + + if psk == 1: + host = new_host(listen_addrs=listen_addrs, psk=PSK) + else: + host = new_host(listen_addrs=listen_addrs) async with host.run(listen_addrs=listen_addrs), trio.open_nursery() as nursery: # Start the peer-store cleanup task @@ -87,12 +98,9 @@ async def run(port: int, destination: str) -> None: for addr in all_addrs: print(f"{addr}") - # Use optimal address for the client command - optimal_addr = get_optimal_binding_address(port) - optimal_addr_with_peer = f"{optimal_addr}/p2p/{host.get_id().to_string()}" print( f"\nRun this from the same folder in another console:\n\n" - f"ping-demo -d {optimal_addr_with_peer}\n" + f"ping-demo -d {host.get_addrs()[0]} -psk {psk} -t {transport}\n" ) print("Waiting for incoming connection...") @@ -130,10 +138,23 @@ def main() -> None: type=str, help=f"destination multiaddr string, e.g. {example_maddr}", ) + + parser.add_argument( + "-psk", "--psk", default=0, type=int, help="Enable PSK in the transport layer" + ) + + parser.add_argument( + "-t", + "--transport", + default="tcp", + type=str, + help="Choose the transport layer for ping TCP/WS", + ) + args = parser.parse_args() try: - trio.run(run, *(args.port, args.destination)) + trio.run(run, *(args.port, args.destination, args.psk, args.transport)) except KeyboardInterrupt: pass diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 14075b988..7f9a82d42 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -199,6 +199,7 @@ def new_swarm( tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, resource_manager: ResourceManager | None = None, + psk: str | None = None ) -> INetworkService: logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}") """ @@ -214,6 +215,7 @@ def new_swarm( :param quic_transport_opt: options for transport :param resource_manager: optional resource manager for connection/stream limits :type resource_manager: :class:`libp2p.rcmgr.ResourceManager` or None + :param psk: optional pre-shared key for PSK encryption in transport :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -324,7 +326,8 @@ def new_swarm( upgrader, transport, retry_config=retry_config, - connection_config=connection_config + connection_config=connection_config, + psk=psk ) # Set resource manager if provided @@ -342,6 +345,21 @@ def new_swarm( return swarm + # Set resource manager if provided + # Auto-create a default ResourceManager if one was not provided + if resource_manager is None: + try: + from libp2p.rcmgr import new_resource_manager as _new_rm + + resource_manager = _new_rm() + except Exception: + resource_manager = None + + if resource_manager is not None: + swarm.set_resource_manager(resource_manager) + + return swarm + def new_host( key_pair: KeyPair | None = None, @@ -360,6 +378,7 @@ def new_host( tls_client_config: ssl.SSLContext | None = None, tls_server_config: ssl.SSLContext | None = None, resource_manager: ResourceManager | None = None, + psk: str | None = None ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -379,6 +398,7 @@ def new_host( :param tls_server_config: optional TLS server configuration for WebSocket transport :param resource_manager: optional resource manager for connection/stream limits :type resource_manager: :class:`libp2p.rcmgr.ResourceManager` or None + :param psk: optional pre-shared key (PSK) :return: return a host instance """ @@ -408,6 +428,7 @@ def new_host( tls_client_config=tls_client_config, tls_server_config=tls_server_config, resource_manager=resource_manager, + psk=psk ) if disc_opt is not None: diff --git a/libp2p/host/basic_host.py b/libp2p/host/basic_host.py index 54893e2ab..1682461fa 100644 --- a/libp2p/host/basic_host.py +++ b/libp2p/host/basic_host.py @@ -108,6 +108,7 @@ def __init__( default_protocols: OrderedDict[TProtocol, StreamHandlerFn] | None = None, negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, resource_manager: ResourceManager | None = None, + psk: str | None = None, ) -> None: """ Initialize a BasicHost instance. @@ -148,6 +149,7 @@ def __init__( self.bootstrap = None if bootstrap: self.bootstrap = BootstrapDiscovery(network, bootstrap) + self.psk = psk # Cache a signed-record if the local-node in the PeerStore envelope = create_signed_peer_record( diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index e0db2ffda..4853b090d 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -39,6 +39,7 @@ PeerStoreError, ) from libp2p.rcmgr.manager import ResourceManager +from libp2p.security.pnet.protector import new_protected_conn from libp2p.tools.async_service import ( Service, ) @@ -103,11 +104,13 @@ def __init__( transport: ITransport, retry_config: RetryConfig | None = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, + psk: str | None = None, ): self.self_id = peer_id self.peerstore = peerstore self.upgrader = upgrader self.transport = transport + self.psk = psk # Enhanced: Initialize retry and connection configuration self.retry_config = retry_config or RetryConfig() @@ -355,6 +358,10 @@ async def _dial_addr_single_attempt(self, addr: Multiaddr, peer_id: ID) -> INetC try: addr = Multiaddr(f"{addr}/p2p/{peer_id}") raw_conn = await self.transport.dial(addr) + + # Enable PNET if psk is provvided + if self.psk is not None: + raw_conn = new_protected_conn(raw_conn, self.psk) except OpenConnectionError as error: logger.debug("fail to dial peer %s over base transport", peer_id) # Release pre-upgrade scope on failure @@ -678,6 +685,10 @@ async def upgrade_inbound_raw_conn( :raises SwarmException: raised when security or muxer upgrade fails :return: network connection with security and multiplexing established """ + # Enable PNET is psk is provided + if self.psk is not None: + raw_conn = new_protected_conn(raw_conn, self.psk) + # secure the conn and then mux the conn try: secured_conn = await self.upgrader.upgrade_security(raw_conn, False) diff --git a/libp2p/security/pnet/__init__.py b/libp2p/security/pnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libp2p/security/pnet/protector.py b/libp2p/security/pnet/protector.py new file mode 100644 index 000000000..af9143f0d --- /dev/null +++ b/libp2p/security/pnet/protector.py @@ -0,0 +1,10 @@ +from libp2p.abc import IRawConnection +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.security.pnet.psk_conn import PskConn + + +def new_protected_conn(conn: RawConnection | IRawConnection, psk: str) -> PskConn: + if len(psk) != 64: + raise ValueError("Expected 32-byte pre shared key (PSK)") + + return PskConn(conn, psk) diff --git a/libp2p/security/pnet/psk_conn.py b/libp2p/security/pnet/psk_conn.py new file mode 100644 index 000000000..5ef8a06c8 --- /dev/null +++ b/libp2p/security/pnet/psk_conn.py @@ -0,0 +1,58 @@ +import os + +from Crypto.Cipher import Salsa20 + +from libp2p.abc import IRawConnection +from libp2p.network.connection.raw_connection import RawConnection + + +class PskConn(RawConnection): + _psk: bytes + _conn: RawConnection | IRawConnection + + def __init__(self, conn: RawConnection | IRawConnection, psk: str) -> None: + self._psk = bytes.fromhex(psk) + self._conn = conn + + self.read_cipher: Salsa20.Salsa20Cipher | None = None + self.write_cipher: Salsa20.Salsa20Cipher | None = None + + async def write(self, data: bytes) -> None: + """ + Encrpyts and writes data to the stream. + On the first call, generates a 24-byte nonce and sends it first. + """ + if self.write_cipher is None: + nonce = os.urandom(8) + await self._conn.write(nonce) + self.write_cipher = Salsa20.new(key=self._psk, nonce=nonce) + + assert self.write_cipher is not None + ciphertext = self.write_cipher.encrypt(data) + + await self._conn.write(ciphertext) + + async def read(self, n: int | None = None) -> bytes: + """ + Reads and decrypts data. On the first call, it reads a 8-byte + nonce to initialize the decryption stream + """ + if self.read_cipher is None: + nonce = await self._conn.read(8) + if len(nonce) != 8: + raise ValueError("short nonce from stream") + + self.read_cipher = Salsa20.new(key=self._psk, nonce=nonce) + + data = await self._conn.read(n) + if not data: + return b"" + + plaintext = self.read_cipher.decrypt(data) + return plaintext + + async def close(self) -> None: + await self._conn.close() + + def get_remote_address(self) -> tuple[str, int] | None: + return self._conn.get_remote_address() diff --git a/tests/core/security/test_pnet.py b/tests/core/security/test_pnet.py new file mode 100644 index 000000000..e12ccfc44 --- /dev/null +++ b/tests/core/security/test_pnet.py @@ -0,0 +1,82 @@ +import pytest +import trio + +from libp2p.io.abc import ReadWriteCloser +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.security.pnet.protector import new_protected_conn + + +# --- MemoryPipe: implements ReadWriteCloser interface --- +class MemoryPipe(ReadWriteCloser): + """Wrap a pair of Trio memory channels into a ReadWriteCloser-like object.""" + + def __init__( + self, send: trio.MemorySendChannel, receive: trio.MemoryReceiveChannel + ): + self._send = send + self._receive = receive + + async def read(self, n: int | None = None) -> bytes: + """Read next chunk from receive channel.""" + return await self._receive.receive() + + async def write(self, data: bytes) -> None: + """Write a chunk to send channel.""" + await self._send.send(data) + + async def close(self) -> None: + """Close channels (noop for memory channels).""" + pass + + def get_remote_address(self) -> tuple[str, int] | None: + # Memory pipe doesn’t have a real address, so return None + return None + + +# --- Helper function to create a connected pair of PskConns --- +async def make_psk_pair(psk_hex: str): + send1, recv1 = trio.open_memory_channel(0) + send2, recv2 = trio.open_memory_channel(0) + + pipe1 = MemoryPipe(send1, recv2) + pipe2 = MemoryPipe(send2, recv1) + + raw1 = RawConnection(pipe1, False) + raw2 = RawConnection(pipe2, False) + + # NOTE: The new_protected_conn function needs to perform the handshake. + # We'll assume it does for this example. If not, a handshake() call + # might be needed here within a nursery. + psk_conn1 = new_protected_conn(raw1, psk_hex) + psk_conn2 = new_protected_conn(raw2, psk_hex) + + return psk_conn1, psk_conn2 + + +@pytest.mark.trio +async def test_psk_simple_message(): + # Use a fixed PSK for testing + psk_hex = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" + conn1, conn2 = await make_psk_pair(psk_hex) + + msg = b"hello world" + + async with trio.open_nursery() as nursery: + nursery.start_soon(conn1.write, msg) + received = await conn2.read(len(msg)) + + assert received == msg, "Decrypted message does not match original" + + +@pytest.mark.trio +async def test_psk_empty_message(): + # PSK for testing + psk_hex = "dffb7e3135399a8b1612b2aaca1c36a3a8ac2cd0cca51ceeb2ced87d308cac6d" + conn1, conn2 = await make_psk_pair(psk_hex) + + # Empty message should round-trip correctly + async with trio.open_nursery() as nursery: + nursery.start_soon(conn1.write, b"") + received = await conn2.read(0) + + assert received == b"", "Empty message failed"