diff --git a/sharktank/sharktank/tools/import_hf_dataset_from_hub.py b/sharktank/sharktank/tools/import_hf_dataset_from_hub.py index 58c3e116f7..46cec117ca 100644 --- a/sharktank/sharktank/tools/import_hf_dataset_from_hub.py +++ b/sharktank/sharktank/tools/import_hf_dataset_from_hub.py @@ -11,10 +11,26 @@ def main(argv: list[str] | None = None): - parser = cli.create_parser(description="Import a Hugging Face dataset.") - cli.add_output_dataset_options(parser) + parser = cli.create_parser( + description=( + "Import a Hugging Face dataset. " + "This includes downloading from the HF hub and transforming the model " + "parameters into an IRPA. " + "The import can either be specified through the various detailed " + "parameters or a preset name can be reference. " + "The HF dataset details can also be reference by a named preregistered dataset. " + "The HF dataset is only concerned with what files need to be downloaded from the hub." + ) + ) + parser.add_argument( + "--output-irpa-file", + type=Path, + help="IRPA file to save dataset to", + ) parser.add_argument( "repo_id_or_path", + nargs="?", + default=None, type=str, help='Local path to the model or Hugging Face repo id, e.g. "meta-llama/Meta-Llama-3-8B-Instruct"', ) @@ -36,6 +52,28 @@ def main(argv: list[str] | None = None): default=None, help="Subpath inside the subfolder for the model config file. Defaults to config.json.", ) + parser.add_argument( + "--hf-dataset", + type=str, + default=None, + help=( + "A name of a preset HF dataset. " + "This is mutually exclusive with specifying repo id or a model path." + ), + ) + parser.add_argument( + "--preset", + type=str, + default=None, + help=( + "A name of a preset to import. " + "This is different form a preset HF dataset, which only specifies the HF " + "files to download. " + "The import preset also specifies the import transformations. " + "When using a preset the output directory structure would be relative to " + "the current working directory." + ), + ) args = cli.parse(parser, args=argv) import_hf_dataset_from_hub( @@ -44,6 +82,8 @@ def main(argv: list[str] | None = None): subfolder=args.subfolder, config_subpath=args.config_subpath, output_irpa_file=args.output_irpa_file, + hf_dataset=args.hf_dataset, + preset=args.preset, ) diff --git a/sharktank/sharktank/utils/hf.py b/sharktank/sharktank/utils/hf.py index 051c962465..39f8da34f9 100644 --- a/sharktank/sharktank/utils/hf.py +++ b/sharktank/sharktank/utils/hf.py @@ -9,16 +9,17 @@ import re import os import json +import shutil import torch from pathlib import Path from huggingface_hub import snapshot_download from sharktank.layers.configs import ( - LlamaHParams, LlamaModelConfig, is_hugging_face_llama3_config, ) from sharktank.types import * +from sharktank.utils import verify_exactly_one_is_not_none from sharktank.utils.functools import compose from sharktank.utils.logging import get_logger from sharktank.transforms.dataset import wrap_in_list_if_inference_tensor @@ -59,6 +60,7 @@ def import_hf_dataset( target_dtype=None, tensor_transform: Optional["InferenceTensorTransform"] = None, metadata_transform: MetadataTransform | None = None, + file_copy_map: dict[PathLike, PathLike] | None = None, ) -> Optional[Dataset]: import safetensors @@ -86,23 +88,54 @@ def import_hf_dataset( theta = Theta(tensors) + if file_copy_map is not None: + for src, dst in file_copy_map.items(): + Path(dst).parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src, dst) + dataset = Dataset(props, theta) if output_irpa_file is not None: + Path(output_irpa_file).parent.mkdir(parents=True, exist_ok=True) dataset.save(output_irpa_file, io_report_callback=logger.debug) return dataset def import_hf_dataset_from_hub( - repo_id_or_path: str, + repo_id_or_path: str | None = None, *, revision: str | None = None, subfolder: str | None = None, config_subpath: str | None = None, output_irpa_file: PathLike | None = None, + target_dtype: torch.dtype | None = None, + file_copy_map: dict[PathLike, PathLike] | None = None, + hf_dataset: str | None = None, + preset: str | None = None, ) -> Dataset | None: - model_dir = Path(repo_id_or_path) - if not model_dir.exists(): - model_dir = Path(snapshot_download(repo_id=repo_id_or_path, revision=revision)) + verify_exactly_one_is_not_none( + repo_id_or_path=repo_id_or_path, preset=preset, hf_dataset=hf_dataset + ) + if preset is not None: + return import_hf_dataset_from_hub(**get_dataset_import_preset_kwargs(preset)) + + if hf_dataset is not None: + from sharktank.utils.hf_datasets import get_dataset + + download_result_dict = get_dataset(hf_dataset).download() + downloaded_file_paths = [ + p for paths in download_result_dict.values() for p in paths + ] + if len(downloaded_file_paths) > 1 or downloaded_file_paths[0].is_file(): + assert ( + subfolder is None + ), "Not robust in determining the model dir if doing a non-single model snapshot download and subfolder is specified." + model_dir = Path(os.path.commonpath([str(p) for p in downloaded_file_paths])) + else: + model_dir = Path(repo_id_or_path) + if not model_dir.exists(): + model_dir = Path( + snapshot_download(repo_id=repo_id_or_path, revision=revision) + ) if subfolder is not None: model_dir /= subfolder @@ -115,15 +148,73 @@ def import_hf_dataset_from_hub( for file_name in os.listdir(model_dir) if (model_dir / file_name).is_file() ] + param_paths = [p for p in file_paths if p.is_file() and p.suffix == ".safetensors"] + if file_copy_map is not None: + file_copy_map = {model_dir / src: dst for src, dst in file_copy_map.items()} + return import_hf_dataset( config_json_path=config_json_path, param_paths=param_paths, output_irpa_file=output_irpa_file, + target_dtype=target_dtype, + file_copy_map=file_copy_map, ) +dataset_import_presets: dict[str, dict[str, Any]] = {} +"""Declarative specification on how to import a HF dataset.""" + + +def register_default_llama_dataset_preset( + name: str, + *, + hf_dataset: str, + output_prefix_path: str, + target_dtype: torch.dtype | None = None, +): + output_prefix_path = Path(output_prefix_path) + dataset_import_presets[name] = { + "hf_dataset": hf_dataset, + "output_irpa_file": output_prefix_path / "model.irpa", + "target_dtype": target_dtype, + "file_copy_map": { + "tokenizer.json": output_prefix_path / "tokenizer.json", + "tokenizer_config.json": output_prefix_path / "tokenizer_config.json", + "LICENSE": output_prefix_path / "LICENSE", + }, + } + + +def register_all_dataset_import_presets(): + register_default_llama_dataset_preset( + name="meta_llama3_1_8b_instruct_f16", + hf_dataset="meta-llama/Llama-3.1-8B-Instruct", + output_prefix_path="llama3.1/8b/instruct/f16", + target_dtype=torch.float16, + ) + register_default_llama_dataset_preset( + name="meta_llama3_1_70b_instruct_f16", + hf_dataset="meta-llama/Llama-3.1-70B-Instruct", + output_prefix_path="llama3.1/70b/instruct/f16", + target_dtype=torch.float16, + ) + register_default_llama_dataset_preset( + name="meta_llama3_1_405b_instruct_f16", + hf_dataset="meta-llama/Llama-3.1-405B-Instruct", + output_prefix_path="llama3.1/405b/instruct/f16", + target_dtype=torch.float16, + ) + + +register_all_dataset_import_presets() + + +def get_dataset_import_preset_kwargs(preset: str) -> dict[str, Any]: + return dataset_import_presets[preset] + + _llama3_hf_to_sharktank_tensor_name_map: dict[str, str] = { "model.embed_tokens.weight": "token_embd.weight", "lm_head.weight": "output.weight", diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 10476a3f4b..14d89f54df 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -12,13 +12,13 @@ This can be invoked as a tool in order to fetch a local dataset. """ -from typing import Dict, Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple, Union import argparse from dataclasses import dataclass from pathlib import Path -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download ################################################################################ @@ -30,10 +30,19 @@ class RemoteFile: file_id: str repo_id: str - filename: str + revision: str | None = None + filename: str | None = None + allow_patterns: list[str] | str | None = None + ignore_patterns: list[str] | str | None = None extra_filenames: Sequence[str] = () def download(self, *, local_dir: Optional[Path] = None) -> list[Path]: + if self.filename is None: + return self._download_snapshot(local_dir=local_dir) + else: + return self._download_files(local_dir=local_dir) + + def _download_files(self, *, local_dir: Optional[Path] = None) -> list[Path]: res = [] res.append( Path( @@ -47,6 +56,7 @@ def download(self, *, local_dir: Optional[Path] = None) -> list[Path]: Path( hf_hub_download( repo_id=self.repo_id, + revision=self.revision, filename=extra_filename, local_dir=local_dir, ) @@ -54,15 +64,35 @@ def download(self, *, local_dir: Optional[Path] = None) -> list[Path]: ) return res + def _download_snapshot(self, *, local_dir: Optional[Path] = None) -> list[Path]: + return [ + Path( + snapshot_download( + repo_id=self.repo_id, + revision=self.revision, + allow_patterns=self.allow_patterns, + ignore_patterns=self.ignore_patterns, + ) + ) + ] + @dataclass class Dataset: name: str files: Tuple[RemoteFile] + revision: str | None = None def __post_init__(self): if self.name in ALL_DATASETS: raise KeyError(f"Duplicate dataset name '{self.name}'") + + # Set the revision of all files if a Dataset revision is provided. + if self.revision is not None: + for f in self.files: + assert f.revision is None + f.revision = self.revision + ALL_DATASETS[self.name] = self def alias_to(self, to_name: str) -> "Dataset": @@ -93,6 +123,42 @@ def alias_dataset(from_name: str, to_name: str): # Dataset definitions ################################################################################ +Dataset( + "meta-llama/Llama-3.1-8B-Instruct", + ( + RemoteFile( + "all", + repo_id="meta-llama/Llama-3.1-8B-Instruct", + ignore_patterns="original/*", + ), + ), + revision="0e9e39f249a16976918f6564b8830bc894c89659", +) + +Dataset( + "meta-llama/Llama-3.1-70B-Instruct", + ( + RemoteFile( + "all", + repo_id="meta-llama/Llama-3.1-70B-Instruct", + ignore_patterns="original/*", + ), + ), + revision="1605565b47bb9346c5515c34102e054115b4f98b", +) + +Dataset( + "meta-llama/Llama-3.1-405B-Instruct", + ( + RemoteFile( + "all", + repo_id="meta-llama/Llama-3.1-405B-Instruct", + ignore_patterns="original/*", + ), + ), + revision="be673f326cab4cd22ccfef76109faf68e41aa5f1", +) + Dataset( "SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF", ( diff --git a/sharktank/tests/models/llama/test_llama.py b/sharktank/tests/models/llama/test_llama.py index 387ccd3d0f..30712af272 100644 --- a/sharktank/tests/models/llama/test_llama.py +++ b/sharktank/tests/models/llama/test_llama.py @@ -15,6 +15,7 @@ from pathlib import Path from sharktank.models.llm.llm import PagedLlmModelV1 from sharktank.models.llama.toy_llama import generate +from sharktank.utils import chdir from sharktank.utils.export_artifacts import IreeCompileException from sharktank.utils.testing import ( is_mi300x, @@ -210,3 +211,21 @@ def test_import_llama3_8B_instruct(tmp_path: Path): ] ) assert irpa_path.exists() + + +@pytest.mark.expensive +def test_import_llama3_8B_instruct_from_preset(tmp_path: Path): + from sharktank.tools.import_hf_dataset_from_hub import main + + irpa_path = tmp_path / "llama3.1/8b/instruct/f16/model.irpa" + tokenizer_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer.json" + tokenizer_config_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer_config.json" + with chdir(tmp_path): + main( + [ + "--preset=meta_llama3_1_8b_instruct_f16", + ] + ) + assert irpa_path.exists() + assert tokenizer_path.exists() + assert tokenizer_config_path.exists()