Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions api/rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ license = "MIT OR Apache-2.0"
authors = ["Mathieu Poumeyrol <kali@zoy.org>"]
description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
repository = "https://github.com/sonos/tract"
keywords = [ "NeuralNetworks" ]
categories = [ "science" ]
keywords = ["NeuralNetworks"]
categories = ["science"]
autobenches = false
edition = "2024"
rust-version = "1.85"
include = [ "Cargo.toml", "src/**/*.rs", "LICENSE*", "tests" ]
include = ["Cargo.toml", "src/**/*.rs", "LICENSE*", "tests"]

[dependencies]
anyhow.workspace = true
Expand All @@ -21,7 +21,6 @@ icu_normalizer.workspace = true
icu_properties.workspace = true
ndarray.workspace = true
tract-api.workspace = true
tract-cuda.workspace = true
tract-nnef.workspace = true
tract-onnx-opl.workspace = true
tract-onnx.workspace = true
Expand All @@ -34,6 +33,10 @@ serde_json.workspace = true
[target.'cfg(any(target_vendor = "apple"))'.dependencies]
tract-metal.workspace = true

[target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies]
tract-cuda.workspace = true


[dev-dependencies]
reqwest.workspace = true
tempfile.workspace = true
Expand Down
1 change: 1 addition & 0 deletions api/rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#[cfg(target_vendor = "apple")]
extern crate tract_metal;

#[cfg(any(target_os = "linux", target_os = "windows"))]
extern crate tract_cuda;
extern crate tract_transformers;

Expand Down
48 changes: 31 additions & 17 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
[package]
name = "tract"
version = "0.23.0-pre"
authors = [ "Romain Liautaud <romain.liautaud@snips.ai>", "Mathieu Poumeyrol <kali@zoy.org>"]
authors = [
"Romain Liautaud <romain.liautaud@snips.ai>",
"Mathieu Poumeyrol <kali@zoy.org>",
]
license = "MIT OR Apache-2.0"
description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
repository = "https://github.com/snipsco/tract"
keywords = [ "TensorFlow", "NeuralNetworks" ]
categories = [ "science" ]
keywords = ["TensorFlow", "NeuralNetworks"]
categories = ["science"]
autobenches = false
edition = "2024"
include = [ "Cargo.toml", "src/**/*.rs", "LICENSE*" ]
include = ["Cargo.toml", "src/**/*.rs", "LICENSE*"]

