Skip to content

Commit bce7641

Browse files
remove _FileDigestTree from serialize_by_file.DigestSerializer (#364)
* remove _FileDigestTree from serialize_by_file.DigestSerializer Standardizes the DigestSerializer merge behavior to match what's done in the `serialize_by_file_shard` equivalent. Particularly, the file hasher was turned into a factory, so the hashing can be done in parallel via max_workers. Merging the digests into a final digest now only looks at files, and not the directories as well. Signed-off-by: Spencer Schrock <[email protected]> * update callers Signed-off-by: Spencer Schrock <[email protected]> * update goldens Signed-off-by: Spencer Schrock <[email protected]> --------- Signed-off-by: Spencer Schrock <[email protected]>
1 parent d4fca92 commit bce7641

File tree

19 files changed

+92
-235
lines changed

19 files changed

+92
-235
lines changed

benchmarks/serialize.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,22 +122,15 @@ def run(args: argparse.Namespace) -> Optional[in_toto.IntotoPayload]:
122122
if args.skip_manifest or args.single_digest:
123123
merge_hasher_factory = get_hash_engine_factory(args.merge_hasher)
124124
if args.use_shards:
125-
serializer = serialize_by_file_shard.DigestSerializer(
126-
hasher,
127-
merge_hasher_factory(), # pytype: disable=not-instantiable
128-
max_workers=args.max_workers,
129-
)
125+
serializer_factory = serialize_by_file_shard.DigestSerializer
130126
else:
131-
# This gets complicated because the API here is not matching the
132-
# rest. We should fix this.
133-
if args.max_workers is not None and args.max_workers != 1:
134-
raise ValueError("Currently, only 1 worker is supported here")
135-
serializer = serialize_by_file.DigestSerializer(
136-
# pytype: disable=wrong-arg-count
137-
hasher(pathlib.Path("unused")),
138-
# pytype: enable=wrong-arg-count
139-
merge_hasher_factory,
140-
)
127+
serializer_factory = serialize_by_file.DigestSerializer
128+
129+
serializer = serializer_factory(
130+
hasher,
131+
merge_hasher_factory(), # pytype: disable=not-instantiable
132+
max_workers=args.max_workers,
133+
)
141134
else:
142135
if args.use_shards:
143136
serializer_factory = serialize_by_file_shard.ManifestSerializer

src/model_signing/api.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
import pathlib
2525
import sys
26-
from typing import Literal, Optional, cast
26+
from typing import Literal, Optional
2727

2828
from model_signing.hashing import file
2929
from model_signing.hashing import hashing
@@ -257,6 +257,7 @@ def set_serialize_by_file_to_digest(
257257
hashing_algorithm: Literal["sha256", "blake2"] = "sha256",
258258
merge_algorithm: Literal["sha256", "blake2"] = "sha256",
259259
chunk_size: int = 1048576,
260+
max_workers: Optional[int] = None,
260261
allow_symlinks: bool = False,
261262
) -> Self:
262263
"""Configures serialization to a single digest, at file granularity.
@@ -272,25 +273,22 @@ def set_serialize_by_file_to_digest(
272273
chunk_size: The amount of file to read at once. Default is 1MB. A
273274
special value of 0 signals to attempt to read everything in a
274275
single call.
276+
max_workers: Maximum number of workers to use in parallel. Default
277+
is to defer to the `concurrent.futures` library.
275278
allow_symlinks: Controls whether symbolic links are included. If a
276279
symlink is present but the flag is `False` (default) the
277280
serialization would raise an error.
278281
279282
Returns:
280283
The new hashing configuration with the new serialization method.
281284
"""
282-
# TODO: https://github.com/sigstore/model-transparency/issues/197 -
283-
# Because the API for this case is different than the other ones, we
284-
# have to perform additional steps here.
285-
file_hasher = cast(
286-
file.SimpleFileHasher,
285+
self._serializer = serialize_by_file.DigestSerializer(
287286
self._build_file_hasher_factory(
288287
hashing_algorithm, chunk_size=chunk_size
289-
)(pathlib.Path("unused")),
290-
)
291-
merge_hasher = self._build_stream_hasher(merge_algorithm).__class__
292-
self._serializer = serialize_by_file.DigestSerializer(
293-
file_hasher, merge_hasher, allow_symlinks=allow_symlinks
288+
),
289+
self._build_stream_hasher(merge_algorithm),
290+
max_workers=max_workers,
291+
allow_symlinks=allow_symlinks,
294292
)
295293
return self
296294

src/model_signing/serialization/serialize_by_file.py

Lines changed: 30 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,21 @@ def check_file_or_directory(
6464
)
6565

6666

67-
def _build_header(*, entry_name: str, entry_type: str) -> bytes:
68-
"""Builds a header to encode a path with given name and type.
67+
def _build_header(*, entry_name: str) -> bytes:
68+
"""Builds a header to encode a path with given name.
6969
7070
Args:
7171
entry_name: The name of the entry to build the header for.
72-
entry_type: The type of the entry (file or directory).
7372
7473
Returns:
7574
A sequence of bytes that encodes all arguments as a sequence of UTF-8
7675
bytes. Each argument is separated by dots and the last byte is also a
7776
dot (so the file digest can be appended unambiguously).
7877
"""
79-
encoded_type = entry_type.encode("utf-8")
8078
# Prevent confusion if name has a "." inside by encoding to base64.
8179
encoded_name = base64.b64encode(entry_name.encode("utf-8"))
8280
# Note: empty string at the end, to terminate header with a "."
83-
return b".".join([encoded_type, encoded_name, b""])
81+
return b".".join([encoded_name, b""])
8482

8583

8684
def _ignored(path: pathlib.Path, ignore_paths: Iterable[pathlib.Path]) -> bool:
@@ -243,122 +241,44 @@ def _build_manifest(
243241
return manifest.FileLevelManifest(items)
244242

245243

246-
class _FileDigestTree:
247-
"""A tree of files with their digests.
248-
249-
Every leaf in the tree is a file, paired with its digest. Every intermediate
250-
node represents a directory. We need to pair every directory with a digest,
251-
in a bottom-up fashion.
252-
"""
253-
254-
def __init__(
255-
self, path: pathlib.PurePath, digest: Optional[hashing.Digest] = None
256-
):
257-
"""Builds a node in the digest tree.
258-
259-
Don't call this from outside of the class. Instead, use `build_tree`.
260-
261-
Args:
262-
path: Path included in the node.
263-
digest: Optional hash of the path. Files must have a digest,
264-
directories never have one.
265-
"""
266-
self._path = path
267-
self._digest = digest
268-
self._children: list[_FileDigestTree] = []
269-
270-
@classmethod
271-
def build_tree(
272-
cls, items: Iterable[manifest.FileManifestItem]
273-
) -> "_FileDigestTree":
274-
"""Builds a tree out of the sequence of manifest items."""
275-
path_to_node: dict[pathlib.PurePath, _FileDigestTree] = {}
276-
277-
for file_item in items:
278-
file = file_item.path
279-
node = cls(file, file_item.digest)
280-
for parent in file.parents:
281-
if parent in path_to_node:
282-
parent_node = path_to_node[parent]
283-
parent_node._children.append(node)
284-
break # everything else already exists
285-
286-
parent_node = cls(parent) # no digest for directories
287-
parent_node._children.append(node)
288-
path_to_node[parent] = parent_node
289-
node = parent_node
290-
291-
# Handle empty model
292-
if not path_to_node:
293-
return cls(pathlib.PurePosixPath())
294-
295-
return path_to_node[pathlib.PurePosixPath()]
296-
297-
def get_digest(
298-
self, hasher_factory: Callable[[], hashing.StreamingHashEngine]
299-
) -> hashing.Digest:
300-
"""Returns the digest of this tree of files.
301-
302-
Args:
303-
hasher_factory: A callable that returns a
304-
`hashing.StreamingHashEngine` instance used to merge individual
305-
digests to compute an aggregate digest.
306-
"""
307-
hasher = hasher_factory()
308-
309-
for child in sorted(self._children, key=lambda c: c._path):
310-
name = child._path.name
311-
if child._digest is not None:
312-
header = _build_header(entry_name=name, entry_type="file")
313-
hasher.update(header)
314-
hasher.update(child._digest.digest_value)
315-
else:
316-
header = _build_header(entry_name=name, entry_type="dir")
317-
hasher.update(header)
318-
digest = child.get_digest(hasher_factory)
319-
hasher.update(digest.digest_value)
320-
321-
return hasher.compute()
322-
323-
324244
class DigestSerializer(FilesSerializer):
325245
"""Serializer for a model that performs a traversal of the model directory.
326246
327247
This serializer produces a single hash for the entire model. If the model is
328248
a file, the hash is the digest of the file. If the model is a directory, we
329-
perform a depth-first traversal of the directory, hash each individual files
330-
and aggregate the hashes together.
331-
332-
Currently, this has a different initialization than `FilesSerializer`, but
333-
this will likely change in a subsequent change. Similarly, currently, this
334-
only supports one single worker, but this will change in the future.
249+
traverse the directory, hash each individual file and aggregate the hashes
250+
together.
335251
"""
336252

337253
def __init__(
338254
self,
339-
file_hasher: file.SimpleFileHasher,
340-
merge_hasher_factory: Callable[[], hashing.StreamingHashEngine],
255+
file_hasher_factory: Callable[[pathlib.Path], file.FileHasher],
256+
merge_hasher: hashing.StreamingHashEngine,
341257
*,
258+
max_workers: Optional[int] = None,
342259
allow_symlinks: bool = False,
343260
):
344261
"""Initializes an instance to serialize a model with this serializer.
345262
346263
Args:
347-
hasher: The hash engine used to hash the individual files.
348-
merge_hasher_factory: A callable that returns a
349-
`hashing.StreamingHashEngine` instance used to merge individual
350-
file digests to compute an aggregate digest.
264+
file_hasher_factory: A callable to build the hash engine used to
265+
hash individual files. Because each file is processed in
266+
parallel, every thread needs to call the factory to start
267+
hashing.
268+
merge_hasher: A `hashing.StreamingHashEngine` instance used to
269+
merge individual file digests to compute an aggregate digest.
270+
max_workers: Maximum number of workers to use in parallel. Default
271+
is to defer to the `concurent.futures` library.
351272
allow_symlinks: Controls whether symbolic links are included. If a
352273
symlink is present but the flag is `False` (default) the
353274
serialization would raise an error.
354275
"""
355-
356-
def _factory(path: pathlib.Path) -> file.FileHasher:
357-
file_hasher.set_file(path)
358-
return file_hasher
359-
360-
super().__init__(_factory, max_workers=1, allow_symlinks=allow_symlinks)
361-
self._merge_hasher_factory = merge_hasher_factory
276+
super().__init__(
277+
file_hasher_factory,
278+
max_workers=max_workers,
279+
allow_symlinks=allow_symlinks,
280+
)
281+
self._merge_hasher = merge_hasher
362282

363283
@override
364284
def serialize(
@@ -395,17 +315,18 @@ def serialize(
395315
def _build_manifest(
396316
self, items: Iterable[manifest.FileManifestItem]
397317
) -> manifest.DigestManifest:
398-
# Note: we do several computations here to try and match the old
399-
# behavior but these would be simplified in the future. Since we are
400-
# defining the hashing behavior, we can freely change this.
401-
402318
# If the model is just one file, return the hash of the file.
403319
# A model is a file if we have one item only and its path is empty.
404320
items = list(items)
405321
if len(items) == 1 and not items[0].path.name:
406322
return manifest.DigestManifest(items[0].digest)
407323

408-
# Otherwise, build a tree of files and compute the digests.
409-
tree = _FileDigestTree.build_tree(items)
410-
digest = tree.get_digest(self._merge_hasher_factory)
324+
self._merge_hasher.reset()
325+
326+
for item in sorted(items, key=lambda i: i.path):
327+
header = _build_header(entry_name=item.path.name)
328+
self._merge_hasher.update(header)
329+
self._merge_hasher.update(item.digest.digest_value)
330+
331+
digest = self._merge_hasher.compute()
411332
return manifest.DigestManifest(digest)

0 commit comments

Comments
 (0)