@@ -64,23 +64,21 @@ def check_file_or_directory(
6464 )
6565
6666
67- def _build_header (* , entry_name : str , entry_type : str ) -> bytes :
68- """Builds a header to encode a path with given name and type .
67+ def _build_header (* , entry_name : str ) -> bytes :
68+ """Builds a header to encode a path with given name.
6969
7070 Args:
7171 entry_name: The name of the entry to build the header for.
72- entry_type: The type of the entry (file or directory).
7372
7473 Returns:
7574 A sequence of bytes that encodes all arguments as a sequence of UTF-8
7675 bytes. Each argument is separated by dots and the last byte is also a
7776 dot (so the file digest can be appended unambiguously).
7877 """
79- encoded_type = entry_type .encode ("utf-8" )
8078 # Prevent confusion if name has a "." inside by encoding to base64.
8179 encoded_name = base64 .b64encode (entry_name .encode ("utf-8" ))
8280 # Note: empty string at the end, to terminate header with a "."
83- return b"." .join ([encoded_type , encoded_name , b"" ])
81+ return b"." .join ([encoded_name , b"" ])
8482
8583
8684def _ignored (path : pathlib .Path , ignore_paths : Iterable [pathlib .Path ]) -> bool :
@@ -243,122 +241,44 @@ def _build_manifest(
243241 return manifest .FileLevelManifest (items )
244242
245243
246- class _FileDigestTree :
247- """A tree of files with their digests.
248-
249- Every leaf in the tree is a file, paired with its digest. Every intermediate
250- node represents a directory. We need to pair every directory with a digest,
251- in a bottom-up fashion.
252- """
253-
254- def __init__ (
255- self , path : pathlib .PurePath , digest : Optional [hashing .Digest ] = None
256- ):
257- """Builds a node in the digest tree.
258-
259- Don't call this from outside of the class. Instead, use `build_tree`.
260-
261- Args:
262- path: Path included in the node.
263- digest: Optional hash of the path. Files must have a digest,
264- directories never have one.
265- """
266- self ._path = path
267- self ._digest = digest
268- self ._children : list [_FileDigestTree ] = []
269-
270- @classmethod
271- def build_tree (
272- cls , items : Iterable [manifest .FileManifestItem ]
273- ) -> "_FileDigestTree" :
274- """Builds a tree out of the sequence of manifest items."""
275- path_to_node : dict [pathlib .PurePath , _FileDigestTree ] = {}
276-
277- for file_item in items :
278- file = file_item .path
279- node = cls (file , file_item .digest )
280- for parent in file .parents :
281- if parent in path_to_node :
282- parent_node = path_to_node [parent ]
283- parent_node ._children .append (node )
284- break # everything else already exists
285-
286- parent_node = cls (parent ) # no digest for directories
287- parent_node ._children .append (node )
288- path_to_node [parent ] = parent_node
289- node = parent_node
290-
291- # Handle empty model
292- if not path_to_node :
293- return cls (pathlib .PurePosixPath ())
294-
295- return path_to_node [pathlib .PurePosixPath ()]
296-
297- def get_digest (
298- self , hasher_factory : Callable [[], hashing .StreamingHashEngine ]
299- ) -> hashing .Digest :
300- """Returns the digest of this tree of files.
301-
302- Args:
303- hasher_factory: A callable that returns a
304- `hashing.StreamingHashEngine` instance used to merge individual
305- digests to compute an aggregate digest.
306- """
307- hasher = hasher_factory ()
308-
309- for child in sorted (self ._children , key = lambda c : c ._path ):
310- name = child ._path .name
311- if child ._digest is not None :
312- header = _build_header (entry_name = name , entry_type = "file" )
313- hasher .update (header )
314- hasher .update (child ._digest .digest_value )
315- else :
316- header = _build_header (entry_name = name , entry_type = "dir" )
317- hasher .update (header )
318- digest = child .get_digest (hasher_factory )
319- hasher .update (digest .digest_value )
320-
321- return hasher .compute ()
322-
323-
324244class DigestSerializer (FilesSerializer ):
325245 """Serializer for a model that performs a traversal of the model directory.
326246
327247 This serializer produces a single hash for the entire model. If the model is
328248 a file, the hash is the digest of the file. If the model is a directory, we
329- perform a depth-first traversal of the directory, hash each individual files
330- and aggregate the hashes together.
331-
332- Currently, this has a different initialization than `FilesSerializer`, but
333- this will likely change in a subsequent change. Similarly, currently, this
334- only supports one single worker, but this will change in the future.
249+ traverse the directory, hash each individual file and aggregate the hashes
250+ together.
335251 """
336252
337253 def __init__ (
338254 self ,
339- file_hasher : file .SimpleFileHasher ,
340- merge_hasher_factory : Callable [[], hashing .StreamingHashEngine ] ,
255+ file_hasher_factory : Callable [[ pathlib . Path ], file .FileHasher ] ,
256+ merge_hasher : hashing .StreamingHashEngine ,
341257 * ,
258+ max_workers : Optional [int ] = None ,
342259 allow_symlinks : bool = False ,
343260 ):
344261 """Initializes an instance to serialize a model with this serializer.
345262
346263 Args:
347- hasher: The hash engine used to hash the individual files.
348- merge_hasher_factory: A callable that returns a
349- `hashing.StreamingHashEngine` instance used to merge individual
350- file digests to compute an aggregate digest.
264+ file_hasher_factory: A callable to build the hash engine used to
265+ hash individual files. Because each file is processed in
266+ parallel, every thread needs to call the factory to start
267+ hashing.
268+ merge_hasher: A `hashing.StreamingHashEngine` instance used to
269+ merge individual file digests to compute an aggregate digest.
270+ max_workers: Maximum number of workers to use in parallel. Default
271+ is to defer to the `concurent.futures` library.
351272 allow_symlinks: Controls whether symbolic links are included. If a
352273 symlink is present but the flag is `False` (default) the
353274 serialization would raise an error.
354275 """
355-
356- def _factory (path : pathlib .Path ) -> file .FileHasher :
357- file_hasher .set_file (path )
358- return file_hasher
359-
360- super ().__init__ (_factory , max_workers = 1 , allow_symlinks = allow_symlinks )
361- self ._merge_hasher_factory = merge_hasher_factory
276+ super ().__init__ (
277+ file_hasher_factory ,
278+ max_workers = max_workers ,
279+ allow_symlinks = allow_symlinks ,
280+ )
281+ self ._merge_hasher = merge_hasher
362282
363283 @override
364284 def serialize (
@@ -395,17 +315,18 @@ def serialize(
395315 def _build_manifest (
396316 self , items : Iterable [manifest .FileManifestItem ]
397317 ) -> manifest .DigestManifest :
398- # Note: we do several computations here to try and match the old
399- # behavior but these would be simplified in the future. Since we are
400- # defining the hashing behavior, we can freely change this.
401-
402318 # If the model is just one file, return the hash of the file.
403319 # A model is a file if we have one item only and its path is empty.
404320 items = list (items )
405321 if len (items ) == 1 and not items [0 ].path .name :
406322 return manifest .DigestManifest (items [0 ].digest )
407323
408- # Otherwise, build a tree of files and compute the digests.
409- tree = _FileDigestTree .build_tree (items )
410- digest = tree .get_digest (self ._merge_hasher_factory )
324+ self ._merge_hasher .reset ()
325+
326+ for item in sorted (items , key = lambda i : i .path ):
327+ header = _build_header (entry_name = item .path .name )
328+ self ._merge_hasher .update (header )
329+ self ._merge_hasher .update (item .digest .digest_value )
330+
331+ digest = self ._merge_hasher .compute ()
411332 return manifest .DigestManifest (digest )
0 commit comments