9
9
import re
10
10
import os
11
11
import json
12
+ import shutil
12
13
import torch
13
14
from pathlib import Path
14
15
15
16
from huggingface_hub import snapshot_download
16
17
from sharktank .layers .configs import (
17
- LlamaHParams ,
18
18
LlamaModelConfig ,
19
19
is_hugging_face_llama3_config ,
20
20
)
21
21
from sharktank .types import *
22
+ from sharktank .utils import verify_exactly_one_is_not_none
22
23
from sharktank .utils .functools import compose
23
24
from sharktank .utils .logging import get_logger
24
25
from sharktank .transforms .dataset import wrap_in_list_if_inference_tensor
@@ -59,6 +60,7 @@ def import_hf_dataset(
59
60
target_dtype = None ,
60
61
tensor_transform : Optional ["InferenceTensorTransform" ] = None ,
61
62
metadata_transform : MetadataTransform | None = None ,
63
+ file_copy_map : dict [PathLike , PathLike ] | None = None ,
62
64
) -> Optional [Dataset ]:
63
65
import safetensors
64
66
@@ -86,23 +88,54 @@ def import_hf_dataset(
86
88
87
89
theta = Theta (tensors )
88
90
91
+ if file_copy_map is not None :
92
+ for src , dst in file_copy_map .items ():
93
+ Path (dst ).parent .mkdir (parents = True , exist_ok = True )
94
+ shutil .copy (src , dst )
95
+
89
96
dataset = Dataset (props , theta )
90
97
if output_irpa_file is not None :
98
+ Path (output_irpa_file ).parent .mkdir (parents = True , exist_ok = True )
91
99
dataset .save (output_irpa_file , io_report_callback = logger .debug )
92
100
return dataset
93
101
94
102
95
103
def import_hf_dataset_from_hub (
96
- repo_id_or_path : str ,
104
+ repo_id_or_path : str | None = None ,
97
105
* ,
98
106
revision : str | None = None ,
99
107
subfolder : str | None = None ,
100
108
config_subpath : str | None = None ,
101
109
output_irpa_file : PathLike | None = None ,
110
+ target_dtype : torch .dtype | None = None ,
111
+ file_copy_map : dict [PathLike , PathLike ] | None = None ,
112
+ hf_dataset : str | None = None ,
113
+ preset : str | None = None ,
102
114
) -> Dataset | None :
103
- model_dir = Path (repo_id_or_path )
104
- if not model_dir .exists ():
105
- model_dir = Path (snapshot_download (repo_id = repo_id_or_path , revision = revision ))
115
+ verify_exactly_one_is_not_none (
116
+ repo_id_or_path = repo_id_or_path , preset = preset , hf_dataset = hf_dataset
117
+ )
118
+ if preset is not None :
119
+ return import_hf_dataset_from_hub (** get_dataset_import_preset_kwargs (preset ))
120
+
121
+ if hf_dataset is not None :
122
+ from sharktank .utils .hf_datasets import get_dataset
123
+
124
+ download_result_dict = get_dataset (hf_dataset ).download ()
125
+ downloaded_file_paths = [
126
+ p for paths in download_result_dict .values () for p in paths
127
+ ]
128
+ if len (downloaded_file_paths ) > 1 or downloaded_file_paths [0 ].is_file ():
129
+ assert (
130
+ subfolder is None
131
+ ), "Not robust in determining the model dir if doing a non-single model snapshot download and subfolder is specified."
132
+ model_dir = Path (os .path .commonpath ([str (p ) for p in downloaded_file_paths ]))
133
+ else :
134
+ model_dir = Path (repo_id_or_path )
135
+ if not model_dir .exists ():
136
+ model_dir = Path (
137
+ snapshot_download (repo_id = repo_id_or_path , revision = revision )
138
+ )
106
139
107
140
if subfolder is not None :
108
141
model_dir /= subfolder
@@ -115,15 +148,73 @@ def import_hf_dataset_from_hub(
115
148
for file_name in os .listdir (model_dir )
116
149
if (model_dir / file_name ).is_file ()
117
150
]
151
+
118
152
param_paths = [p for p in file_paths if p .is_file () and p .suffix == ".safetensors" ]
119
153
154
+ if file_copy_map is not None :
155
+ file_copy_map = {model_dir / src : dst for src , dst in file_copy_map .items ()}
156
+
120
157
return import_hf_dataset (
121
158
config_json_path = config_json_path ,
122
159
param_paths = param_paths ,
123
160
output_irpa_file = output_irpa_file ,
161
+ target_dtype = target_dtype ,
162
+ file_copy_map = file_copy_map ,
124
163
)
125
164
126
165
166
+ dataset_import_presets : dict [str , dict [str , Any ]] = {}
167
+ """Declarative specification on how to import a HF dataset."""
168
+
169
+
170
+ def register_default_llama_dataset_preset (
171
+ name : str ,
172
+ * ,
173
+ hf_dataset : str ,
174
+ output_prefix_path : str ,
175
+ target_dtype : torch .dtype | None = None ,
176
+ ):
177
+ output_prefix_path = Path (output_prefix_path )
178
+ dataset_import_presets [name ] = {
179
+ "hf_dataset" : hf_dataset ,
180
+ "output_irpa_file" : output_prefix_path / "model.irpa" ,
181
+ "target_dtype" : target_dtype ,
182
+ "file_copy_map" : {
183
+ "tokenizer.json" : output_prefix_path / "tokenizer.json" ,
184
+ "tokenizer_config.json" : output_prefix_path / "tokenizer_config.json" ,
185
+ "LICENSE" : output_prefix_path / "LICENSE" ,
186
+ },
187
+ }
188
+
189
+
190
+ def register_all_dataset_import_presets ():
191
+ register_default_llama_dataset_preset (
192
+ name = "meta_llama3_1_8b_instruct_f16" ,
193
+ hf_dataset = "meta-llama/Llama-3.1-8B-Instruct" ,
194
+ output_prefix_path = "llama3.1/8b/instruct/f16" ,
195
+ target_dtype = torch .float16 ,
196
+ )
197
+ register_default_llama_dataset_preset (
198
+ name = "meta_llama3_1_70b_instruct_f16" ,
199
+ hf_dataset = "meta-llama/Llama-3.1-70B-Instruct" ,
200
+ output_prefix_path = "llama3.1/70b/instruct/f16" ,
201
+ target_dtype = torch .float16 ,
202
+ )
203
+ register_default_llama_dataset_preset (
204
+ name = "meta_llama3_1_405b_instruct_f16" ,
205
+ hf_dataset = "meta-llama/Llama-3.1-405B-Instruct" ,
206
+ output_prefix_path = "llama3.1/405b/instruct/f16" ,
207
+ target_dtype = torch .float16 ,
208
+ )
209
+
210
+
211
+ register_all_dataset_import_presets ()
212
+
213
+
214
+ def get_dataset_import_preset_kwargs (preset : str ) -> dict [str , Any ]:
215
+ return dataset_import_presets [preset ]
216
+
217
+
127
218
_llama3_hf_to_sharktank_tensor_name_map : dict [str , str ] = {
128
219
"model.embed_tokens.weight" : "token_embd.weight" ,
129
220
"lm_head.weight" : "output.weight" ,
0 commit comments