@@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
60
60
if os .path .exists (download_target ) and not os .path .isfile (download_target ):
61
61
raise RuntimeError (f"{ download_target } exists and is not a regular file" )
62
62
63
+ def compute_sha256 (file_path : str ) -> str :
64
+ sha256 = hashlib .sha256 ()
65
+ with open (file_path , "rb" ) as f :
66
+ for chunk in iter (lambda : f .read (8192 ), b"" ):
67
+ sha256 .update (chunk )
68
+ return sha256 .hexdigest ()
69
+
63
70
if os .path .isfile (download_target ):
64
- with open (download_target , "rb" ) as f :
65
- model_bytes = f .read ()
66
- if hashlib .sha256 (model_bytes ).hexdigest () == expected_sha256 :
67
- return model_bytes if in_memory else download_target
71
+ if compute_sha256 (download_target ) == expected_sha256 :
72
+ if in_memory :
73
+ with open (download_target , "rb" ) as f :
74
+ return f .read ()
75
+ else :
76
+ return download_target
68
77
else :
69
78
warnings .warn (
70
79
f"{ download_target } exists, but the SHA256 checksum does not match; re-downloading the file"
@@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
86
95
output .write (buffer )
87
96
loop .update (len (buffer ))
88
97
89
- model_bytes = open (download_target , "rb" ).read ()
90
- if hashlib .sha256 (model_bytes ).hexdigest () != expected_sha256 :
98
+ if compute_sha256 (download_target ) != expected_sha256 :
91
99
raise RuntimeError (
92
100
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
93
101
)
94
102
95
- return model_bytes if in_memory else download_target
103
+ if in_memory :
104
+ with open (download_target , "rb" ) as f :
105
+ return f .read ()
106
+ else :
107
+ return download_target
96
108
97
109
98
110
def available_models () -> List [str ]:
@@ -147,7 +159,7 @@ def load_model(
147
159
with (
148
160
io .BytesIO (checkpoint_file ) if in_memory else open (checkpoint_file , "rb" )
149
161
) as fp :
150
- checkpoint = torch .load (fp , map_location = device )
162
+ checkpoint = torch .load (fp , map_location = device , weights_only = True )
151
163
del checkpoint_file
152
164
153
165
dims = ModelDimensions (** checkpoint ["dims" ])
@@ -157,4 +169,4 @@ def load_model(
157
169
if alignment_heads is not None :
158
170
model .set_alignment_heads (alignment_heads )
159
171
160
- return model .to (device )
172
+ return model .to (device )
0 commit comments