Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ homepage = "https://github.com/neuralmagic/speculators"
source = "https://github.com/neuralmagic/speculators"
issues = "https://github.com/neuralmagic/speculators/issues"

[project.entry-points.console_scripts]
speculators = "speculators.__main__:app"

# ************************************************
# ********** Code Quality Tools **********
# ************************************************
Expand Down
144 changes: 142 additions & 2 deletions src/speculators/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,148 @@
"""
Entry point for running speculators as a module.
CLI entrypoints for the Speculators library.

This module provides a command-line interface for creating and managing speculative
decoding models. The CLI is built using Typer and provides commands for model
conversion, version information, and other utilities.

The CLI can be accessed through the `speculators` command after installation, or by
running this module directly with `python -m speculators`.

Commands:
convert: Convert models from external repos/formats to supported Speculators models
version: Display the current version of the Speculators library

Usage:
$ speculators --help
$ speculators --version
$ speculators convert <model> [OPTIONS]
"""

from speculators.cli import app
import json
from importlib.metadata import version as pkg_version
from typing import Annotated, Any, Optional

import click
import typer

from speculators.convert import convert_model

__all__ = ["app"]

# Configure the main Typer application
app = typer.Typer(
name="speculators",
help="Speculators - Tools for speculative decoding with LLMs",
add_completion=False,
no_args_is_help=True,
)


def version_callback(value: bool):
"""
Callback function to print the version of the Speculators package and exit.

This function is used as a callback for the --version option in the main CLI.
When the version option is specified, it prints the version information and
exits the application.

:param value: Boolean indicating whether the version option was specified.
If True, prints version and exits.
"""
if value:
typer.echo(f"speculators version: {pkg_version('speculators')}")
raise typer.Exit


@app.callback()
def speculators(
ctx: typer.Context,
version: bool = typer.Option(
None,
"--version",
callback=version_callback,
),
):
"""
Main entry point for the Speculators CLI application.

This function serves as the root command callback and handles global options
such as version display. It is automatically called by Typer when the CLI
is invoked.

:param ctx: The Typer context object containing runtime information.
:param version: Boolean option to display version information and exit.
"""


# Add convert command
@app.command()
def convert(
model: str,
output_path: str = "speculators_converted",
config: Optional[str] = None,
verifier: Optional[str] = None,
validate_device: Optional[str] = None,
algorithm: Annotated[
str, typer.Option(click_type=click.Choice(["auto", "eagle", "eagle2", "hass"]))
] = "auto",
algorithm_kwargs: Annotated[
Optional[dict[str, Any]], typer.Option(parser=json.loads)
] = None,
cache_dir: Optional[str] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[str] = None,
revision: Optional[str] = None,
):
"""
Convert external models to Speculators-compatible format.

This command converts models from external research repositories or formats
into the standardized Speculators format. Currently supports model formats
from the list of research repositories below with automatic algorithm detection.

Supported Research Repositories:
- Eagle v1 and v2: https://github.com/SafeAILab/EAGLE
- HASS: https://github.com/HArmonizedSS/HASS

:param model: Path to model checkpoint or Hugging Face model ID to convert.
:param output_path: Directory path where converted model will be saved.
:param config: Path to config.json file or HF model ID for model configuration.
If not provided, configuration will be inferred from the checkpoint.
:param verifier: Path to verifier checkpoint or HF model ID to attach as the
verification model for speculative decoding.
:param validate_device: Device identifier (e.g., "cpu", "cuda") for post-conversion
validation. If not provided, validation is skipped.
:param algorithm: Conversion algorithm to use. "auto" enables automatic detection
based on model type and configuration.
:param algorithm_kwargs: Additional keyword arguments for the conversion algorithm
as a JSON string. Passed directly to the converter class.
:param cache_dir: Directory for caching downloaded models. Uses default HF cache
if not specified.
:param force_download: Force re-download of checkpoint and config files,
bypassing cache.
:param local_files_only: Use only local files without attempting downloads
from Hugging Face Hub.
:param token: Hugging Face authentication token for accessing private models.
:param revision: Git revision (branch, tag, or commit hash) for model files
from Hugging Face Hub.
"""
convert_model(
model=model,
output_path=output_path,
config=config,
verifier=verifier,
validate_device=validate_device,
algorithm=algorithm, # type: ignore[arg-type]
algorithm_kwargs=algorithm_kwargs,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
)


