From 6e4804a9a129e5cfcec58402da4096ef9a6ec24a Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 25 Sep 2025 11:10:21 -0700 Subject: [PATCH] change lint --- .clang-format | 127 +++++++++++ .flake8 | 3 + .github/workflows/lint.yaml | 12 +- .lintrunner.toml | 81 ++++++- .pyre_configuration | 8 +- .rustfmt.toml | 5 + CONTRIBUTING.md | 1 + docs/source/conf.py | 8 +- examples/monarch/README.md | 4 +- examples/monarch/train_distributed.py | 2 +- proto/torchft.proto | 117 +++++----- pyproject.toml | 18 ++ src/bin/lighthouse.rs | 3 +- src/lib.rs | 93 ++++---- src/lighthouse.rs | 88 ++++---- src/manager.rs | 83 ++++--- src/net.rs | 6 +- src/retry.rs | 9 +- src/timeout.rs | 12 +- tools/linter/adapters/pyre_linter.py | 4 +- tools/linter/adapters/rust_linter.py | 202 ++++++++++++++++++ torchft/_test/diloco_trainer.py | 4 +- torchft/_test/managed_work_test.py | 6 +- torchft/checkpointing/_rwlock.py | 43 ++-- torchft/checkpointing/http_transport.py | 11 +- torchft/checkpointing/http_transport_bench.py | 2 +- torchft/checkpointing/http_transport_test.py | 2 +- torchft/checkpointing/pg_transport.py | 6 +- torchft/checkpointing/pg_transport_bench.py | 2 +- torchft/checkpointing/pg_transport_test.py | 2 +- torchft/checkpointing/transport_test.py | 4 +- torchft/collectives.py | 9 +- torchft/collectives_test.py | 2 +- torchft/coordination.py | 2 +- torchft/ddp.py | 4 +- torchft/ddp_test.py | 2 +- torchft/device_mesh.py | 4 +- torchft/device_mesh_test.py | 2 +- torchft/diloco_regression_test.py | 8 +- torchft/examples/slurm/punisher.py | 2 +- torchft/examples/slurm/runner.py | 2 +- torchft/fsdp_test.py | 10 +- torchft/futures.py | 3 +- torchft/futures_test.py | 2 +- torchft/local_sgd.py | 1 + torchft/local_sgd_integ_test.py | 9 +- torchft/local_sgd_test.py | 6 +- torchft/manager.py | 4 +- torchft/manager_integ_test.py | 8 +- torchft/manager_test.py | 4 +- torchft/optim.py | 2 +- torchft/optim_test.py | 2 +- torchft/process_group.py | 6 +- torchft/process_group_test.py | 16 +- torchft/quantization_test.py | 3 +- train_ddp.py | 4 +- train_diloco.py | 2 +- 57 files changed, 780 insertions(+), 307 deletions(-) create mode 100644 .clang-format create mode 100644 .flake8 create mode 100644 tools/linter/adapters/rust_linter.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..67b722d9 --- /dev/null +++ b/.clang-format @@ -0,0 +1,127 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: + - FOR_EACH_RANGE + - FOR_EACH +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +Macros: + - >- + PyObject_HEAD_INIT(type)={ + /* this is not exactly match with PyObject_HEAD_INIT in Python source code + * but it is enough for clang-format */ + { 0xFFFFFFFF }, + (type) + }, + - >- + PyVarObject_HEAD_INIT(type, size)={ + { + /* manually expand PyObject_HEAD_INIT(type) above + * because clang-format do not support recursive expansion */ + { 0xFFFFFFFF }, + (type) + }, + (size) + }, +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 2000000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: c++17 +StatementMacros: + - C10_DEFINE_bool + - C10_DEFINE_int + - C10_DEFINE_int32 + - C10_DEFINE_int64 + - C10_DEFINE_string + - C10_DEFINE_REGISTRY_WITHOUT_WARNING + - C10_REGISTER_CREATOR + - DEFINE_BINARY + - PyObject_HEAD + - PyObject_VAR_HEAD + - PyException_HEAD + - TORCH_DECLARE_bool + +TabWidth: 8 +UseTab: Never +--- +Language: ObjC +ColumnLimit: 120 +AlignAfterOpenBracket: Align +IndentWidth: 2 +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +... diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..f61a700c --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 256 +extend-ignore = E302, G004, SIM105, G201, SIM115, SIM904 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index d4662a2f..08ba6f24 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -24,15 +24,21 @@ jobs: sudo apt-get install -y protobuf-compiler pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 - pip install .[dev] -v + # install recent version of Rust via rustup + curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=default -y + . "$HOME/.cargo/env" + + rustup install nightly + rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt + lintrunner init - name: Run lintrunner run: | set -eux - lintrunner --skip PYRE --force-color --all-files + lintrunner --skip PYRE,FLAKE --force-color --all-files - name: Run pyre run: | set -eux @@ -42,4 +48,4 @@ jobs: run: | set -eux - cargo fmt --check + cargo +nightly fmt --check diff --git a/.lintrunner.toml b/.lintrunner.toml index 2a715170..8dddffad 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1,5 +1,62 @@ [[linter]] -code = 'BLACK-ISORT' +code = 'CLANGFORMAT' +include_patterns = [ + '**/*.proto', +] +exclude_patterns = [] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'clangformat_linter', + '--binary=clang-format', + '--fallback', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'pip_init', + '--dry-run={{DRYRUN}}', + 'clang-format==18.1.3', +] + +[[linter]] +code = 'UFMT' +include_patterns = [ + '*.py', + '**/*.py', + '**/*.pyi', +] +exclude_patterns = [] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'ufmt_linter', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'pip_init', + '--dry-run={{DRYRUN}}', + 'black==24.4.2', + 'ufmt==2.6.0', + 'usort==1.0.5', + 'ruff-api==0.1.0', +] + +[[linter]] +code = 'FLAKE' include_patterns = [ '*.py', '**/*.py', @@ -10,8 +67,8 @@ command = [ '-m', 'lintrunner_adapters', 'run', - 'black_isort_linter', - '--fast', + 'flake8_linter', + '--config=.flake8', '--', '@{{PATHSFILE}}', ] @@ -22,8 +79,17 @@ init_command = [ 'run', 'pip_init', '--dry-run={{DRYRUN}}', - 'black==24.10.0', # Use 24.x when ruff styles are updated - 'isort==5.13.2', + 'flake8==7.3.0', + 'flake8-bugbear==24.12.12', + 'flake8-comprehensions==3.16.0', + 'flake8-executable==2.1.3', + 'flake8-logging-format==2024.24.12', + 'flake8-pyi==25.5.0', + 'flake8-simplify==0.22.0', + 'mccabe==0.7.0', + 'pycodestyle==2.14.0', + 'pyflakes==3.4.0', + 'torchfix==0.4.0 ; python_version >= "3.10" and python_version < "3.13"', ] is_formatter = true @@ -34,10 +100,7 @@ include_patterns = [ ] command = [ 'python', - '-m', - 'lintrunner_adapters', - 'run', - 'rustfmt_linter', + 'tools/linter/adapters/rust_linter.py', '--binary=rustfmt', '--config-path=.rustfmt.toml', '--', diff --git a/.pyre_configuration b/.pyre_configuration index 04ce7f04..7913bbe1 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -8,7 +8,11 @@ } ], "search_path": [ - {"site-package": "torchx"}, - {"site-package": "parameterized"} + { + "site-package": "torchx" + }, + { + "site-package": "parameterized" + } ] } diff --git a/.rustfmt.toml b/.rustfmt.toml index 3a26366d..d55c2f22 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1 +1,6 @@ edition = "2021" +group_imports = "StdExternalCrate" +imports_granularity = "Item" +merge_derives = false +style_edition = "2024" +use_field_init_shorthand = true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 59862100..55191e08 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,6 +63,7 @@ We actively welcome your pull requests. pip install lintrunner lintrunner-adapters lintrunner init lintrunner -a +cargo +nightly fmt ``` ### Tests diff --git a/docs/source/conf.py b/docs/source/conf.py index a8a7d0fe..ce268dbe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -31,12 +31,12 @@ from importlib.metadata import version import pytorch_sphinx_theme2 + +import torchft from docutils import nodes from sphinx import addnodes from sphinx.util.docfields import TypedField -import torchft - FBCODE = "fbcode" in os.getcwd() # -- General configuration ------------------------------------------------ @@ -236,9 +236,7 @@ def setup(app): # In Sphinx 1.8 it was renamed to `add_css_file`, 1.7 and prior it is # `add_stylesheet` (deprecated in 1.8). - add_css = getattr( - app, "add_css_file", getattr(app, "add_stylesheet", None) - ) # noqa B009 + add_css = getattr(app, "add_css_file", getattr(app, "add_stylesheet", None)) # noqa B009 for css_file in html_css_files: add_css(css_file) diff --git a/examples/monarch/README.md b/examples/monarch/README.md index 08a5f924..8de09528 100644 --- a/examples/monarch/README.md +++ b/examples/monarch/README.md @@ -1,6 +1,6 @@ ### Monarch-TorchFT-TorchTitan Distributed Training Orchestrator -#### Overview +#### Overview This script orchestrates fault-tolerant distributed training using TorchTitan and TorchMonarch frameworks. It manages multiple training replicas across SLURM-scheduled compute nodes with automatic failure recovery and TorchFT lighthouse coordination. @@ -47,4 +47,4 @@ You can also override the resource configuration manually: - TensorBoard metrics enabled by default ##### CLEANUP -All SLURM jobs are automatically terminated at script completion. \ No newline at end of file +All SLURM jobs are automatically terminated at script completion. diff --git a/examples/monarch/train_distributed.py b/examples/monarch/train_distributed.py index 02c8b527..3fb3b970 100644 --- a/examples/monarch/train_distributed.py +++ b/examples/monarch/train_distributed.py @@ -16,7 +16,7 @@ import torch from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer -from monarch.actor import Actor, ProcMesh, current_rank, endpoint, this_host +from monarch.actor import Actor, current_rank, endpoint, ProcMesh, this_host from monarch.tools import commands from monarch.tools.components import hyperactor from monarch.tools.config import Config diff --git a/proto/torchft.proto b/proto/torchft.proto index 7c086eb9..194ea44c 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -10,119 +10,120 @@ package torchft; import "google/protobuf/timestamp.proto"; message RaftMessageRequest { - // Request message contains the serialized Raft proto message. - bytes message = 1; + // Request message contains the serialized Raft proto message. + bytes message = 1; } -message RaftMessageResponse { -} +message RaftMessageResponse {} message NodeInfo { - uint64 rank = 1; - string address = 2; + uint64 rank = 1; + string address = 2; } message InfoRequest { - NodeInfo requester = 1; + NodeInfo requester = 1; } message InfoResponse { - repeated NodeInfo peers = 1; + repeated NodeInfo peers = 1; } service CoordinatorService { - rpc RaftMessage (RaftMessageRequest) returns (RaftMessageResponse); - rpc Info (InfoRequest) returns (InfoResponse); + rpc RaftMessage(RaftMessageRequest) returns (RaftMessageResponse); + rpc Info(InfoRequest) returns (InfoResponse); } message QuorumMember { - string replica_id = 1; - string address = 2; - string store_address = 3; - int64 step = 4; - uint64 world_size = 5; - bool shrink_only = 6; - int64 commit_failures = 8; - // User passing in data stored as JSON string. - string data = 7; + string replica_id = 1; + string address = 2; + string store_address = 3; + int64 step = 4; + uint64 world_size = 5; + bool shrink_only = 6; + int64 commit_failures = 8; + // User passing in data stored as JSON string. + string data = 7; } message Quorum { - int64 quorum_id = 1; - repeated QuorumMember participants = 2; - google.protobuf.Timestamp created = 3; + int64 quorum_id = 1; + repeated QuorumMember participants = 2; + google.protobuf.Timestamp created = 3; } message LighthouseQuorumRequest { - QuorumMember requester = 1; + QuorumMember requester = 1; } message LighthouseQuorumResponse { - Quorum quorum = 1; + Quorum quorum = 1; } message LighthouseHeartbeatRequest { - string replica_id = 1; + string replica_id = 1; } message LighthouseHeartbeatResponse {} service LighthouseService { - rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse); - rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse); + rpc Quorum(LighthouseQuorumRequest) returns (LighthouseQuorumResponse); + rpc Heartbeat(LighthouseHeartbeatRequest) + returns (LighthouseHeartbeatResponse); } message ManagerQuorumRequest { - int64 group_rank = 1; - int64 step = 2; - string checkpoint_metadata = 3; - bool shrink_only = 4; - bool init_sync = 5; - int64 commit_failures = 6; + int64 group_rank = 1; + int64 step = 2; + string checkpoint_metadata = 3; + bool shrink_only = 4; + bool init_sync = 5; + int64 commit_failures = 6; } message ManagerQuorumResponse { - int64 quorum_id = 1; - string recover_src_manager_address = 2; - optional int64 recover_src_replica_rank = 3; - repeated int64 recover_dst_replica_ranks = 4; - string store_address = 5; - // These are information for the replicas which are at the max step. - int64 max_step = 6; - optional int64 max_replica_rank = 7; - int64 max_world_size = 8; - // These are information for all replicas including behind replicas. - int64 replica_rank = 9; - int64 replica_world_size = 10; - bool heal = 11; - int64 commit_failures = 12; + int64 quorum_id = 1; + string recover_src_manager_address = 2; + optional int64 recover_src_replica_rank = 3; + repeated int64 recover_dst_replica_ranks = 4; + string store_address = 5; + // These are information for the replicas which are at the max step. + int64 max_step = 6; + optional int64 max_replica_rank = 7; + int64 max_world_size = 8; + // These are information for all replicas including behind replicas. + int64 replica_rank = 9; + int64 replica_world_size = 10; + bool heal = 11; + int64 commit_failures = 12; } message CheckpointMetadataRequest { - int64 rank = 1; + int64 rank = 1; } message CheckpointMetadataResponse { - string checkpoint_metadata = 1; + string checkpoint_metadata = 1; } message ShouldCommitRequest { - bool should_commit = 1; - int64 group_rank = 2; - int64 step = 3; + bool should_commit = 1; + int64 group_rank = 2; + int64 step = 3; } message ShouldCommitResponse { - bool should_commit = 1; + bool should_commit = 1; } message KillRequest { - string msg = 1; + string msg = 1; } message KillResponse {} service ManagerService { - rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse); - rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse); - rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); - rpc Kill(KillRequest) returns (KillResponse); + rpc Quorum(ManagerQuorumRequest) returns (ManagerQuorumResponse); + rpc CheckpointMetadata(CheckpointMetadataRequest) + returns (CheckpointMetadataResponse); + rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); + rpc Kill(KillRequest) returns (KillResponse); } diff --git a/pyproject.toml b/pyproject.toml index 0c848b1e..b7c53e34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,27 @@ module-name = "torchft._torchft" torchft_lighthouse = "torchft._torchft:lighthouse_main" [tool.isort] +extra_standard_library = ["typing_extensions"] +skip_gitignore = true +atomic = true +profile = "black" +indent = 4 +line_length = 88 +lines_after_imports = 2 multi_line_output = 3 +include_trailing_comma = true combine_as_imports = true +[tool.black] +target-version = ["py312"] +include = '\.pyi?$' + +[tool.usort] +first_party_detection = false + +[tool.ufmt] +formatter = "ruff-api" + [tool.pytest.ini_options] log_format = "%(asctime)s %(levelname)s %(message)s" log_date_format = "%Y-%m-%d %H:%M:%S" diff --git a/src/bin/lighthouse.rs b/src/bin/lighthouse.rs index dbce458b..09686fe2 100644 --- a/src/bin/lighthouse.rs +++ b/src/bin/lighthouse.rs @@ -5,7 +5,8 @@ // LICENSE file in the root directory of this source tree. use structopt::StructOpt; -use torchft::lighthouse::{Lighthouse, LighthouseOpt}; +use torchft::lighthouse::Lighthouse; +use torchft::lighthouse::LighthouseOpt; #[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { diff --git a/src/lib.rs b/src/lib.rs index fc8f8eb5..4b4da042 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,36 +10,41 @@ mod net; mod retry; mod timeout; -use anyhow::Result; -use atty::Stream; use core::time::Duration; -use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use std::cmp; use std::env; use std::sync::Arc; use std::thread::available_parallelism; + +use anyhow::Result; +use atty::Stream; +use chrono::Local; +use fern::colors::Color; +use fern::colors::ColoredLevelConfig; +use log::LevelFilter; +use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyTimeoutError; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; -use tonic::transport::Channel; use tonic::Status; - -use chrono::Local; -use fern::colors::{Color, ColoredLevelConfig}; -use log::LevelFilter; +use tonic::transport::Channel; pub mod torchftpb { tonic::include_proto!("torchft"); } +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::types::PyString; + +use crate::torchftpb::CheckpointMetadataRequest; +use crate::torchftpb::LighthouseHeartbeatRequest; +use crate::torchftpb::LighthouseQuorumRequest; +use crate::torchftpb::ManagerQuorumRequest; +use crate::torchftpb::ShouldCommitRequest; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; -use crate::torchftpb::{ - CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, - ManagerQuorumRequest, ShouldCommitRequest, -}; -use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString}; // Get the number of threads to use for the tokio runtime fn num_threads() -> usize { @@ -115,8 +120,8 @@ impl ManagerServer { .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let handle = runtime.spawn(manager.clone().run()); Ok(Self { - handle: handle, - manager: manager, + handle, + manager, _runtime: runtime, }) }) @@ -165,10 +170,7 @@ impl ManagerClient { .block_on(manager::manager_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(Self { - runtime: runtime, - client: client, - }) + Ok(Self { runtime, client }) }) } @@ -185,12 +187,12 @@ impl ManagerClient { ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { - group_rank: group_rank, - step: step, - checkpoint_metadata: checkpoint_metadata, - shrink_only: shrink_only, - init_sync: init_sync, - commit_failures: commit_failures, + group_rank, + step, + checkpoint_metadata, + shrink_only, + init_sync, + commit_failures, }); // This timeout is processed on the server side so we also enable @@ -222,7 +224,7 @@ impl ManagerClient { timeout: Duration, ) -> Result { py.allow_threads(move || { - let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank }); + let mut request = tonic::Request::new(CheckpointMetadataRequest { rank }); // This timeout is processed on the server side so we also enable // keep alives to detect server health. @@ -260,9 +262,9 @@ impl ManagerClient { ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ShouldCommitRequest { - group_rank: group_rank, - step: step, - should_commit: should_commit, + group_rank, + step, + should_commit, }); // This notifies the server about the timeout but doesn't affect the @@ -466,7 +468,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { Ok(Quorum { quorum_id: q.quorum_id, - participants: participants, + participants, created: Timestamp::from(q.created.unwrap()), }) } @@ -498,10 +500,7 @@ impl LighthouseClient { let client = runtime .block_on(manager::lighthouse_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - Ok(Self { - client: client, - runtime: runtime, - }) + Ok(Self { client, runtime }) }) } @@ -545,12 +544,12 @@ impl LighthouseClient { let quorum: Result = py.allow_threads(move || { let mut request = tonic::Request::new(LighthouseQuorumRequest { requester: Some(torchftpb::QuorumMember { - replica_id: replica_id, - address: address, - store_address: store_address, - step: step, - world_size: world_size, - shrink_only: shrink_only, + replica_id, + address, + store_address, + step, + world_size, + shrink_only, data: data_string, commit_failures: 0, }), @@ -636,17 +635,17 @@ impl LighthouseServer { let lighthouse = rt .block_on(lighthouse::Lighthouse::new(lighthouse::LighthouseOpt { - bind: bind, - min_replicas: min_replicas, - join_timeout_ms: join_timeout_ms, - quorum_tick_ms: quorum_tick_ms, - heartbeat_timeout_ms: heartbeat_timeout_ms, + bind, + min_replicas, + join_timeout_ms, + quorum_tick_ms, + heartbeat_timeout_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; Ok(Self { handle: rt.spawn(lighthouse.clone().run()), - lighthouse: lighthouse, + lighthouse, _runtime: rt, }) }) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a5760032..ea1aa3f9 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -8,37 +8,45 @@ use core::net::SocketAddr; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; - use std::time::Duration; -use std::time::{Instant, SystemTime}; +use std::time::Instant; +use std::time::SystemTime; -use anyhow::{anyhow, Result}; +use anyhow::Result; +use anyhow::anyhow; use askama::Template; -use axum::{ - extract::Path, - http::StatusCode, - response::{Html, IntoResponse}, - routing::{get, post}, - Router, -}; +use axum::Router; +use axum::extract::Path; +use axum::http::StatusCode; +use axum::response::Html; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::routing::post; use gethostname::gethostname; -use log::{error, info}; +use log::error; +use log::info; use structopt::StructOpt; -use tokio::sync::broadcast; use tokio::sync::Mutex; +use tokio::sync::broadcast; use tokio::task::JoinSet; use tokio::time::interval; +use tonic::Request; +use tonic::Response; +use tonic::Status; use tonic::service::Routes; -use tonic::transport::server::TcpIncoming; use tonic::transport::Server; -use tonic::{Request, Response, Status}; +use tonic::transport::server::TcpIncoming; use crate::manager::manager_client_new; -use crate::torchftpb::{ - lighthouse_service_server::{LighthouseService, LighthouseServiceServer}, - KillRequest, LighthouseHeartbeatRequest, LighthouseHeartbeatResponse, LighthouseQuorumRequest, - LighthouseQuorumResponse, Quorum, QuorumMember, -}; +use crate::torchftpb::KillRequest; +use crate::torchftpb::LighthouseHeartbeatRequest; +use crate::torchftpb::LighthouseHeartbeatResponse; +use crate::torchftpb::LighthouseQuorumRequest; +use crate::torchftpb::LighthouseQuorumResponse; +use crate::torchftpb::Quorum; +use crate::torchftpb::QuorumMember; +use crate::torchftpb::lighthouse_service_server::LighthouseService; +use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; #[derive(Clone)] struct QuorumMemberDetails { @@ -274,7 +282,7 @@ impl Lighthouse { quorum_id: 0, heartbeats: HashMap::new(), }), - opt: opt, + opt, local_addr: listener.local_addr()?, listener: Mutex::new(Some(listener)), change_logger: ChangeLogger::new(), @@ -318,7 +326,7 @@ impl Lighthouse { let quorum = Quorum { quorum_id: state.quorum_id, - participants: participants, + participants, created: Some(SystemTime::now().into()), }; @@ -429,10 +437,10 @@ impl Lighthouse { StatusTemplate { quorum_id: state.quorum_id, - num_participants: num_participants, + num_participants, prev_quorum: state.prev_quorum.clone(), - quorum_status: quorum_status, - max_step: max_step, + quorum_status, + max_step, heartbeats: state.heartbeats.clone(), old_age_threshold: Instant::now() @@ -603,11 +611,11 @@ where #[cfg(test)] mod tests { - use super::*; use std::ops::Sub; use tonic::transport::Channel; + use super::*; use crate::net::connect; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; @@ -1171,10 +1179,12 @@ mod tests { }); let second_response = client.quorum(second_request).await?; let second_quorum = second_response.into_inner().quorum.unwrap(); - assert!(second_quorum - .participants - .iter() - .all(|p| p.replica_id != "joiner")); + assert!( + second_quorum + .participants + .iter() + .all(|p| p.replica_id != "joiner") + ); assert_eq!(second_quorum.participants.len(), 2); assert_eq!(second_quorum.participants[0].replica_id, "replica0"); assert_eq!(second_quorum.participants[1].replica_id, "replica1"); @@ -1190,19 +1200,23 @@ mod tests { }); let third_response = client.quorum(second_request).await?; let third_quorum = third_response.into_inner().quorum.unwrap(); - assert!(third_quorum - .participants - .iter() - .any(|p| p.replica_id == "joiner")); + assert!( + third_quorum + .participants + .iter() + .any(|p| p.replica_id == "joiner") + ); assert_eq!(third_quorum.participants.len(), 3); assert_eq!(third_quorum.participants[2].step, 3); let join_result = joining_task.await?; let join_quorum = join_result.unwrap().into_inner().quorum.unwrap(); - assert!(join_quorum - .participants - .iter() - .any(|p| p.replica_id == "joiner")); + assert!( + join_quorum + .participants + .iter() + .any(|p| p.replica_id == "joiner") + ); assert_eq!(join_quorum.participants.len(), 3); assert_eq!(join_quorum.participants[2].step, 3); diff --git a/src/manager.rs b/src/manager.rs index 0c2a541a..d901caa6 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -7,36 +7,48 @@ use core::net::SocketAddr; use std::collections::HashMap; use std::collections::HashSet; +#[cfg(test)] +use std::println as info; +#[cfg(test)] +use std::println as warn; use std::sync::Arc; use std::time::Duration; use anyhow::Result; -use tokio::sync::broadcast; +#[cfg(not(test))] +use log::info; +#[cfg(not(test))] +use log::warn; use tokio::sync::Mutex; +use tokio::sync::broadcast; use tokio::task::JoinSet; use tokio::time::sleep; -use tonic::transport::server::TcpIncoming; +use tonic::Request; +use tonic::Response; +use tonic::Status; use tonic::transport::Channel; use tonic::transport::Server; -use tonic::{Request, Response, Status}; +use tonic::transport::server::TcpIncoming; use crate::net::connect; use crate::timeout::try_parse_grpc_timeout; +use crate::torchftpb::CheckpointMetadataRequest; +use crate::torchftpb::CheckpointMetadataResponse; +use crate::torchftpb::KillRequest; +use crate::torchftpb::KillResponse; +use crate::torchftpb::LighthouseHeartbeatRequest; +use crate::torchftpb::LighthouseQuorumRequest; +use crate::torchftpb::LighthouseQuorumResponse; +use crate::torchftpb::ManagerQuorumRequest; +use crate::torchftpb::ManagerQuorumResponse; +use crate::torchftpb::Quorum; +use crate::torchftpb::QuorumMember; +use crate::torchftpb::ShouldCommitRequest; +use crate::torchftpb::ShouldCommitResponse; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; -use crate::torchftpb::LighthouseQuorumResponse; -use crate::torchftpb::{ - manager_service_server::{ManagerService, ManagerServiceServer}, - CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse, - LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, - ManagerQuorumResponse, Quorum, QuorumMember, ShouldCommitRequest, ShouldCommitResponse, -}; - -#[cfg(not(test))] -use log::{info, warn}; - -#[cfg(test)] -use std::{println as info, println as warn}; +use crate::torchftpb::manager_service_server::ManagerService; +use crate::torchftpb::manager_service_server::ManagerServiceServer; // The replica_id string is of the form {replica_name}:{uuid} or just {uuid} (see torchft/manager.py) // We can parse the replica_id if it exists, otherwise we just use the uuid @@ -124,13 +136,13 @@ impl Manager { let client = lighthouse_client_new(lighthouse_addr.clone(), connect_timeout).await?; Ok(Arc::new(Self { - replica_id: replica_id, + replica_id, lighthouse_addr, connect_timeout, - hostname: hostname, + hostname, store_address: store_addr, - world_size: world_size, - heartbeat_interval: heartbeat_interval, + world_size, + heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_metadata: HashMap::new(), channel: tx, @@ -142,7 +154,7 @@ impl Manager { lighthouse_client: client, }), - local_addr: local_addr, + local_addr, listener: Mutex::new(Some(listener)), quorum_retries, })) @@ -462,9 +474,7 @@ impl ManagerService for Arc { .await .map_err(|e| Status::internal(e.to_string()))?; - let reply = ShouldCommitResponse { - should_commit: should_commit, - }; + let reply = ShouldCommitResponse { should_commit }; Ok(Response::new(reply)) } @@ -593,18 +603,18 @@ fn compute_quorum_results( Ok(ManagerQuorumResponse { quorum_id: quorum.quorum_id, // address is used for looking up the checkpoint server address. - recover_src_manager_address: recover_src_manager_address, - recover_src_replica_rank: recover_src_replica_rank, + recover_src_manager_address, + recover_src_replica_rank, recover_dst_replica_ranks: recovery_assignments .get(&replica_rank) .map_or_else(Vec::new, |v| v.clone()), store_address: primary.store_address.clone(), - max_step: max_step, - max_replica_rank: max_replica_rank, + max_step, + max_replica_rank, max_world_size: max_participants.len() as i64, replica_rank: replica_rank as i64, replica_world_size: participants.len() as i64, - heal: heal, + heal, commit_failures: participants .iter() .map(|p| p.commit_failures) @@ -615,13 +625,16 @@ fn compute_quorum_results( #[cfg(test)] mod tests { - use super::*; - use crate::lighthouse::{Lighthouse, LighthouseOpt}; - use crate::torchftpb::lighthouse_service_server::{LighthouseService, LighthouseServiceServer}; - use crate::torchftpb::LighthouseHeartbeatResponse; use tokio::net::TcpListener; use tonic::codegen::tokio_stream::wrappers::TcpListenerStream; + use super::*; + use crate::lighthouse::Lighthouse; + use crate::lighthouse::LighthouseOpt; + use crate::torchftpb::LighthouseHeartbeatResponse; + use crate::torchftpb::lighthouse_service_server::LighthouseService; + use crate::torchftpb::lighthouse_service_server::LighthouseServiceServer; + async fn should_commit(group_rank: i64, should_commit: bool) -> Result { let mut client = manager_client_new( "http://localhost:29531".to_string(), @@ -630,9 +643,9 @@ mod tests { .await?; let request = tonic::Request::new(ShouldCommitRequest { - group_rank: group_rank, + group_rank, step: 1, - should_commit: should_commit, + should_commit, }); let resp = client.should_commit(request).await?; diff --git a/src/net.rs b/src/net.rs index e6d9b690..12ce62c5 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,9 +1,11 @@ use std::time::Duration; use anyhow::Result; -use tonic::transport::{Channel, Endpoint}; +use tonic::transport::Channel; +use tonic::transport::Endpoint; -use crate::retry::{retry_backoff, ExponentialBackoff}; +use crate::retry::ExponentialBackoff; +use crate::retry::retry_backoff; pub async fn connect_once(addr: String, connect_timeout: Duration) -> Result { let conn = Endpoint::new(addr)? diff --git a/src/retry.rs b/src/retry.rs index 00f2f204..6a84cff6 100644 --- a/src/retry.rs +++ b/src/retry.rs @@ -1,7 +1,9 @@ -use anyhow::Result; use std::future::Future; use std::pin::Pin; -use std::time::{Duration, Instant}; +use std::time::Duration; +use std::time::Instant; + +use anyhow::Result; pub struct ExponentialBackoff { pub initial_backoff: Duration, @@ -42,10 +44,11 @@ where #[cfg(test)] mod tests { - use super::*; use std::sync::Arc; use std::sync::Mutex; + use super::*; + #[tokio::test] async fn test_retry_backoff() -> Result<()> { let count = Arc::new(Mutex::new(0)); diff --git a/src/timeout.rs b/src/timeout.rs index d966e88e..da8af4fb 100644 --- a/src/timeout.rs +++ b/src/timeout.rs @@ -1,7 +1,9 @@ use std::time::Duration; use anyhow::Result; -use tonic::metadata::{Ascii, MetadataMap, MetadataValue}; +use tonic::metadata::Ascii; +use tonic::metadata::MetadataMap; +use tonic::metadata::MetadataValue; const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; const SECONDS_IN_HOUR: u64 = 60 * 60; @@ -79,9 +81,11 @@ mod tests { assert!(try_parse_grpc_timeout(&map("3u")).unwrap() == Some(Duration::from_micros(3))); assert!(try_parse_grpc_timeout(&map("3n")).unwrap() == Some(Duration::from_nanos(3))); - assert!(try_parse_grpc_timeout(&MetadataMap::new()) - .unwrap() - .is_none()); + assert!( + try_parse_grpc_timeout(&MetadataMap::new()) + .unwrap() + .is_none() + ); assert!(try_parse_grpc_timeout(&map("")).is_err()); } } diff --git a/tools/linter/adapters/pyre_linter.py b/tools/linter/adapters/pyre_linter.py index d5c9ad84..99f2c3b8 100644 --- a/tools/linter/adapters/pyre_linter.py +++ b/tools/linter/adapters/pyre_linter.py @@ -109,7 +109,9 @@ def main() -> None: level=( logging.NOTSET if args.verbose - else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO ), stream=sys.stderr, ) diff --git a/tools/linter/adapters/rust_linter.py b/tools/linter/adapters/rust_linter.py new file mode 100644 index 00000000..9c8e9201 --- /dev/null +++ b/tools/linter/adapters/rust_linter.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import argparse +import concurrent.futures +import logging +import os +import re +import subprocess +import sys +from typing import Pattern + +import lintrunner_adapters +from lintrunner_adapters import as_posix, LintMessage, LintSeverity, run_command + +LINTER_CODE = "RUSTFMT" + +SYNTAX_ERROR_ARROW_RE: Pattern[str] = re.compile( + r"(?m)^( +--> )(.+)(:(?P\d+):(?P\d+))\n" +) + +SYNTAX_ERROR_PARSE_RE: Pattern[str] = re.compile(r"(?m)^failed to parse .*\n") + + +def strip_path_from_error(error: str) -> str: + # Remove full paths from the description to have deterministic messages. + error = SYNTAX_ERROR_ARROW_RE.sub("", error, count=1) + error = SYNTAX_ERROR_PARSE_RE.sub("", error, count=1) + return error + + +def check_file( + filename: str, + *, + binary: str, + config_path: str | None, +) -> list[LintMessage]: + try: + with open(filename, "rb") as f: + original = f.read() + with open(filename, "rb") as f: + proc = run_command( + [ + binary, + "+nightly", + "--emit=stdout", + "--quiet", + ] + + (["--config-path", config_path] if config_path else []), + stdin=f, + check=True, + ) + except (OSError, subprocess.CalledProcessError) as err: + # https://github.com/rust-lang/rustfmt#running + # TODO: Fix the syntax error regexp to handle multiple issues and + # to handle the empty result case. + if ( + isinstance(err, subprocess.CalledProcessError) + and err.returncode == 1 + and err.stderr + ): + line = None + char = None + description = err.stderr.decode("utf-8") + match = SYNTAX_ERROR_ARROW_RE.search(description) + if match: + line = int(match["line"]) + char = int(match["column"]) + description = f"```\n{strip_path_from_error(description)}\n```" + return [ + LintMessage( + path=filename, + line=line, + char=char, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="parsing-error", + original=None, + replacement=None, + description=description, + ) + ] + + return [ + LintMessage( + path=filename, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=( + f"Failed due to {err.__class__.__name__}:\n{err}" + if not isinstance(err, subprocess.CalledProcessError) + else ( + "COMMAND (exit code {returncode})\n" + "{command}\n\n" + "STDERR\n{stderr}\n\n" + "STDOUT\n{stdout}" + ).format( + returncode=err.returncode, + command=" ".join(as_posix(x) for x in err.cmd), + stderr=err.stderr.decode("utf-8").strip() or "(empty)", + stdout=err.stdout.decode("utf-8").strip() or "(empty)", + ) + ), + ) + ] + + replacement = proc.stdout + if original == replacement: + return [] + + if proc.stderr.startswith(b"error: "): + clean_err = strip_path_from_error(proc.stderr.decode("utf-8")).strip() + return [ + LintMessage( + path=filename, + line=None, + char=None, + code=LINTER_CODE, + severity=LintSeverity.WARNING, + name="rustfmt-bug", + original=None, + replacement=None, + description=( + "Possible rustfmt bug. " + f"rustfmt returned error output but didn't fail:\n{clean_err}" + ), + ) + ] + + return [ + LintMessage( + path=filename, + line=1, + char=1, + code=LINTER_CODE, + severity=LintSeverity.WARNING, + name="format", + original=original.decode("utf-8"), + replacement=replacement.decode("utf-8"), + description="See https://github.com/rust-lang/rustfmt#tips", + ) + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Format rust files with rustfmt. Linter code: {LINTER_CODE}", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--binary", + required=True, + default="rustfmt", + help="rustfmt binary path", + ) + parser.add_argument( + "--config-path", + required=True, + default=None, + help="rustfmt config path", + ) + + lintrunner_adapters.add_default_options(parser) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=( + logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO + ), + stream=sys.stderr, + ) + + with concurrent.futures.ThreadPoolExecutor( + max_workers=os.cpu_count(), + thread_name_prefix="Thread", + ) as executor: + futures = { + executor.submit( + check_file, x, binary=args.binary, config_path=args.config_path + ): x + for x in args.filenames + } + for future in concurrent.futures.as_completed(futures): + try: + for lint_message in future.result(): + lint_message.display() + except Exception: + logging.critical('Failed at "%s".', futures[future]) + raise + + +if __name__ == "__main__": + main() diff --git a/torchft/_test/diloco_trainer.py b/torchft/_test/diloco_trainer.py index 832f3698..2cb0d36d 100644 --- a/torchft/_test/diloco_trainer.py +++ b/torchft/_test/diloco_trainer.py @@ -3,13 +3,13 @@ import os from contextlib import ExitStack from datetime import timedelta -from typing import Any, Dict, List, cast +from typing import Any, cast, Dict, List import torch from torch import nn from torch.distributed.tensor import DTensor -from torchft.device_mesh import ManagedDeviceMesh, ft_init_device_mesh +from torchft.device_mesh import ft_init_device_mesh, ManagedDeviceMesh from torchft.local_sgd import DiLoCo from torchft.manager import Manager from torchft.manager_integ_test import MyModel, Runner diff --git a/torchft/_test/managed_work_test.py b/torchft/_test/managed_work_test.py index 118daf98..8b360466 100644 --- a/torchft/_test/managed_work_test.py +++ b/torchft/_test/managed_work_test.py @@ -7,7 +7,7 @@ import types import unittest from datetime import timedelta -from typing import Callable, Dict, List, Optional, Tuple, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, Tuple, TypeVar # Define a type variable for the Future's value type T = TypeVar("T") @@ -17,7 +17,7 @@ from torch.distributed.distributed_c10d import Work from torch.futures import Future -from torchft.manager import Manager, _ManagedWork +from torchft.manager import _ManagedWork, Manager class SimpleWork(Work): @@ -278,7 +278,7 @@ def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]: # Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float] # Uses Future.value() to modify tensor2 def callback2( - fut: Future[Dict[str, torch.Tensor]] + fut: Future[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, float]: data = fut.value() # Modify tensor2 in-place using the value from the future diff --git a/torchft/checkpointing/_rwlock.py b/torchft/checkpointing/_rwlock.py index db8c370d..54b6d088 100644 --- a/torchft/checkpointing/_rwlock.py +++ b/torchft/checkpointing/_rwlock.py @@ -1,39 +1,38 @@ # -*- coding: utf-8 -*- -""" rwlock.py +"""rwlock.py - Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py +Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py - A class to implement read-write locks on top of the standard threading - library. +A class to implement read-write locks on top of the standard threading +library. - This is implemented with two mutexes (threading.Lock instances) as per this - wikipedia pseudocode: +This is implemented with two mutexes (threading.Lock instances) as per this +wikipedia pseudocode: - https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes +https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes - __________________________ - License info (MIT): +__________________________ +License info (MIT): - ******* +******* - Copyright 2023 Tyler Neylon and contributors +Copyright 2023 Tyler Neylon and contributors - Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated - documentation files (the "Software"), to deal in the Software without restriction, including without limitation the - rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit - persons to whom the Software is furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +persons to whom the Software is furnished to do so, subject to the following conditions: - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR - OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - ******* +******* """ - from contextlib import contextmanager from threading import Lock from typing import Generator diff --git a/torchft/checkpointing/http_transport.py b/torchft/checkpointing/http_transport.py index 826f23d5..5613b2d5 100644 --- a/torchft/checkpointing/http_transport.py +++ b/torchft/checkpointing/http_transport.py @@ -13,10 +13,10 @@ from contextlib import contextmanager, nullcontext from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Generator, List, Optional, TypeVar, cast +from typing import cast, Generator, List, Optional, TypeVar import torch -from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec from torchft.checkpointing._rwlock import RWLock from torchft.checkpointing._serialization import _streaming_load, _streaming_save @@ -152,9 +152,10 @@ def _load_from_address(cls, address: str, timeout: timedelta) -> object: msg = f"fetching checkpoint from {address}" logger.info(msg) - with _time(msg), urllib.request.urlopen( - address, timeout=timeout.total_seconds() - ) as f: + with ( + _time(msg), + urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f, + ): # We have to set weights_only to False as there are some non-tensor # states like lr_scheduler. # pyre-fixme[16]: needs torch>=2.7 diff --git a/torchft/checkpointing/http_transport_bench.py b/torchft/checkpointing/http_transport_bench.py index 4e52193c..f002895c 100644 --- a/torchft/checkpointing/http_transport_bench.py +++ b/torchft/checkpointing/http_transport_bench.py @@ -5,7 +5,7 @@ import torch -from torchft.checkpointing.http_transport import HTTPTransport, _time +from torchft.checkpointing.http_transport import _time, HTTPTransport logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchft/checkpointing/http_transport_test.py b/torchft/checkpointing/http_transport_test.py index 00d379e0..8129394c 100644 --- a/torchft/checkpointing/http_transport_test.py +++ b/torchft/checkpointing/http_transport_test.py @@ -7,7 +7,7 @@ import urllib.error from datetime import timedelta from typing import Dict -from unittest import TestCase, skipUnless +from unittest import skipUnless, TestCase from unittest.mock import MagicMock import torch diff --git a/torchft/checkpointing/pg_transport.py b/torchft/checkpointing/pg_transport.py index 5aaa3a9f..0c792af6 100644 --- a/torchft/checkpointing/pg_transport.py +++ b/torchft/checkpointing/pg_transport.py @@ -4,16 +4,16 @@ from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Callable, Generator, Optional, TypeVar, Union, cast +from typing import Callable, cast, Generator, Optional, TypeVar, Union import torch from torch.distributed import Work -from torch.distributed.tensor import DTensor, _DTensorSpec +from torch.distributed.tensor import _DTensorSpec, DTensor from torch.utils._pytree import ( KeyPath, - TreeSpec, tree_flatten_with_path, tree_unflatten, + TreeSpec, ) from torchft.checkpointing.transport import CheckpointTransport diff --git a/torchft/checkpointing/pg_transport_bench.py b/torchft/checkpointing/pg_transport_bench.py index 1bf385f8..71e27fdc 100644 --- a/torchft/checkpointing/pg_transport_bench.py +++ b/torchft/checkpointing/pg_transport_bench.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist -from torchft.checkpointing.pg_transport import PGTransport, _timeit +from torchft.checkpointing.pg_transport import _timeit, PGTransport from torchft.process_group import ProcessGroupBabyNCCL logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchft/checkpointing/pg_transport_test.py b/torchft/checkpointing/pg_transport_test.py index 250528b2..1d98d2b8 100644 --- a/torchft/checkpointing/pg_transport_test.py +++ b/torchft/checkpointing/pg_transport_test.py @@ -1,6 +1,6 @@ import sys from datetime import timedelta -from unittest import TestCase, skipIf, skipUnless +from unittest import skipIf, skipUnless, TestCase import torch from torch.distributed import TCPStore diff --git a/torchft/checkpointing/transport_test.py b/torchft/checkpointing/transport_test.py index 2eea1570..c33801a3 100644 --- a/torchft/checkpointing/transport_test.py +++ b/torchft/checkpointing/transport_test.py @@ -1,13 +1,13 @@ import threading import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed, ThreadPoolExecutor from datetime import timedelta from typing import Callable from unittest import TestCase import torch import torch.distributed as dist -from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor from torchft.checkpointing.transport import CheckpointTransport diff --git a/torchft/collectives.py b/torchft/collectives.py index cd84b0b9..513a2d99 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -254,7 +254,14 @@ def reduce_scatter_quantized( fut = work.get_future() def callback(fut: Future[list[torch.Tensor]]) -> None: - nonlocal inputs, quantized_inputs_out, world_size, sync_stream, rank, reduce_outputs, reducescatter_opts + nonlocal \ + inputs, \ + quantized_inputs_out, \ + world_size, \ + sync_stream, \ + rank, \ + reduce_outputs, \ + reducescatter_opts with torch.cuda.stream(sync_stream): # Setup stream dependency diff --git a/torchft/collectives_test.py b/torchft/collectives_test.py index b73a18b2..1536d0a2 100644 --- a/torchft/collectives_test.py +++ b/torchft/collectives_test.py @@ -6,7 +6,7 @@ import unittest from typing import Callable -from unittest import TestCase, skipUnless +from unittest import skipUnless, TestCase import torch import torch.distributed as dist diff --git a/torchft/coordination.py b/torchft/coordination.py index fcba08e0..0173b8e9 100644 --- a/torchft/coordination.py +++ b/torchft/coordination.py @@ -5,7 +5,7 @@ .. warning:: As torchft is still in development, the APIs in this module are subject to change. -This module exposes low level coordination APIs to allow you to build your own +This module exposes low level coordination APIs to allow you to build your own custom fault tolerance algorithms on top of torchft. If you're looking for a more complete solution, please use the other modules in diff --git a/torchft/ddp.py b/torchft/ddp.py index 1af50876..78484ecb 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -14,7 +14,7 @@ import os import sys -from typing import TYPE_CHECKING, Optional, cast +from typing import cast, Optional, TYPE_CHECKING from unittest.mock import patch import torch @@ -26,7 +26,7 @@ from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo if TYPE_CHECKING: - from torchft.manager import Manager, _ManagedFuture + from torchft.manager import _ManagedFuture, Manager class DistributedDataParallel(parallel.DistributedDataParallel): diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index 8c9c7a19..118cc8d2 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -14,7 +14,7 @@ from torch.futures import Future from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel -from torchft.manager import Manager, _ManagedWork +from torchft.manager import _ManagedWork, Manager from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo from torchft.work import _DummyWork diff --git a/torchft/device_mesh.py b/torchft/device_mesh.py index 252384cd..168189a6 100644 --- a/torchft/device_mesh.py +++ b/torchft/device_mesh.py @@ -1,14 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union import torch from torch._C._distributed_c10d import Backend as C10dBackend from torch.distributed import ( DeviceMesh, - ProcessGroup as BaseProcessGroup, get_rank, init_device_mesh, + ProcessGroup as BaseProcessGroup, ) from torch.distributed.tensor.device_mesh import _mesh_resources diff --git a/torchft/device_mesh_test.py b/torchft/device_mesh_test.py index 3a8ab5b8..757cab5e 100644 --- a/torchft/device_mesh_test.py +++ b/torchft/device_mesh_test.py @@ -16,9 +16,9 @@ from torchft.manager import Manager from torchft.process_group import ( + ft_init_device_mesh, ManagedProcessGroup, ProcessGroupGloo, - ft_init_device_mesh, ) diff --git a/torchft/diloco_regression_test.py b/torchft/diloco_regression_test.py index 11ab9f62..63cb9744 100644 --- a/torchft/diloco_regression_test.py +++ b/torchft/diloco_regression_test.py @@ -5,12 +5,12 @@ import os import sys import threading -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed, ThreadPoolExecutor from contextlib import ExitStack from datetime import timedelta from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, cast, overload -from unittest import TestCase, skipIf +from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple +from unittest import skipIf, TestCase import torch from parameterized import parameterized @@ -226,7 +226,7 @@ def train_loop(self) -> Dict[str, Any]: ) parameter_history["global_parameter_history"][local_step][ f"layers.{i}.weight" - ] = (value["weight"].data.clone().detach().cpu().tolist()) + ] = value["weight"].data.clone().detach().cpu().tolist() manager_steps.add(manager_curr_step) diff --git a/torchft/examples/slurm/punisher.py b/torchft/examples/slurm/punisher.py index 6cac7145..328d9f41 100644 --- a/torchft/examples/slurm/punisher.py +++ b/torchft/examples/slurm/punisher.py @@ -4,7 +4,7 @@ import time from torchx import specs -from torchx.runner import Runner, get_runner +from torchx.runner import get_runner, Runner logging.basicConfig(level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchft/examples/slurm/runner.py b/torchft/examples/slurm/runner.py index a1f6184c..4111b3db 100644 --- a/torchft/examples/slurm/runner.py +++ b/torchft/examples/slurm/runner.py @@ -5,7 +5,7 @@ from torchx import specs from torchx.components.dist import ddp -from torchx.runner import Runner, get_runner +from torchx.runner import get_runner, Runner logging.basicConfig(level=logging.INFO) logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchft/fsdp_test.py b/torchft/fsdp_test.py index 7c3a7b65..f7a5c2ec 100644 --- a/torchft/fsdp_test.py +++ b/torchft/fsdp_test.py @@ -15,34 +15,34 @@ import torch.distributed as dist from torch import nn from torch._C._distributed_c10d import ( + _resolve_process_group, AllgatherOptions, AllreduceOptions, BroadcastOptions, ReduceOp, - _resolve_process_group, ) from torch.distributed import ( + _functional_collectives, + get_world_size, ReduceOp, TCPStore, Work, - _functional_collectives, - get_world_size, ) from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ( ColwiseParallel, + parallelize_module, PrepareModuleInput, RowwiseParallel, SequenceParallel, - parallelize_module, ) from torchft.manager import Manager from torchft.process_group import ( + ft_init_device_mesh, ManagedProcessGroup, ProcessGroupGloo, - ft_init_device_mesh, ) diff --git a/torchft/futures.py b/torchft/futures.py index 247d687f..c45a4a36 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -256,8 +256,7 @@ def _clear_del_queue(self) -> int: refcount = sys.getrefcount(item) assert ( # 1 from item, 1 from getrefcount - refcount - == 2 + refcount == 2 ), f"items in del_queue reference should not have other references, found {refcount=}" del item diff --git a/torchft/futures_test.py b/torchft/futures_test.py index 59ca73d5..679d497e 100644 --- a/torchft/futures_test.py +++ b/torchft/futures_test.py @@ -1,6 +1,6 @@ import threading from datetime import timedelta -from unittest import TestCase, skipUnless +from unittest import skipUnless, TestCase from unittest.mock import Mock, patch import torch diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index ade4dc5b..435376a8 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -8,6 +8,7 @@ ========= This module implements a fault tolerant version of LocalSGD and related methods. """ + import logging import math from contextlib import nullcontext diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 422188f8..c1a329db 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -5,17 +5,17 @@ import sys import threading import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed, ThreadPoolExecutor from contextlib import ExitStack from dataclasses import field from datetime import timedelta -from typing import Any, Dict, cast -from unittest import TestCase, skipIf +from typing import Any, cast, Dict +from unittest import skipIf, TestCase import torch from parameterized import parameterized from torch import nn, optim -from torch.distributed.pipelining import SplitPoint, pipeline +from torch.distributed.pipelining import pipeline, SplitPoint from torch.distributed.tensor import DTensor, Replicate from torchft._test.diloco_trainer import DiLoCoTrainer, MultiMyModel @@ -116,7 +116,6 @@ def diloco_train_loop( runner: Runner, train_loop_args: dict[str, Any] = {}, ) -> Dict[str, Dict[str, object]]: - model_state_dict = train_loop_args.get("model_state_dict", {}) n_fragments = train_loop_args.get("n_fragments", 1) diloco_args = train_loop_args.get("diloco_args", {}) diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 881b96ea..5561ce52 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -6,15 +6,15 @@ from typing import Dict from unittest import TestCase -from unittest.mock import MagicMock, create_autospec +from unittest.mock import create_autospec, MagicMock import torch from parameterized import parameterized -from torch import Tensor, nn, optim +from torch import nn, optim, Tensor from torch.distributed.distributed_c10d import Work from torch.distributed.tensor import DTensor -from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor +from torchft.local_sgd import DiLoCo, extract_local_tensor, LocalSGD from torchft.manager import Manager from torchft.work import _DummyWork diff --git a/torchft/manager.py b/torchft/manager.py index 6e19c4c3..ae7ae41a 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -37,16 +37,16 @@ from datetime import timedelta from enum import Enum from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, List, Optional, + TYPE_CHECKING, TypeAlias, TypeVar, Union, - cast, ) import torch diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index e75d5dde..30d2ef4c 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -4,13 +4,14 @@ import time import traceback from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from contextlib import ExitStack, contextmanager +from concurrent.futures import as_completed, ThreadPoolExecutor +from contextlib import contextmanager, ExitStack from dataclasses import dataclass, field from datetime import timedelta -from enum import Enum, auto +from enum import auto, Enum from typing import ( Any, + cast, Dict, Generator, List, @@ -19,7 +20,6 @@ Set, Tuple, TypeVar, - cast, ) from unittest import TestCase diff --git a/torchft/manager_test.py b/torchft/manager_test.py index ca8a07e8..8ab7295a 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -10,7 +10,7 @@ from datetime import timedelta from typing import Optional from unittest import TestCase -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import create_autospec, MagicMock, patch import torch from torch.distributed import TCPStore @@ -18,7 +18,7 @@ from torchft._torchft import QuorumResult from torchft.checkpointing._rwlock import RWLock from torchft.checkpointing.transport import CheckpointTransport -from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode +from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode from torchft.process_group import ProcessGroup from torchft.work import _DummyWork diff --git a/torchft/optim.py b/torchft/optim.py index 1d7b187d..a2884392 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -12,7 +12,7 @@ """ -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING import torch from torch.optim import Optimizer diff --git a/torchft/optim_test.py b/torchft/optim_test.py index 5dd69640..7938ab8f 100644 --- a/torchft/optim_test.py +++ b/torchft/optim_test.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from unittest import TestCase -from unittest.mock import MagicMock, create_autospec +from unittest.mock import create_autospec, MagicMock import torch from torch.nn import Linear diff --git a/torchft/process_group.py b/torchft/process_group.py index f259cc37..4be8f51f 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -26,17 +26,17 @@ from datetime import timedelta from multiprocessing.connection import Connection from typing import ( - TYPE_CHECKING, Any, Callable, + cast, Dict, Generator, List, Optional, Tuple, + TYPE_CHECKING, TypeVar, Union, - cast, ) import torch @@ -1308,7 +1308,7 @@ def set_stream(self) -> Generator[None, None, None]: def _maybe_share_tensors( - tensor: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor] + tensor: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> None: """Move a tensor / list of tensors to shared memory if not already in shared memory.""" if isinstance(tensor, list): diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index bc364e5f..523bccdd 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -9,8 +9,8 @@ import sys from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from datetime import timedelta -from typing import Any, Callable, Dict, List, cast -from unittest import TestCase, skipIf, skipUnless +from typing import Any, Callable, cast, Dict, List +from unittest import skipIf, skipUnless, TestCase from unittest.mock import Mock, patch import torch @@ -18,6 +18,7 @@ from parameterized import parameterized from torch import nn from torch._C._distributed_c10d import ( + _resolve_process_group, AllgatherOptions, AllreduceCoalescedOptions, AllreduceOptions, @@ -26,19 +27,21 @@ BroadcastOptions, ReduceOp, ReduceScatterOptions, - _resolve_process_group, ) from torch.distributed import ( - ReduceOp, - TCPStore, _functional_collectives, get_world_size, + ReduceOp, + TCPStore, ) from torch.distributed.device_mesh import init_device_mesh from torchft.manager import Manager from torchft.process_group import ( + _ErrorSwallowingWork, ErrorSwallowingProcessGroupWrapper, + extend_device_mesh, + ft_init_device_mesh, ManagedProcessGroup, ProcessGroup, ProcessGroupBabyGloo, @@ -47,9 +50,6 @@ ProcessGroupGloo, ProcessGroupNCCL, ProcessGroupWrapper, - _ErrorSwallowingWork, - extend_device_mesh, - ft_init_device_mesh, ) from torchft.work import _DummyWork diff --git a/torchft/quantization_test.py b/torchft/quantization_test.py index 02166ba5..cfd5c885 100644 --- a/torchft/quantization_test.py +++ b/torchft/quantization_test.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from unittest import TestCase, skipUnless +from unittest import skipUnless, TestCase import torch from parameterized import parameterized @@ -33,7 +33,6 @@ "CUDA is required for this test", ) class QuantizationTest(TestCase): - def run_test( self, world_size: int, diff --git a/train_ddp.py b/train_ddp.py index be93d9bb..ff405306 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -95,7 +95,9 @@ def state_dict(): device=( "cuda" if torch.cuda.is_available() - else "xpu" if torch.xpu.is_available() else "cpu" + else "xpu" + if torch.xpu.is_available() + else "cpu" ), ) diff --git a/train_diloco.py b/train_diloco.py index e207e73e..6625535a 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -22,7 +22,7 @@ import torchvision.transforms as transforms from torch import nn, optim from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.pipelining import SplitPoint, pipeline +from torch.distributed.pipelining import pipeline, SplitPoint from torch.export import export from torch.utils.tensorboard import SummaryWriter from torchdata.stateful_dataloader import StatefulDataLoader