diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index c358d517..e1c4cf8e 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -18,9 +18,15 @@ import urllib.request from http.server import BaseHTTPRequestHandler from typing import Callable, Generic, TypeVar +from dataclasses import dataclass +import pickle +from io import BufferedIOBase +from typing import Tuple +import struct import torch - +from torch.utils._pytree import tree_flatten, tree_unflatten +from hashlib import sha256 from torchft.http import _IPv6HTTPServer logger: logging.Logger = logging.getLogger(__name__) @@ -28,6 +34,83 @@ T = TypeVar("T") +@dataclass +class TensorMetadata: + nbytes: int + dtype: torch.dtype + storage_offset: int + size: Tuple[int, ...] + stride: Tuple[int, ...] + + +def write_state_dict(state_dict: object, f: BufferedIOBase) -> None: + """ + Write the state_dict to the file-like object. + """ + values, spec = tree_flatten(state_dict) + + storages = [] + non_tensor_values = [] + for value in values: + if isinstance(value, torch.Tensor): + storage = value.untyped_storage() + storages.append(storage) + non_tensor_values.append( + TensorMetadata( + nbytes=storage.nbytes(), + dtype=value.dtype, + storage_offset=value.storage_offset(), + size=value.size(), + stride=value.stride(), + ) + ) + else: + non_tensor_values.append(value) + + meta_buf = pickle.dumps((non_tensor_values, spec)) + checksum = sha256(meta_buf).hexdigest() + total_length = len(meta_buf) + len(checksum) + + f.write(struct.pack(" object: + """ + Read the state_dict from the file-like object. + """ + + total_length = struct.unpack(" None: logger.error(msg) @@ -100,7 +183,8 @@ def load_from_address(cls, address: str) -> T: data = f.read() reader = io.BytesIO(data) - return torch.load(reader, weights_only=True) + state_dict = read_state_dict(reader) + return state_dict def address(self) -> str: """ diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index f27392bc..9c7f6d7b 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -4,11 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import unittest import urllib.error from unittest import TestCase from unittest.mock import MagicMock - -from torchft.checkpointing import CheckpointServer +from io import BytesIO +import torch +from typing import Tuple +from checkpointing import CheckpointServer, TensorMetadata, write_state_dict, read_state_dict class TestCheckpointing(TestCase): @@ -33,3 +36,67 @@ def test_checkpoint_server(self) -> None: CheckpointServer.load_from_address(addr) server.shutdown() + + def setUp(self): + self.file = BytesIO() + + def test_scalar_tensor(self): + tensor = torch.tensor(42, dtype=torch.int32) + state_dict = {'scalar': tensor} + write_state_dict(state_dict, self.file) + self.file.seek(0) + + result = read_state_dict(self.file) + self.assertTrue(torch.equal(result['scalar'], tensor)) + + def test_strided_tensor(self): + base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4) + strided_tensor = base_tensor[::2, ::2] + state_dict = {'strided': strided_tensor} + write_state_dict(state_dict, self.file) + self.file.seek(0) + + result = read_state_dict(self.file) + self.assertTrue(torch.equal(result['strided'], strided_tensor)) + + def test_tensor_with_offset(self): + base_tensor = torch.arange(10, dtype=torch.float64) + offset_tensor = base_tensor[2:] + state_dict = {'offset': offset_tensor} + write_state_dict(state_dict, self.file) + self.file.seek(0) + + result = read_state_dict(self.file) + self.assertTrue(torch.equal(result['offset'], offset_tensor)) + + def test_nested_tensors(self): + tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32) + tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64) + state_dict = {'nested': {'tensor1': tensor1, 'tensor2': tensor2}} + write_state_dict(state_dict, self.file) + self.file.seek(0) + + result = read_state_dict(self.file) + self.assertTrue(torch.equal(result['nested']['tensor1'], tensor1)) + self.assertTrue(torch.equal(result['nested']['tensor2'], tensor2)) + + def test_various_data_types(self): + tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16) + tensor_bool = torch.tensor([True, False, True], dtype=torch.bool) + state_dict = { + 'float32': tensor_float32, + 'int16': tensor_int16, + 'bool': tensor_bool, + } + write_state_dict(state_dict, self.file) + self.file.seek(0) + + result = read_state_dict(self.file) + self.assertTrue(torch.equal(result['float32'], tensor_float32)) + self.assertTrue(torch.equal(result['int16'], tensor_int16)) + self.assertTrue(torch.equal(result['bool'], tensor_bool)) + + +if __name__ == '__main__': + unittest.main()