Skip to content

Commit e18016c

Browse files
committed
[sharktank] Llama 3.1 f16 HF import presets
Add registry for import presets and populate it with Llama 3.1 f16 models. Add support for referencing a HF dataset during import. This decouples specification of what needs to be downloaded versus how to import the dataset after it is download. Expand the HF datasets to specify the models files not completely explicitly, but by using filters like `huggingface_hub.snapshot_download`. Next steps are: 1. Make the CI use this new mechanism. 2. Add importation of models with more complicated transformations like quantization.
1 parent 2a5b2a4 commit e18016c

File tree

4 files changed

+226
-10
lines changed

4 files changed

+226
-10
lines changed

sharktank/sharktank/tools/import_hf_dataset_from_hub.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,26 @@
1111

1212
def main(argv: list[str] | None = None):
1313

14-
parser = cli.create_parser(description="Import a Hugging Face dataset.")
15-
cli.add_output_dataset_options(parser)
14+
parser = cli.create_parser(
15+
description=(
16+
"Import a Hugging Face dataset. "
17+
"This includes downloading from the HF hub and transforming the model "
18+
"parameters into an IRPA. "
19+
"The import can either be specified through the various detailed "
20+
"parameters or a preset name can be reference. "
21+
"The HF dataset details can also be reference by a named preregistered dataset. "
22+
"The HF dataset is only concerned with what files need to be downloaded from the hub."
23+
)
24+
)
25+
parser.add_argument(
26+
"--output-irpa-file",
27+
type=Path,
28+
help="IRPA file to save dataset to",
29+
)
1630
parser.add_argument(
1731
"repo_id_or_path",
32+
nargs="?",
33+
default=None,
1834
type=str,
1935
help='Local path to the model or Hugging Face repo id, e.g. "meta-llama/Meta-Llama-3-8B-Instruct"',
2036
)
@@ -36,6 +52,28 @@ def main(argv: list[str] | None = None):
3652
default=None,
3753
help="Subpath inside the subfolder for the model config file. Defaults to config.json.",
3854
)
55+
parser.add_argument(
56+
"--hf-dataset",
57+
type=str,
58+
default=None,
59+
help=(
60+
"A name of a preset HF dataset. "
61+
"This is mutually exclusive with specifying repo id or a model path."
62+
),
63+
)
64+
parser.add_argument(
65+
"--preset",
66+
type=str,
67+
default=None,
68+
help=(
69+
"A name of a preset to import. "
70+
"This is different form a preset HF dataset, which only specifies the HF "
71+
"files to download. "
72+
"The import preset also specifies the import transformations. "
73+
"When using a preset the output directory structure would be relative to "
74+
"the current working directory."
75+
),
76+
)
3977
args = cli.parse(parser, args=argv)
4078

4179
import_hf_dataset_from_hub(
@@ -44,6 +82,8 @@ def main(argv: list[str] | None = None):
4482
subfolder=args.subfolder,
4583
config_subpath=args.config_subpath,
4684
output_irpa_file=args.output_irpa_file,
85+
hf_dataset=args.hf_dataset,
86+
preset=args.preset,
4787
)
4888

4989

sharktank/sharktank/utils/hf.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,17 @@
99
import re
1010
import os
1111
import json
12+
import shutil
1213
import torch
1314
from pathlib import Path
1415

1516
from huggingface_hub import snapshot_download
1617
from sharktank.layers.configs import (
17-
LlamaHParams,
1818
LlamaModelConfig,
1919
is_hugging_face_llama3_config,
2020
)
2121
from sharktank.types import *
22+
from sharktank.utils import verify_exactly_one_is_not_none
2223
from sharktank.utils.functools import compose
2324
from sharktank.utils.logging import get_logger
2425
from sharktank.transforms.dataset import wrap_in_list_if_inference_tensor
@@ -59,6 +60,7 @@ def import_hf_dataset(
5960
target_dtype=None,
6061
tensor_transform: Optional["InferenceTensorTransform"] = None,
6162
metadata_transform: MetadataTransform | None = None,
63+
file_copy_map: dict[PathLike, PathLike] | None = None,
6264
) -> Optional[Dataset]:
6365
import safetensors
6466