[badges]
maintenance = { status = "actively-developed" }
Expand All @@ -20,7 +23,6 @@ box_drawing.workspace = true
clap.workspace = true
criterion.workspace = true
colorous.workspace = true
cudarc.workspace = true
env_logger.workspace = true
flate2.workspace = true
fs-err.workspace = true
Expand Down Expand Up @@ -51,7 +53,6 @@ tract-nnef.workspace = true
tract-nnef-resources.workspace = true
tract-libcli.workspace = true
tract-gpu.workspace = true
tract-cuda.workspace = true
tract-extra = { workspace = true, optional = true }
tract-pulse = { workspace = true, optional = true }
tract-pulse-opl = { workspace = true, optional = true }
Expand All @@ -66,15 +67,28 @@ levenshtein-diff.workspace = true
[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
tract-metal.workspace = true

[target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies]
cudarc.workspace = true
tract-cuda.workspace = true

[features]
default = ["onnx", "tf", "pulse", "pulse-opl", "tflite", "transformers", "extra"]
apple-amx-ios = [ "tract-linalg/apple-amx-ios" ]
onnx = [ "tract-onnx", "tract-libcli/hir", "tract-libcli/onnx" ]
extra = [ "tract-extra" ]
pulse-opl = [ "tract-pulse-opl" ]
pulse = [ "tract-pulse", "tract-pulse-opl" ]
tf = [ "tract-tensorflow", "tract-libcli/hir" ]
tflite = [ "tract-tflite" ]
transformers = [ "tract-transformers", "tract-libcli/transformers" ]
conform = [ "tract-tensorflow/conform" ]
multithread-mm = [ "tract-linalg/multithread-mm" ]
default = [
"onnx",
"tf",
"pulse",
"pulse-opl",
"tflite",
"transformers",
"extra",
]
apple-amx-ios = ["tract-linalg/apple-amx-ios"]
onnx = ["tract-onnx", "tract-libcli/hir", "tract-libcli/onnx"]
extra = ["tract-extra"]
pulse-opl = ["tract-pulse-opl"]
pulse = ["tract-pulse", "tract-pulse-opl"]
tf = ["tract-tensorflow", "tract-libcli/hir"]
tflite = ["tract-tflite"]
transformers = ["tract-transformers", "tract-libcli/transformers"]
conform = ["tract-tensorflow/conform"]
multithread-mm = ["tract-linalg/multithread-mm"]
inventory-registry = []
1 change: 1 addition & 0 deletions cli/src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub fn handle(
limits.warmup(&params.req_runnable()?, &inputs)?;

let (iters, dur) = {
#[cfg(any(target_os = "linux", target_os = "windows"))]
let _profiler =
sub_matches.is_present("cuda-gpu-trace").then(cudarc::driver::safe::Profiler::new);
limits.bench(&params.req_runnable()?, &inputs)?
Expand Down
1 change: 1 addition & 0 deletions cli/src/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tract_core::ops::matmul::optimized::{OptMatMul, ProtoFusedSpec};
use tract_core::ops::matmul::pack::DynPackedOpaqueFact;
use tract_core::ops::scan::OptScan;
#[allow(unused_imports)]
#[cfg(any(target_os = "linux", target_os = "windows"))]
use tract_cuda::utils::ensure_cuda_runtime_dependencies;
use tract_hir::internal::*;
use tract_itertools::Itertools;
Expand Down
3 changes: 2 additions & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ rustfft.workspace = true
smallvec.workspace = true
tract-linalg.workspace = true
tract-data.workspace = true
inventory.workspace = true
inventory = { workspace = true, optional = true }

[features]
default = [ ]
inventory-registry = [ "inventory" ]
complex = [ "tract-data/complex", "tract-linalg/complex" ]
blas = [ "cblas" ]
accelerate = [ "blas", "accelerate-src" ]
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ pub mod prelude {

/// This prelude is meant for code extending tract (like implementing new ops).
pub mod internal {
#[cfg(feature = "inventory-registry")]
pub extern crate inventory;
pub use crate::axes::{AxesMapping, Axis};
pub use crate::late_bind::*;
Expand Down
12 changes: 12 additions & 0 deletions core/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,33 @@ impl Debug for InventorizedRuntime {
}
}

#[cfg(feature = "inventory-registry")]
inventory::collect!(InventorizedRuntime);

#[cfg(feature = "inventory-registry")]
pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
}

#[cfg(not(feature = "inventory-registry"))]
pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
static DEFAULT: DefaultRuntime = DefaultRuntime;
[(&DEFAULT as &'static dyn Runtime)].into_iter()
}

pub fn runtime_for_name(s: &str) -> Option<&'static dyn Runtime> {
runtimes().find(|rt| rt.name() == s)
}

#[macro_export]
macro_rules! register_runtime {
($type: ty= $val:expr) => {
#[cfg(feature = "inventory-registry")]
static D: $type = $val;
#[cfg(feature = "inventory-registry")]
inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
#[cfg(not(feature = "inventory-registry"))]
const _: () = (); // no-op when inventory is disabled
};
}

Expand Down
95 changes: 77 additions & 18 deletions core/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,44 +102,103 @@ pub struct ModelTransformFactory {
pub builder: fn(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>>,
}

#[cfg(feature = "inventory-registry")]
inventory::collect!(ModelTransformFactory);

#[macro_export]
macro_rules! register_simple_model_transform {
($name: expr, $type: expr) => {
#[cfg(feature = "inventory-registry")]
$crate::internal::inventory::submit! {
$crate::transform::ModelTransformFactory {
name: $name,
builder: |_| Ok(Some(Box::new($type)))
}
}
#[cfg(not(feature = "inventory-registry"))]
const _: () = (); // no-op when inventory is disabled
};
}

pub fn get_transform(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
for factory in inventory::iter::<ModelTransformFactory>() {
if spec.starts_with(factory.name) {
return (factory.builder)(spec);
/// Declare a set of transform factories once, and generate both
/// inventory registrations and a non-inventory fallback `get_transform`.
#[macro_export]
macro_rules! declare_transform_factories {
( $fname:ident, $( $(#[$m:meta])? ($name:expr, $builder:expr) ),+ $(,)? ) => {
$(
$(#[$m])?
#[cfg(feature = "inventory-registry")]
$crate::internal::inventory::submit! {
$crate::transform::ModelTransformFactory { name: $name, builder: $builder }
}
)+

#[cfg(not(feature = "inventory-registry"))]
pub fn $fname(
spec: &str,
) -> ::std::result::Result<
Option<Box<dyn $crate::transform::ModelTransform>>,
Box<dyn ::std::error::Error + Send + Sync + 'static>,
> {
$(
$(#[$m])?
if spec.starts_with($name) {
return ($builder)(spec);
}
)+
Ok(None)
}
}
Ok(None)
}

register_simple_model_transform!("softmax-fast-compact", SoftmaxFastCompact);
#[cfg(feature = "blas")]
register_simple_model_transform!("as-blas", AsBlas);
register_simple_model_transform!("block-quant", BlockQuantTransform);

inventory::submit! {
ModelTransformFactory {
name: "f32-to-f16",
builder: |spec| Ok(build_float_translator::<f32,f16>(spec.strip_prefix("f32-to-f16")))
/// Declare simple transforms by type (must be Default), generating both
/// inventory registrations and a non-inventory `get_transform`.
#[macro_export]
macro_rules! declare_model_transforms {
( $( ($name:expr, $ty:ty) ),+ $(,)? ) => {
$(
$crate::register_simple_model_transform!($name, <$ty>::default());
)+

#[cfg(not(feature = "inventory-registry"))]
pub fn get_transform(
spec: &str,
) -> ::std::result::Result<
Option<Box<dyn $crate::transform::ModelTransform>>,
Box<dyn ::std::error::Error + Send + Sync + 'static>,
> {
$(
if spec.starts_with($name) {
return Ok(Some(Box::new(<$ty>::default())));
}
)+
Ok(None)
}
}
}

inventory::submit! {
ModelTransformFactory {
name: "f16-to-f32",
builder: |spec| Ok(build_float_translator::<f16,f32>(spec.strip_prefix("f16-to-f32")))
pub fn get_transform(spec: &str) -> TractResult<Option<Box<dyn ModelTransform>>> {
#[cfg(feature = "inventory-registry")]
{
for factory in inventory::iter::<ModelTransformFactory>() {
if spec.starts_with(factory.name) {
return (factory.builder)(spec);
}
}
Ok(None)
}
#[cfg(not(feature = "inventory-registry"))]
{
lookup_core_transforms(spec).map_err(|e| anyhow::anyhow!(e))
}
}

declare_transform_factories! {
lookup_core_transforms,
("softmax-fast-compact", |_| Ok(Some(Box::new(SoftmaxFastCompact) as Box<dyn ModelTransform>))),
("block-quant", |_| Ok(Some(Box::new(BlockQuantTransform) as Box<dyn ModelTransform>))),
#[cfg(feature = "blas")]
("as-blas", |_| Ok(Some(Box::new(AsBlas) as Box<dyn ModelTransform>))),
("f32-to-f16", |spec: &str| Ok(build_float_translator::<f32,f16>(spec.strip_prefix("f32-to-f16")))),
("f16-to-f32", |spec: &str| Ok(build_float_translator::<f16,f32>(spec.strip_prefix("f16-to-f32")))),
}
2 changes: 2 additions & 0 deletions cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ rand.workspace = true
[features]
cuda-12060 = ["cudarc/cuda-12060"]
default = ["cuda-12060"]
# Dummy feature for cfg used in macros that may expand here.
inventory-registry = []

[[bench]]
name = "cuda_flash"
Expand Down
7 changes: 7 additions & 0 deletions cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod utils;
pub use context::CUDA_STREAM;
use tract_core::internal::*;
use tract_core::transform::ModelTransform;
use tract_core::declare_transform_factories;
pub use transform::CudaTransform;

use crate::utils::ensure_cuda_runtime_dependencies;
Expand Down Expand Up @@ -50,3 +51,9 @@ impl Runtime for CudaRuntime {
}

register_runtime!(CudaRuntime = CudaRuntime);

// Register CUDA-specific transforms via central macro (inventory/no-inventory).
declare_transform_factories! {
lookup_cuda_transforms,
("cuda-transform", |_| Ok(Some(Box::new(CudaTransform::default()) as Box<dyn ModelTransform>))),
}
15 changes: 8 additions & 7 deletions libcli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ license = "MIT OR Apache-2.0"
authors = ["Mathieu Poumeyrol <kali@zoy.org>"]
description = "Tiny, no-nonsense, self contained, TensorFlow and ONNX inference"
repository = "https://github.com/snipsco/tract"
keywords = [ "TensorFlow", "NeuralNetworks" ]
categories = [ "science" ]
keywords = ["TensorFlow", "NeuralNetworks"]
categories = ["science"]
edition = "2024"

[badges]
Expand All @@ -20,7 +20,6 @@ lazy_static.workspace = true
log.workspace = true
ndarray-npy.workspace = true
nu-ansi-term.workspace = true
cudarc.workspace = true
py_literal.workspace = true
rand.workspace = true
serde.workspace = true
Expand All @@ -29,18 +28,20 @@ tract-core.workspace = true
tract-hir.workspace = true
tract-onnx = { workspace = true, optional = true }
tract-tflite.workspace = true
tract-cuda.workspace = true
tract-gpu.workspace = true
tract-transformers = {workspace= true, optional = true}
tract-transformers = { workspace = true, optional = true }

[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
tract-metal = { workspace = true }

[target.'cfg(any(target_os = "linux", target_os = "windows"))'.dependencies]
cudarc.workspace = true
tract-cuda.workspace = true

[features]
default = ["transformers"]
# hir = ["tract-hir"]
hir = []
onnx = [ "tract-onnx" ]
complex = [ "tract-core/complex" ]
onnx = ["tract-onnx"]
complex = ["tract-core/complex"]
transformers = ["tract-transformers"]
Loading
Loading