if __name__ == "__main__":
app()
54 changes: 0 additions & 54 deletions src/speculators/cli.py

This file was deleted.

48 changes: 37 additions & 11 deletions src/speculators/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@
from typing import Any, ClassVar, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel

from speculators.utils import PydanticClassRegistryMixin, ReloadableBaseModel
from speculators.utils import (
PydanticClassRegistryMixin,
ReloadableBaseModel,
load_model_config,
)

__all__ = [
"SpeculatorModelConfig",
Expand Down Expand Up @@ -78,28 +82,50 @@ class VerifierConfig(BaseModel):
"""

@classmethod
def from_config(
cls, config: PretrainedConfig, name_or_path: Optional[str] = "UNSET"
def from_pretrained(
cls,
config: Optional[
Union[str, os.PathLike, PreTrainedModel, PretrainedConfig, dict]
],
name_or_path: Optional[str] = "UNSET",
**kwargs,
) -> "VerifierConfig":
"""
Create a VerifierConfig from a PretrainedConfig object.
Create a VerifierConfig from a PretrainedConfig.
Used to extract the required parameters from the original verifier
config and create a VerifierConfig object.

:param config: The PretrainedConfig object to extract the parameters from.
:param config: The PretrainedConfig object or a path/huggingface model id
to the original verifier model config. If None, the config will be empty.
If a string or path is provided, it will be loaded as a PretrainedConfig.
If a PretrainedConfig is provided, it will be used directly.
:param name_or_path: The name or path for the verifier model.
Set to None to not add a specific name_or_path.
If not provided, the name_or_path from the config will be used.
:param kwargs: Additional keyword arguments to pass to AutoConfig for loading.
:return: A VerifierConfig object with the extracted parameters.
"""
config_dict = config.to_dict()
config_pretrained: Optional[Union[PretrainedConfig, dict]] = (
load_model_config(config, **kwargs) # type: ignore[assignment]
if config and not isinstance(config, dict)
else config
)
config_dict: dict = (
config_pretrained.to_dict() # type: ignore[assignment]
if config_pretrained and isinstance(config_pretrained, PretrainedConfig)
else config_pretrained
)
if not config_dict:
config_dict = {}

if name_or_path == "UNSET":
name_or_path = (
getattr(config, "name_or_path", None)
or config_dict.get("_name_or_path", None)
or config_dict.get("name_or_path", None)
config_name_or_path = (
getattr(config, "name_or_path", None) if config else None
)
config_dict_name_or_path = config_dict.get(
"_name_or_path", None
) or config_dict.get("name_or_path", None)
name_or_path = config_name_or_path or config_dict_name_or_path

return cls(
name_or_path=name_or_path,
Expand Down
36 changes: 32 additions & 4 deletions src/speculators/convert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
"""
Checkpoint conversion utilities for Speculators.

This module provides tools to convert existing speculator checkpoints
(Eagle, HASS, etc.) into the standardized speculators format.
This module provides tools to convert existing speculator checkpoints from external
research repositories (Eagle, HASS, etc.) into the standardized Speculators format.
The conversion process handles model architecture adaptation, configuration translation,
and optional verifier attachment for speculative decoding.

The primary entry point is the `convert_model` function, which supports automatic
algorithm detection and conversion from various input formats including local
checkpoints, Hugging Face model IDs, and PyTorch model instances.

Supported Research Repositories:
- Eagle v1 and v2: https://github.com/SafeAILab/EAGLE
- HASS: https://github.com/HArmonizedSS/HASS

Functions:
convert_model: Convert external model checkpoints to Speculators-compatible format

Usage:
::
from speculators.convert import convert_model

# Convert with automatic algorithm detection
model = convert_model("path/to/checkpoint", output_path="converted_model")

# Convert with specific algorithm and verifier
model = convert_model(
model="hf_model_id",
verifier="verifier_model_id",
output_path="my_speculator"
)
"""

from speculators.convert.eagle.eagle_converter import EagleConverter
from .converters import SpeculatorConverter
from .entrypoints import convert_model

__all__ = ["EagleConverter"]
__all__ = ["SpeculatorConverter", "convert_model"]
Loading