@@ -27,12 +27,12 @@ def read_file(blobpath: str) -> bytes:
2727 return resp .content
2828
2929
30- def check_hash (data : bytes , hash : str ) -> bool :
31- data_hash = hashlib .sha256 (data ).hexdigest ()
32- return data_hash == hash
30+ def check_hash (data : bytes , expected_hash : str ) -> bool :
31+ actual_hash = hashlib .sha256 (data ).hexdigest ()
32+ return actual_hash == expected_hash
3333
3434
35- def read_file_cached (blobpath : str , expected_hash : Optional [str ]= None ) -> bytes :
35+ def read_file_cached (blobpath : str , expected_hash : Optional [str ] = None ) -> bytes :
3636 user_specified_cache = True
3737 if "TIKTOKEN_CACHE_DIR" in os .environ :
3838 cache_dir = os .environ ["TIKTOKEN_CACHE_DIR" ]
@@ -52,13 +52,15 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
5252 if os .path .exists (cache_path ):
5353 with open (cache_path , "rb" ) as f :
5454 data = f .read ()
55- if expected_hash and not check_hash (data , expected_hash ):
56- raise ValueError (
57- f"Hash mismatch for cached data from { blobpath } (expected { expected_hash } ). "
58- f"Please delete the cache file at { cache_path } and try again."
59- )
55+ if expected_hash is None or check_hash (data , expected_hash ):
6056 return data
6157
58+ # the cached file does not match the hash, remove it and re-fetch
59+ try :
60+ os .remove (cache_path )
61+ except OSError :
62+ pass
63+
6264 contents = read_file (blobpath )
6365 if expected_hash and not check_hash (contents , expected_hash ):
6466 raise ValueError (
@@ -81,7 +83,10 @@ def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
8183
8284
8385def data_gym_to_mergeable_bpe_ranks (
84- vocab_bpe_file : str , encoder_json_file : str , vocab_bpe_hash : Optional [str ]= None , encoder_json_hash : Optional [str ]= None
86+ vocab_bpe_file : str ,
87+ encoder_json_file : str ,
88+ vocab_bpe_hash : Optional [str ] = None ,
89+ encoder_json_hash : Optional [str ] = None ,
8590) -> dict [bytes , int ]:
8691 # NB: do not add caching to this function
8792 rank_to_intbyte = [b for b in range (2 ** 8 ) if chr (b ).isprintable () and chr (b ) != " " ]
@@ -135,7 +140,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
135140 f .write (base64 .b64encode (token ) + b" " + str (rank ).encode () + b"\n " )
136141
137142
138- def load_tiktoken_bpe (tiktoken_bpe_file : str , expected_hash : Optional [str ]= None ) -> dict [bytes , int ]:
143+ def load_tiktoken_bpe (
144+ tiktoken_bpe_file : str , expected_hash : Optional [str ] = None
145+ ) -> dict [bytes , int ]:
139146 # NB: do not add caching to this function
140147 contents = read_file_cached (tiktoken_bpe_file , expected_hash )
141148 return {
0 commit comments