Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions sharktank/sharktank/tools/import_hf_dataset_from_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,26 @@

def main(argv: list[str] | None = None):

parser = cli.create_parser(description="Import a Hugging Face dataset.")
cli.add_output_dataset_options(parser)
parser = cli.create_parser(
description=(
"Import a Hugging Face dataset. "
"This includes downloading from the HF hub and transforming the model "
"parameters into an IRPA. "
"The import can either be specified through the various detailed "
"parameters or a preset name can be reference. "
"The HF dataset details can also be reference by a named preregistered dataset. "
"The HF dataset is only concerned with what files need to be downloaded from the hub."
)
)
parser.add_argument(
"--output-irpa-file",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hesitate to add this flag as we currently can not directly consume it. sharktank expects naming conventions from gguf format, not huggingface

Copy link
Contributor Author

@sogartar sogartar Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are we supposed to consume it in GGUF format? The converter would do on-the-fly conversion to our format which is derived from GGUF. We save only in IRPA. We can't save in GGUF, we can only read it as sharktank.types.Dataset.

type=Path,
help="IRPA file to save dataset to",
)
parser.add_argument(
"repo_id_or_path",
nargs="?",
default=None,
type=str,
help='Local path to the model or Hugging Face repo id, e.g. "meta-llama/Meta-Llama-3-8B-Instruct"',
)
Expand All @@ -36,6 +52,28 @@ def main(argv: list[str] | None = None):
default=None,
help="Subpath inside the subfolder for the model config file. Defaults to config.json.",
)
parser.add_argument(
"--hf-dataset",
type=str,
default=None,
help=(
"A name of a preset HF dataset. "
"This is mutually exclusive with specifying repo id or a model path."
),
)
parser.add_argument(
"--preset",
type=str,
default=None,
help=(
"A name of a preset to import. "
"This is different form a preset HF dataset, which only specifies the HF "
"files to download. "
"The import preset also specifies the import transformations. "
"When using a preset the output directory structure would be relative to "
"the current working directory."
),
)
args = cli.parse(parser, args=argv)

import_hf_dataset_from_hub(
Expand All @@ -44,6 +82,8 @@ def main(argv: list[str] | None = None):
subfolder=args.subfolder,
config_subpath=args.config_subpath,
output_irpa_file=args.output_irpa_file,
hf_dataset=args.hf_dataset,
preset=args.preset,
)


Expand Down
101 changes: 96 additions & 5 deletions sharktank/sharktank/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import re
import os
import json
import shutil
import torch
from pathlib import Path

from huggingface_hub import snapshot_download
from sharktank.layers.configs import (
LlamaHParams,
LlamaModelConfig,
is_hugging_face_llama3_config,
)
from sharktank.types import *
from sharktank.utils import verify_exactly_one_is_not_none
from sharktank.utils.functools import compose
from sharktank.utils.logging import get_logger
from sharktank.transforms.dataset import wrap_in_list_if_inference_tensor
Expand Down Expand Up @@ -59,6 +60,7 @@ def import_hf_dataset(
target_dtype=None,
tensor_transform: Optional["InferenceTensorTransform"] = None,
metadata_transform: MetadataTransform | None = None,
file_copy_map: dict[PathLike, PathLike] | None = None,
) -> Optional[Dataset]:
import safetensors

