@@ -220,6 +220,8 @@ pub enum EndpointResult {
220220 ///
221221 /// It is common to simply log this error and move on.
222222 IncomingError ,
223+ /// Unable to find connection for the given `NodeId`
224+ ConnectionTypeError ,
223225}
224226
225227/// Attempts to bind the endpoint to the provided IPv4 and IPv6 address.
@@ -713,83 +715,6 @@ pub fn endpoint_accept_any_cb(
713715 } ) ;
714716}
715717
716- /// Run a callback once you have a direct connection to a peer
717- ///
718- /// Does not block. The provided callback will be called when we have a direct
719- /// connection to the peer associated with the `node_id`, or the timeout has occurred.
720- ///
721- /// To wait indefinitely, provide -1 for the timeout parameter.
722- ///
723- /// `ctx` is passed along to the callback, to allow passing context, it must be thread safe as the callback is
724- /// called from another thread.
725- #[ ffi_export]
726- pub fn endpoint_direct_conn_cb (
727- ep : repr_c:: Box < Endpoint > ,
728- ctx : * const c_void ,
729- node_id : & PublicKey ,
730- timeout : isize ,
731- cb : unsafe extern "C" fn ( ctx : * const c_void , res : EndpointResult ) ,
732- ) {
733- // hack around the fact that `*const c_void` is not Send
734- struct CtxPtr ( * const c_void ) ;
735- unsafe impl Send for CtxPtr { }
736- let ctx_ptr = CtxPtr ( ctx) ;
737-
738- let node_id: NodeId = node_id. into ( ) ;
739-
740- TOKIO_EXECUTOR . spawn ( async move {
741- // make the compiler happy
742- let _ = & ctx_ptr;
743-
744- async fn connect ( ep : repr_c:: Box < Endpoint > , node_id : NodeId ) -> anyhow:: Result < ( ) > {
745- ep. ep
746- . read ( )
747- . await
748- . as_ref ( )
749- . expect ( "endpoint not initalized" )
750- . add_node_addr ( iroh:: NodeAddr :: new ( node_id) ) ?;
751-
752- let mut stream = ep
753- . ep
754- . read ( )
755- . await
756- . as_ref ( )
757- . expect ( "endpoint not initalized" )
758- . conn_type ( node_id) ?
759- . stream ( ) ;
760-
761- while let Some ( conn_type) = stream. next ( ) . await {
762- if matches ! ( conn_type, iroh:: endpoint:: ConnectionType :: Direct ( _) ) {
763- return Ok ( ( ) ) ;
764- }
765- }
766- anyhow:: bail!( "stream ended before getting a direct connection" ) ;
767- }
768-
769- let res = match timeout {
770- -1 => connect ( ep, node_id) . await ,
771- _ => {
772- let timeout = Duration :: from_millis ( timeout as u64 ) ;
773- match tokio:: time:: timeout ( timeout, connect ( ep, node_id) ) . await {
774- Ok ( Ok ( _) ) => Ok ( ( ) ) ,
775- Ok ( Err ( err) ) => Err ( err) ,
776- Err ( _) => Err ( anyhow:: anyhow!( "timeout" ) ) ,
777- }
778- }
779- } ;
780-
781- match res {
782- Ok ( _) => unsafe {
783- cb ( ctx_ptr. 0 , EndpointResult :: Ok ) ;
784- } ,
785- Err ( err) => unsafe {
786- warn ! ( "accept failed: {:?}" , err) ;
787- cb ( ctx_ptr. 0 , EndpointResult :: AcceptFailed ) ;
788- } ,
789- }
790- } ) ;
791- }
792-
793718#[ derive_ReprC]
794719#[ repr( u8 ) ]
795720#[ derive( Debug , Copy , Clone , PartialEq , Eq ) ]
@@ -845,20 +770,6 @@ pub fn endpoint_conn_type_cb(
845770 // make the compiler happy
846771 let _ = & ctx_ptr;
847772
848- let res = ep
849- . ep
850- . read ( )
851- . await
852- . as_ref ( )
853- . expect ( "endpoint not initalized" )
854- . add_node_addr ( iroh:: NodeAddr :: new ( node_id) ) ;
855- if res. is_err ( ) {
856- unsafe {
857- cb ( ctx_ptr. 0 , EndpointResult :: AddrError , ConnectionType :: None ) ;
858- }
859- return ;
860- }
861-
862773 let mut stream = match ep
863774 . ep
864775 . read ( )
@@ -869,7 +780,11 @@ pub fn endpoint_conn_type_cb(
869780 {
870781 Err ( _) => {
871782 unsafe {
872- cb ( ctx_ptr. 0 , EndpointResult :: AddrError , ConnectionType :: None ) ;
783+ cb (
784+ ctx_ptr. 0 ,
785+ EndpointResult :: ConnectionTypeError ,
786+ ConnectionType :: None ,
787+ ) ;
873788 }
874789 return ;
875790 }
@@ -1619,147 +1534,6 @@ mod tests {
16191534 client_thread. join ( ) . unwrap ( ) ;
16201535 }
16211536
1622- unsafe extern "C" fn direct_conn_callback ( ctx : * const c_void , res : EndpointResult ) {
1623- // unsafe b/c dereferencing a raw pointer
1624- let sender: & tokio:: sync:: mpsc:: Sender < EndpointResult > =
1625- unsafe { & ( * ( ctx as * const tokio:: sync:: mpsc:: Sender < EndpointResult > ) ) } ;
1626- sender
1627- . try_send ( res)
1628- . expect ( "receiver dropped or channel full" ) ;
1629- }
1630-
1631- #[ test]
1632- fn test_direct_conn_cb ( ) {
1633- let alpn: vec:: Vec < u8 > = b"/cool/alpn/1" . to_vec ( ) . into ( ) ;
1634-
1635- // create config
1636- let mut config_server = endpoint_config_default ( ) ;
1637- endpoint_config_add_alpn ( & mut config_server, alpn. as_ref ( ) ) ;
1638-
1639- let mut config_client = endpoint_config_default ( ) ;
1640- endpoint_config_add_alpn ( & mut config_client, alpn. as_ref ( ) ) ;
1641-
1642- let ( s, r) = std:: sync:: mpsc:: channel ( ) ;
1643- let ( client_s, client_r) = std:: sync:: mpsc:: channel ( ) ;
1644-
1645- // setup server
1646- let alpn_s = alpn. clone ( ) ;
1647- let server_thread = std:: thread:: spawn ( move || {
1648- // create magic endpoint and bind
1649- let ep = endpoint_default ( ) ;
1650- let bind_res = endpoint_bind ( & config_server, None , None , & ep) ;
1651- assert_eq ! ( bind_res, EndpointResult :: Ok ) ;
1652-
1653- let mut node_addr = node_addr_default ( ) ;
1654- let res = endpoint_node_addr ( & ep, & mut node_addr) ;
1655- assert_eq ! ( res, EndpointResult :: Ok ) ;
1656-
1657- s. send ( node_addr) . unwrap ( ) ;
1658-
1659- let ep = Arc :: new ( ep) ;
1660- let alpn_s = alpn_s. clone ( ) ;
1661-
1662- // accept connection
1663- println ! ( "[s] accepting conn" ) ;
1664- let conn = connection_default ( ) ;
1665- let mut alpn = vec:: Vec :: EMPTY ;
1666- let res = endpoint_accept_any ( & ep, & mut alpn, & conn) ;
1667- assert_eq ! ( res, EndpointResult :: Ok ) ;
1668-
1669- if alpn. as_ref ( ) != alpn_s. as_ref ( ) {
1670- panic ! ( "unexpectd alpn: {:?}" , alpn) ;
1671- } ;
1672-
1673- let mut send_stream = send_stream_default ( ) ;
1674- let mut recv_stream = recv_stream_default ( ) ;
1675- let accept_res = connection_accept_bi ( & conn, & mut send_stream, & mut recv_stream) ;
1676- assert_eq ! ( accept_res, EndpointResult :: Ok ) ;
1677-
1678- println ! ( "[s] reading" ) ;
1679-
1680- let mut recv_buffer = vec ! [ 0u8 ; 1024 ] ;
1681- let read_res = recv_stream_read ( & mut recv_stream, ( & mut recv_buffer[ ..] ) . into ( ) ) ;
1682- assert ! ( read_res > 0 ) ;
1683- assert_eq ! (
1684- std:: str :: from_utf8( & recv_buffer[ ..read_res as usize ] ) . unwrap( ) ,
1685- "hello world" ,
1686- ) ;
1687-
1688- println ! ( "[s] sending" ) ;
1689- let send_res = send_stream_write ( & mut send_stream, "hello client" . as_bytes ( ) . into ( ) ) ;
1690- assert_eq ! ( send_res, EndpointResult :: Ok ) ;
1691-
1692- let res = send_stream_finish ( send_stream) ;
1693- assert_eq ! ( res, EndpointResult :: Ok ) ;
1694- client_r. recv ( ) . unwrap ( ) ;
1695- } ) ;
1696-
1697- let ( direct_conn_s, mut direct_conn_r) : (
1698- tokio:: sync:: mpsc:: Sender < EndpointResult > ,
1699- tokio:: sync:: mpsc:: Receiver < EndpointResult > ,
1700- ) = tokio:: sync:: mpsc:: channel ( 1 ) ;
1701-
1702- // setup client
1703- let client_thread = std:: thread:: spawn ( move || {
1704- // create magic endpoint and bind
1705- let ep = endpoint_default ( ) ;
1706- let bind_res = endpoint_bind ( & config_client, None , None , & ep) ;
1707- assert_eq ! ( bind_res, EndpointResult :: Ok ) ;
1708-
1709- // wait for addr from server
1710- let node_addr = r. recv ( ) . unwrap ( ) ;
1711-
1712- let alpn = alpn. clone ( ) ;
1713-
1714- // wait for a moment to make sure the server is ready
1715- std:: thread:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) ;
1716-
1717- println ! ( "[c] dialing" ) ;
1718- // connect to server
1719- let conn = connection_default ( ) ;
1720- let connect_res = endpoint_connect ( & ep, alpn. as_ref ( ) , node_addr. clone ( ) , & conn) ;
1721- assert_eq ! ( connect_res, EndpointResult :: Ok ) ;
1722-
1723- let mut send_stream = send_stream_default ( ) ;
1724- let mut recv_stream = recv_stream_default ( ) ;
1725- let open_res = connection_open_bi ( & conn, & mut send_stream, & mut recv_stream) ;
1726- assert_eq ! ( open_res, EndpointResult :: Ok ) ;
1727-
1728- let s_ptr: * const c_void = & direct_conn_s as * const _ as * const c_void ;
1729- endpoint_direct_conn_cb ( ep, s_ptr, & node_addr. node_id , 5000 , direct_conn_callback) ;
1730-
1731- println ! ( "[c] sending" ) ;
1732- let send_res = send_stream_write ( & mut send_stream, "hello world" . as_bytes ( ) . into ( ) ) ;
1733- assert_eq ! ( send_res, EndpointResult :: Ok ) ;
1734-
1735- println ! ( "[c] reading" ) ;
1736-
1737- let mut recv_buffer = vec ! [ 0u8 ; 1024 ] ;
1738- let read_res = recv_stream_read ( & mut recv_stream, ( & mut recv_buffer[ ..] ) . into ( ) ) ;
1739- assert ! ( read_res > 0 ) ;
1740- assert_eq ! (
1741- std:: str :: from_utf8( & recv_buffer[ ..read_res as usize ] ) . unwrap( ) ,
1742- "hello client"
1743- ) ;
1744-
1745- let finish_res = send_stream_finish ( send_stream) ;
1746- assert_eq ! ( finish_res, EndpointResult :: Ok ) ;
1747- client_s. send ( ( ) ) . unwrap ( ) ;
1748- } ) ;
1749-
1750- server_thread. join ( ) . unwrap ( ) ;
1751- client_thread. join ( ) . unwrap ( ) ;
1752- let res = direct_conn_r. blocking_recv ( ) . unwrap ( ) ;
1753- match res {
1754- EndpointResult :: Ok => {
1755- println ! ( "got direct connection!" ) ;
1756- }
1757- _ => {
1758- panic ! ( "did not get a direct connection: {res:?}" ) ;
1759- }
1760- }
1761- }
1762-
17631537 type CallbackRes = ( EndpointResult , ConnectionType ) ;
17641538
17651539 unsafe extern "C" fn conn_type_callback (
0 commit comments