@@ -14,8 +14,10 @@ use anyhow::Result;
1414use atty:: Stream ;
1515use core:: time:: Duration ;
1616use pyo3:: exceptions:: { PyRuntimeError , PyTimeoutError } ;
17+ use std:: cmp;
1718use std:: env;
1819use std:: sync:: Arc ;
20+ use std:: thread:: available_parallelism;
1921use structopt:: StructOpt ;
2022use tokio:: runtime:: Runtime ;
2123use tokio:: task:: JoinHandle ;
@@ -34,6 +36,21 @@ use crate::torchftpb::manager_service_client::ManagerServiceClient;
3436use crate :: torchftpb:: { CheckpointMetadataRequest , ManagerQuorumRequest , ShouldCommitRequest } ;
3537use pyo3:: prelude:: * ;
3638
39+ // Get the number of threads to use for the tokio runtime
40+ fn num_threads ( ) -> usize {
41+ let default_threads = 4 ;
42+ let num_cpus = available_parallelism ( )
43+ . and_then ( |p| Ok ( p. get ( ) ) )
44+ . unwrap_or ( default_threads) ;
45+
46+ let num_threads = env:: var ( "TOKIO_WORKER_THREADS" )
47+ . ok ( )
48+ . and_then ( |s| s. parse ( ) . ok ( ) )
49+ . unwrap_or ( cmp:: min ( default_threads, num_cpus) ) ;
50+
51+ num_threads
52+ }
53+
3754/// ManagerServer is a GRPC server for the manager service.
3855/// There should be one manager server per replica group (typically running on
3956/// the rank 0 host). The individual ranks within a replica group should use
@@ -71,7 +88,11 @@ impl ManagerServer {
7188 connect_timeout : Duration ,
7289 ) -> PyResult < Self > {
7390 py. allow_threads ( move || {
74- let runtime = Runtime :: new ( ) ?;
91+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
92+ . worker_threads ( num_threads ( ) )
93+ . thread_name ( "torchft-manager" )
94+ . enable_all ( )
95+ . build ( ) ?;
7596 let manager = runtime
7697 . block_on ( manager:: Manager :: new (
7798 replica_id,
@@ -127,7 +148,11 @@ impl ManagerClient {
127148 #[ new]
128149 fn new ( py : Python < ' _ > , addr : String , connect_timeout : Duration ) -> PyResult < Self > {
129150 py. allow_threads ( move || {
130- let runtime = Runtime :: new ( ) ?;
151+ let runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
152+ . worker_threads ( num_threads ( ) )
153+ . thread_name ( "torchft-mgrclnt" )
154+ . enable_all ( )
155+ . build ( ) ?;
131156 let client = runtime
132157 . block_on ( manager:: manager_client_new ( addr, connect_timeout) )
133158 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
@@ -294,7 +319,11 @@ fn lighthouse_main(py: Python<'_>) -> PyResult<()> {
294319 let mut args = env:: args ( ) ;
295320 args. next ( ) ; // discard binary arg
296321 let opt = lighthouse:: LighthouseOpt :: from_iter ( args) ;
297- let rt = Runtime :: new ( ) ?;
322+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
323+ . thread_name ( "torchft-lighths" )
324+ . worker_threads ( num_threads ( ) )
325+ . enable_all ( )
326+ . build ( ) ?;
298327 rt. block_on ( lighthouse_main_async ( opt) )
299328 . map_err ( |e| PyRuntimeError :: new_err ( e. to_string ( ) ) ) ?;
300329 Ok ( ( ) )
@@ -345,7 +374,11 @@ impl LighthouseServer {
345374 let heartbeat_timeout_ms = heartbeat_timeout_ms. unwrap_or ( 5000 ) ;
346375
347376 py. allow_threads ( move || {
348- let rt = Runtime :: new ( ) ?;
377+ let rt = tokio:: runtime:: Builder :: new_multi_thread ( )
378+ . worker_threads ( num_threads ( ) )
379+ . thread_name ( "torchft-lighths" )
380+ . enable_all ( )
381+ . build ( ) ?;
349382
350383 let lighthouse = rt
351384 . block_on ( lighthouse:: Lighthouse :: new ( lighthouse:: LighthouseOpt {
0 commit comments