@@ -86,23 +88,54 @@ def import_hf_dataset(
8688

8789
theta = Theta(tensors)
8890

91+
if file_copy_map is not None:
92+
for src, dst in file_copy_map.items():
93+
Path(dst).parent.mkdir(parents=True, exist_ok=True)
94+
shutil.copy(src, dst)
95+
8996
dataset = Dataset(props, theta)
9097
if output_irpa_file is not None:
98+
Path(output_irpa_file).parent.mkdir(parents=True, exist_ok=True)
9199
dataset.save(output_irpa_file, io_report_callback=logger.debug)
92100
return dataset
93101

94102

95103
def import_hf_dataset_from_hub(
96-
repo_id_or_path: str,
104+
repo_id_or_path: str | None = None,
97105
*,
98106
revision: str | None = None,
99107
subfolder: str | None = None,
100108
config_subpath: str | None = None,
101109
output_irpa_file: PathLike | None = None,
110+
target_dtype: torch.dtype | None = None,
111+
file_copy_map: dict[PathLike, PathLike] | None = None,
112+
hf_dataset: str | None = None,
113+
preset: str | None = None,
102114
) -> Dataset | None:
103-
model_dir = Path(repo_id_or_path)
104-
if not model_dir.exists():
105-
model_dir = Path(snapshot_download(repo_id=repo_id_or_path, revision=revision))
115+
verify_exactly_one_is_not_none(
116+
repo_id_or_path=repo_id_or_path, preset=preset, hf_dataset=hf_dataset
117+
)
118+
if preset is not None:
119+
return import_hf_dataset_from_hub(**get_dataset_import_preset_kwargs(preset))
120+
121+
if hf_dataset is not None:
122+
from sharktank.utils.hf_datasets import get_dataset
123+
124+
download_result_dict = get_dataset(hf_dataset).download()
125+
downloaded_file_paths = [
126+
p for paths in download_result_dict.values() for p in paths
127+
]
128+
if len(downloaded_file_paths) > 1 or downloaded_file_paths[0].is_file():
129+
assert (
130+
subfolder is None
131+
), "Not robust in determining the model dir if doing a non-single model snapshot download and subfolder is specified."
132+
model_dir = Path(os.path.commonpath([str(p) for p in downloaded_file_paths]))
133+
else:
134+
model_dir = Path(repo_id_or_path)
135+
if not model_dir.exists():
136+
model_dir = Path(
137+
snapshot_download(repo_id=repo_id_or_path, revision=revision)
138+
)
106139

107140
if subfolder is not None:
108141
model_dir /= subfolder
@@ -115,15 +148,73 @@ def import_hf_dataset_from_hub(
115148
for file_name in os.listdir(model_dir)
116149
if (model_dir / file_name).is_file()
117150
]
151+
118152
param_paths = [p for p in file_paths if p.is_file() and p.suffix == ".safetensors"]
119153

154+
if file_copy_map is not None:
155+
file_copy_map = {model_dir / src: dst for src, dst in file_copy_map.items()}
156+
120157
return import_hf_dataset(
121158
config_json_path=config_json_path,
122159
param_paths=param_paths,
123160
output_irpa_file=output_irpa_file,
161+
target_dtype=target_dtype,
162+
file_copy_map=file_copy_map,
124163
)
125164

126165

166+
dataset_import_presets: dict[str, dict[str, Any]] = {}
167+
"""Declarative specification on how to import a HF dataset."""
168+
169+
170+
def register_default_llama_dataset_preset(
171+
name: str,
172+
*,
173+
hf_dataset: str,
174+
output_prefix_path: str,
175+
target_dtype: torch.dtype | None = None,
176+
):
177+
output_prefix_path = Path(output_prefix_path)
178+
dataset_import_presets[name] = {
179+
"hf_dataset": hf_dataset,
180+
"output_irpa_file": output_prefix_path / "model.irpa",
181+
"target_dtype": target_dtype,
182+
"file_copy_map": {
183+
"tokenizer.json": output_prefix_path / "tokenizer.json",
184+
"tokenizer_config.json": output_prefix_path / "tokenizer_config.json",
185+
"LICENSE": output_prefix_path / "LICENSE",
186+
},
187+
}
188+
189+
190+
def register_all_dataset_import_presets():
191+
register_default_llama_dataset_preset(
192+
name="meta_llama3_1_8b_instruct_f16",
193+
hf_dataset="meta-llama/Llama-3.1-8B-Instruct",
194+
output_prefix_path="llama3.1/8b/instruct/f16",
195+
target_dtype=torch.float16,
196+
)
197+
register_default_llama_dataset_preset(
198+
name="meta_llama3_1_70b_instruct_f16",
199+
hf_dataset="meta-llama/Llama-3.1-70B-Instruct",
200+
output_prefix_path="llama3.1/70b/instruct/f16",
201+
target_dtype=torch.float16,
202+
)
203+
register_default_llama_dataset_preset(
204+
name="meta_llama3_1_405b_instruct_f16",
205+
hf_dataset="meta-llama/Llama-3.1-405B-Instruct",
206+
output_prefix_path="llama3.1/405b/instruct/f16",
207+
target_dtype=torch.float16,
208+
)
209+
210+
211+
register_all_dataset_import_presets()
212+
213+
214+
def get_dataset_import_preset_kwargs(preset: str) -> dict[str, Any]:
215+
return dataset_import_presets[preset]
216+
217+
127218
_llama3_hf_to_sharktank_tensor_name_map: dict[str, str] = {
128219
"model.embed_tokens.weight": "token_embd.weight",
129220
"lm_head.weight": "output.weight",

sharktank/sharktank/utils/hf_datasets.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
This can be invoked as a tool in order to fetch a local dataset.
1313
"""
1414

15-
from typing import Dict, Optional, Sequence, Tuple
15+
from typing import Dict, Optional, Sequence, Tuple, Union
1616

1717
import argparse
1818
from dataclasses import dataclass
1919
from pathlib import Path
2020

21-
from huggingface_hub import hf_hub_download
21+
from huggingface_hub import hf_hub_download, snapshot_download
2222

2323

2424
################################################################################
@@ -30,10 +30,19 @@
3030
class RemoteFile:
3131
file_id: str
3232
repo_id: str
33-
filename: str
33+
revision: str | None = None
34+
filename: str | None = None
35+
allow_patterns: list[str] | str | None = None
36+
ignore_patterns: list[str] | str | None = None
3437
extra_filenames: Sequence[str] = ()
3538

3639
def download(self, *, local_dir: Optional[Path] = None) -> list[Path]:
40+
if self.filename is None:
41+
return self._download_snapshot(local_dir=local_dir)
42+
else:
43+
return self._download_files(local_dir=local_dir)
44+
45+
def _download_files(self, *, local_dir: Optional[Path] = None) -> list[Path]:
3746
res = []
3847
res.append(
3948
Path(
@@ -47,22 +56,43 @@ def download(self, *, local_dir: Optional[Path] = None) -> list[Path]:
4756
Path(
4857
hf_hub_download(
4958
repo_id=self.repo_id,
59+
revision=self.revision,
5060
filename=extra_filename,
5161
local_dir=local_dir,
5262
)
5363
)
5464
)
5565
return res
5666

67+
def _download_snapshot(self, *, local_dir: Optional[Path] = None) -> list[Path]:
68+
return [
69+
Path(
70+
snapshot_download(
71+
repo_id=self.repo_id,
72+
revision=self.revision,
73+
allow_patterns=self.allow_patterns,
74+
ignore_patterns=self.ignore_patterns,
75+
)
76+
)
77+
]
78+
5779

5880
@dataclass
5981
class Dataset:
6082
name: str
6183
files: Tuple[RemoteFile]
84+
revision: str | None = None
6285

6386
def __post_init__(self):
6487
if self.name in ALL_DATASETS:
6588
raise KeyError(f"Duplicate dataset name '{self.name}'")
89+
90+
# Set the revision of all files if a Dataset revision is provided.
91+
if self.revision is not None:
92+
for f in self.files:
93+
assert f.revision is None
94+
f.revision = self.revision
95+
6696
ALL_DATASETS[self.name] = self
6797

6898
def alias_to(self, to_name: str) -> "Dataset":
@@ -93,6 +123,42 @@ def alias_dataset(from_name: str, to_name: str):
93123
# Dataset definitions
94124
################################################################################
95125

126+
Dataset(
127+
"meta-llama/Llama-3.1-8B-Instruct",
128+
(
129+
RemoteFile(
130+
"all",
131+
repo_id="meta-llama/Llama-3.1-8B-Instruct",
132+
ignore_patterns="original/*",
133+
),
134+
),
135+
revision="0e9e39f249a16976918f6564b8830bc894c89659",
136+
)
137+
138+
Dataset(
139+
"meta-llama/Llama-3.1-70B-Instruct",
140+
(
141+
RemoteFile(
142+
"all",
143+
repo_id="meta-llama/Llama-3.1-70B-Instruct",
144+
ignore_patterns="original/*",
145+
),
146+
),
147+
revision="1605565b47bb9346c5515c34102e054115b4f98b",
148+
)
149+
150+
Dataset(
151+
"meta-llama/Llama-3.1-405B-Instruct",
152+
(
153+
RemoteFile(
154+
"all",
155+
repo_id="meta-llama/Llama-3.1-405B-Instruct",
156+
ignore_patterns="original/*",
157+
),
158+
),
159+
revision="be673f326cab4cd22ccfef76109faf68e41aa5f1",
160+
)
161+
96162
Dataset(
97163
"SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
98164
(

sharktank/tests/models/llama/test_llama.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
from sharktank.models.llm.llm import PagedLlmModelV1
1717
from sharktank.models.llama.toy_llama import generate
18+
from sharktank.utils import chdir
1819
from sharktank.utils.export_artifacts import IreeCompileException
1920
from sharktank.utils.testing import (
2021
is_mi300x,
@@ -210,3 +211,21 @@ def test_import_llama3_8B_instruct(tmp_path: Path):
210211
]
211212
)
212213
assert irpa_path.exists()
214+
215+
216+
@pytest.mark.expensive
217+
def test_import_llama3_8B_instruct_from_preset(tmp_path: Path):
218+
from sharktank.tools.import_hf_dataset_from_hub import main
219+
220+
irpa_path = tmp_path / "llama3.1/8b/instruct/f16/model.irpa"
221+
tokenizer_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer.json"
222+
tokenizer_config_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer_config.json"
223+
with chdir(tmp_path):
224+
main(
225+
[
226+
"--preset=meta_llama3_1_8b_instruct_f16",
227+
]
228+
)
229+
assert irpa_path.exists()
230+
assert tokenizer_path.exists()
231+
assert tokenizer_config_path.exists()

0 commit comments

Comments
 (0)