Expand Down Expand Up @@ -86,23 +88,54 @@ def import_hf_dataset(

theta = Theta(tensors)

if file_copy_map is not None:
for src, dst in file_copy_map.items():
Path(dst).parent.mkdir(parents=True, exist_ok=True)
shutil.copy(src, dst)

dataset = Dataset(props, theta)
if output_irpa_file is not None:
Path(output_irpa_file).parent.mkdir(parents=True, exist_ok=True)
dataset.save(output_irpa_file, io_report_callback=logger.debug)
return dataset


def import_hf_dataset_from_hub(
repo_id_or_path: str,
repo_id_or_path: str | None = None,
*,
revision: str | None = None,
subfolder: str | None = None,
config_subpath: str | None = None,
output_irpa_file: PathLike | None = None,
target_dtype: torch.dtype | None = None,
file_copy_map: dict[PathLike, PathLike] | None = None,
hf_dataset: str | None = None,
preset: str | None = None,
) -> Dataset | None:
model_dir = Path(repo_id_or_path)
if not model_dir.exists():
model_dir = Path(snapshot_download(repo_id=repo_id_or_path, revision=revision))
verify_exactly_one_is_not_none(
repo_id_or_path=repo_id_or_path, preset=preset, hf_dataset=hf_dataset
)
if preset is not None:
return import_hf_dataset_from_hub(**get_dataset_import_preset_kwargs(preset))

if hf_dataset is not None:
from sharktank.utils.hf_datasets import get_dataset

download_result_dict = get_dataset(hf_dataset).download()
downloaded_file_paths = [
p for paths in download_result_dict.values() for p in paths
]
if len(downloaded_file_paths) > 1 or downloaded_file_paths[0].is_file():
assert (
subfolder is None
), "Not robust in determining the model dir if doing a non-single model snapshot download and subfolder is specified."
model_dir = Path(os.path.commonpath([str(p) for p in downloaded_file_paths]))
else:
model_dir = Path(repo_id_or_path)
if not model_dir.exists():
model_dir = Path(
snapshot_download(repo_id=repo_id_or_path, revision=revision)
)

if subfolder is not None:
model_dir /= subfolder
Expand All @@ -115,15 +148,73 @@ def import_hf_dataset_from_hub(
for file_name in os.listdir(model_dir)
if (model_dir / file_name).is_file()
]

param_paths = [p for p in file_paths if p.is_file() and p.suffix == ".safetensors"]

if file_copy_map is not None:
file_copy_map = {model_dir / src: dst for src, dst in file_copy_map.items()}

return import_hf_dataset(
config_json_path=config_json_path,
param_paths=param_paths,
output_irpa_file=output_irpa_file,
target_dtype=target_dtype,
file_copy_map=file_copy_map,
)


dataset_import_presets: dict[str, dict[str, Any]] = {}
"""Declarative specification on how to import a HF dataset."""


def register_default_llama_dataset_preset(
name: str,
*,
hf_dataset: str,
output_prefix_path: str,
target_dtype: torch.dtype | None = None,
):
output_prefix_path = Path(output_prefix_path)
dataset_import_presets[name] = {
"hf_dataset": hf_dataset,
"output_irpa_file": output_prefix_path / "model.irpa",
"target_dtype": target_dtype,
"file_copy_map": {
"tokenizer.json": output_prefix_path / "tokenizer.json",
"tokenizer_config.json": output_prefix_path / "tokenizer_config.json",
"LICENSE": output_prefix_path / "LICENSE",
},
}


def register_all_dataset_import_presets():
register_default_llama_dataset_preset(
name="meta_llama3_1_8b_instruct_f16",
hf_dataset="meta-llama/Llama-3.1-8B-Instruct",
output_prefix_path="llama3.1/8b/instruct/f16",
target_dtype=torch.float16,
)
register_default_llama_dataset_preset(
name="meta_llama3_1_70b_instruct_f16",
hf_dataset="meta-llama/Llama-3.1-70B-Instruct",
output_prefix_path="llama3.1/70b/instruct/f16",
target_dtype=torch.float16,
)
register_default_llama_dataset_preset(
name="meta_llama3_1_405b_instruct_f16",
hf_dataset="meta-llama/Llama-3.1-405B-Instruct",
output_prefix_path="llama3.1/405b/instruct/f16",
target_dtype=torch.float16,
)


register_all_dataset_import_presets()


def get_dataset_import_preset_kwargs(preset: str) -> dict[str, Any]:
return dataset_import_presets[preset]


_llama3_hf_to_sharktank_tensor_name_map: dict[str, str] = {
"model.embed_tokens.weight": "token_embd.weight",
"lm_head.weight": "output.weight",
Expand Down
72 changes: 69 additions & 3 deletions sharktank/sharktank/utils/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
This can be invoked as a tool in order to fetch a local dataset.
"""

from typing import Dict, Optional, Sequence, Tuple
from typing import Dict, Optional, Sequence, Tuple, Union

import argparse
from dataclasses import dataclass
from pathlib import Path

from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, snapshot_download


################################################################################
Expand All @@ -30,10 +30,19 @@
class RemoteFile:
file_id: str
repo_id: str
filename: str
revision: str | None = None
filename: str | None = None
allow_patterns: list[str] | str | None = None
ignore_patterns: list[str] | str | None = None
extra_filenames: Sequence[str] = ()

def download(self, *, local_dir: Optional[Path] = None) -> list[Path]:
if self.filename is None:
return self._download_snapshot(local_dir=local_dir)
else:
return self._download_files(local_dir=local_dir)

def _download_files(self, *, local_dir: Optional[Path] = None) -> list[Path]:
res = []
res.append(
Path(
Expand All @@ -47,22 +56,43 @@ def download(self, *, local_dir: Optional[Path] = None) -> list[Path]:
Path(
hf_hub_download(
repo_id=self.repo_id,
revision=self.revision,
filename=extra_filename,
local_dir=local_dir,
)
)
)
return res

def _download_snapshot(self, *, local_dir: Optional[Path] = None) -> list[Path]:
return [
Path(
snapshot_download(
repo_id=self.repo_id,
revision=self.revision,
allow_patterns=self.allow_patterns,
ignore_patterns=self.ignore_patterns,
)
)
]


@dataclass
class Dataset:
name: str
files: Tuple[RemoteFile]
revision: str | None = None

def __post_init__(self):
if self.name in ALL_DATASETS:
raise KeyError(f"Duplicate dataset name '{self.name}'")

# Set the revision of all files if a Dataset revision is provided.
if self.revision is not None:
for f in self.files:
assert f.revision is None
f.revision = self.revision

ALL_DATASETS[self.name] = self

def alias_to(self, to_name: str) -> "Dataset":
Expand Down Expand Up @@ -93,6 +123,42 @@ def alias_dataset(from_name: str, to_name: str):
# Dataset definitions
################################################################################

Dataset(
"meta-llama/Llama-3.1-8B-Instruct",
(
RemoteFile(
"all",
repo_id="meta-llama/Llama-3.1-8B-Instruct",
ignore_patterns="original/*",
),
),
revision="0e9e39f249a16976918f6564b8830bc894c89659",
)

Dataset(
"meta-llama/Llama-3.1-70B-Instruct",
(
RemoteFile(
"all",
repo_id="meta-llama/Llama-3.1-70B-Instruct",
ignore_patterns="original/*",
),
),
revision="1605565b47bb9346c5515c34102e054115b4f98b",
)

Dataset(
"meta-llama/Llama-3.1-405B-Instruct",
(
RemoteFile(
"all",
repo_id="meta-llama/Llama-3.1-405B-Instruct",
ignore_patterns="original/*",
),
),
revision="be673f326cab4cd22ccfef76109faf68e41aa5f1",
)

Dataset(
"SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
(
Expand Down
19 changes: 19 additions & 0 deletions sharktank/tests/models/llama/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from sharktank.models.llm.llm import PagedLlmModelV1
from sharktank.models.llama.toy_llama import generate
from sharktank.utils import chdir
from sharktank.utils.export_artifacts import IreeCompileException
from sharktank.utils.testing import (
is_mi300x,
Expand Down Expand Up @@ -210,3 +211,21 @@ def test_import_llama3_8B_instruct(tmp_path: Path):
]
)
assert irpa_path.exists()


@pytest.mark.expensive
def test_import_llama3_8B_instruct_from_preset(tmp_path: Path):
from sharktank.tools.import_hf_dataset_from_hub import main

irpa_path = tmp_path / "llama3.1/8b/instruct/f16/model.irpa"
tokenizer_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer.json"
tokenizer_config_path = tmp_path / "llama3.1/8b/instruct/f16/tokenizer_config.json"
with chdir(tmp_path):
main(
[
"--preset=meta_llama3_1_8b_instruct_f16",
]
)
assert irpa_path.exists()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should verify these files in some form instead of seeing that they exist in case of a bad download or something. Maybe have an md5sum that we can compare?

Copy link
Contributor Author

@sogartar sogartar Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible that it fails silently in such a way. It will be pretty sad that the HF hub package would fail there as robust downloading I would assume is a major goal.
I actually have not decided yet what should be our attitude towards changing IRPA files. Should we force a strict manual file hash change? Meaning that the we make explicit assumption that it does not change.
It may be a problem for example if we add another filed to the model config, which would change the IRPA metadata and its hash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for now. But like Ian pointed out, if there are more reliable ways to verify if the irpa generation was complete/successful, that would be great.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I think ultimately should happen is to run a job for model importation before running the CI test jobs.

Another option is to run nightly a job that imports from HF then uploads to Azure so that other runners can update their model cache. This has the problem that we may overwrite existing model files with faulty ones if some bug appeared. In this scenario a more thorough model validation would be needed before uploading.

assert tokenizer_path.exists()
assert tokenizer_config_path.exists()
Loading