-
Notifications
You must be signed in to change notification settings - Fork 0
10. Universal Weights Loader
Changes Made
- Updated Core Architecture section to reflect the new ModelFactory pattern
- Added detailed explanation of the unified model building approach
- Updated API Interfaces to show the new ModelFactory-based loading pattern
- Revised Implementation Details to include ModelFactory integration
- Updated Integration Patterns to reflect the unified loading approach
- Added new section on ModelFactory architecture
- Updated Practical Examples to show current implementation
- Revised Troubleshooting Guide to include ModelFactory-related issues
- Updated Performance Considerations with new architectural insights
- Introduction
- Core Architecture
- ModelFactory Architecture
- API Interfaces
- Implementation Details
- Integration Patterns
- Practical Examples
- Troubleshooting Guide
- Performance Considerations
The Universal Weights Loader is a comprehensive system designed to load machine learning model weights from safetensors files, supporting both local and remote (Hugging Face Hub) sources. This module provides a unified interface for loading model weights in the safetensors format, handling various model architectures and device configurations. The system is built on the Candle framework and integrates with candle-transformers for model instantiation.
The loader supports multiple loading strategies including memory mapping for efficient resource usage, and provides automatic dtype selection based on the target device. It handles both sharded models (using index files) and single-file models, making it versatile for different deployment scenarios.
Section sources
- safetensors.rs
- safetensors.rs
The Universal Weights Loader follows a modular architecture with clear separation of concerns between file discovery, weight loading, and model instantiation. The system is organized into several key components that work together to provide a seamless model loading experience.
mermaid
graph TD
A[Model Loading Request] --> B{Source Type}
B --> |Hugging Face Hub| C[Hub Model Loader]
B --> |Local Path| D[Local Model Loader]
C --> E[Hub Weights Utilities]
D --> F[Local Weights Utilities]
E --> G[Universal Weights Module]
F --> G
G --> H[VarBuilder Creation]
H --> I[ModelFactory]
I --> J[Architecture Detection]
J --> K[Model Builder]
K --> L[Final Model]
Diagram sources
- hub_safetensors.rs
- local_safetensors.rs
- weights.rs
Section sources
- hub_safetensors.rs
- local_safetensors.rs
The recent update introduces a unified ModelFactory pattern that consolidates model building for both GGUF and safetensors formats. This architectural change provides a consistent interface for model instantiation regardless of the source format.
The ModelFactory is implemented as a singleton that manages model builders for different architectures. It provides a centralized registry for model builders and handles the model creation process.
pub struct ModelFactory {
builders: HashMap<ArchKind, ModelBuilder>,
}The factory supports registration of model builders and provides methods for building models from different sources:
impl ModelFactory {
pub fn register_builder(&mut self, builder: ModelBuilder) {
let arch_kind = builder.arch_kind();
self.builders.insert(arch_kind, builder);
}
pub fn build_from_gguf<R: Read + Seek>(
&self,
arch: ArchKind,
content: candle::quantized::gguf_file::Content,
reader: &mut R,
device: &Device,
context_length: usize,
flag: bool,
) -> BuildResult<Box<dyn ModelBackend>> {
// Implementation details
}
pub fn build_from_safetensors(
&self,
arch: ArchKind,
filenames: &[String],
config: &serde_json::Value,
device: &Device,
dtype: DType,
) -> BuildResult<Box<dyn ModelBackend>> {
// Implementation details
}
}The system uses a global ModelFactory instance managed through a OnceLock:
static MODEL_FACTORY: OnceLock<ModelFactory> = OnceLock::new();
pub fn get_model_factory() -> &'static ModelFactory {
MODEL_FACTORY.get_or_init(|| {
let mut factory = ModelFactory::new();
// Register Qwen3 builder
factory.register_builder(crate::models::common::builder::ModelBuilder::Qwen3(Qwen3ModelBuilder::new()));
factory
})
}Section sources
- builder.rs
- registry.rs
The Universal Weights Loader provides a comprehensive API for loading models from safetensors files, with functions designed for different use cases and source types. The API is organized into a hierarchical structure that separates concerns between source-specific loading and universal weight management.
The main entry points for model loading are provided through the model loading module, which now integrates with the ModelFactory pattern:
pub fn load_hub_safetensors_model(
guard: &mut ModelState<Box<dyn ModelBackend + Send>>,
repo_id: String,
revision: Option<String>,
context_length: usize,
device_pref: Option<DevicePreference>,
) -> Result<(), String>
pub fn load_local_safetensors_model(
guard: &mut ModelState<Box<dyn ModelBackend + Send>>,
model_path: String,
context_length: usize,
device_pref: Option<DevicePreference>,
) -> Result<(), String>The core weights module provides universal functions for handling safetensors files regardless of source:
pub fn hub_list_safetensors(api: &hf_hub::api::sync::ApiRepo) -> Result<Vec<String>, String>
pub fn local_list_safetensors<P: AsRef<Path>>(path: P) -> Result<Vec<String>, String>
pub fn build_varbuilder(safetensors_paths: &[String], device: &Device) -> Result<VarBuilder<'static>, String>
pub fn hub_cache_safetensors(api: &hf_hub::api::sync::ApiRepo, safetensors_files: &[String]) -> Result<Vec<String>, String>
pub fn validate_safetensors_files(safetensors_paths: &[String]) -> Result<(), String>The underlying candle-core implementation provides fundamental operations for safetensors file handling:
pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>>
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>>
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(tensors: &HashMap<K, Tensor>, filename: P) -> Result<()>Section sources
- safetensors.rs
- weights.rs
- safetensors.rs
The Universal Weights Loader implementation is built on a layered architecture that separates file discovery, weight loading, and model instantiation. This section details the key implementation aspects of each component.
The system uses a consistent approach for discovering safetensors files, whether from local paths or the Hugging Face Hub. The process follows a priority order:
- Look for
model.safetensors.index.jsonto identify sharded models - If no index file is found, look for a single
model.safetensorsfile - Validate that all referenced files exist and are accessible
mermaid
flowchart TD
Start([Start]) --> CheckIndex["Check for model.safetensors.index.json"]
CheckIndex --> IndexExists{Index File Exists?}
IndexExists --> |Yes| ParseIndex["Parse index.json to get weight_map"]
ParseIndex --> ExtractFiles["Extract unique file names from weight_map"]
ExtractFiles --> ValidateFiles["Validate file existence"]
ValidateFiles --> ReturnFiles["Return full file paths"]
IndexExists --> |No| CheckSingle["Check for model.safetensors"]
CheckSingle --> SingleExists{Single File Exists?}
SingleExists --> |Yes| ReturnSingle["Return model.safetensors path"]
SingleExists --> |No| Error["Return error: No safetensors files found"]
ReturnFiles --> End([End])
ReturnSingle --> End
Error --> End
Diagram sources
- weights.rs
Section sources
- weights.rs
The system creates a VarBuilder instance from the discovered safetensors files, applying a unified dtype policy based on the target device:
mermaid
sequenceDiagram
participant Client as "Client Application"
participant Loader as "Universal Weights Loader"
participant VarBuilder as "VarBuilder"
participant Device as "Target Device"
Client->>Loader : Request model loading
Loader->>Loader : Discover safetensors files
Loader->>Loader : Validate file existence
Loader->>Loader : Apply dtype policy
alt CUDA/Metal Device
Loader->>Loader : Select BF16 dtype
else CPU Device
Loader->>Loader : Select F32 dtype
end
Loader->>VarBuilder : Create from mmaped safetensors
VarBuilder-->>Loader : Return VarBuilder instance
Loader-->>Client : Return loaded model
Diagram sources
- weights.rs
- safetensors.rs
Section sources
- weights.rs
The system uses the ModelFactory to detect model architectures and instantiate the appropriate model builder:
mermaid
classDiagram
class ModelFactory {
+builders : HashMap~ArchKind, ModelBuilder~
+register_builder(builder : ModelBuilder)
+build_from_safetensors(arch, filenames, config, device, dtype)
+build_from_gguf(arch, content, reader, device, context_length, flag)
+detect_gguf_arch(metadata)
+detect_config_arch(config)
}
class ModelBuilder {
<<enum>>
+from_gguf(content, reader, device, context_length, flag)
+from_varbuilder(vb, config, device, dtype)
+detect_gguf_arch(metadata)
+detect_config_arch(config)
+arch_kind()
}
class Qwen3ModelBuilder {
+from_gguf()
+from_varbuilder()
+detect_gguf_arch()
+detect_config_arch()
+arch_kind()
}
ModelFactory --> ModelBuilder : "contains"
Qwen3ModelBuilder --|> ModelBuilder : "implements"
ModelFactory --> Qwen3ModelBuilder : "uses"
Diagram sources
- builder.rs
- registry.rs
Section sources
- builder.rs
- registry.rs
The Universal Weights Loader supports several integration patterns for different use cases, from direct API usage to custom model loading workflows.
For models hosted on the Hugging Face Hub, the loader follows a specific pattern:
- Initialize the Hugging Face API client
- Resolve the repository and revision
- Discover safetensors files using
hub_list_safetensors - Cache the required files locally using
hub_cache_safetensors - Validate the downloaded files
- Create a VarBuilder with
build_varbuilder - Use ModelFactory to instantiate the model
For locally stored models, the pattern is similar but with local file operations:
- Validate the model path exists and is accessible
- Discover safetensors files using
local_list_safetensors - Validate the local files
- Create a VarBuilder with
build_varbuilder - Use ModelFactory to instantiate the model
Both Hub and local loading follow a unified pattern through the core weights module and ModelFactory:
mermaid
flowchart LR
A[Source Detection] --> B[File Discovery]
B --> C[File Validation]
C --> D[dtype Policy Application]
D --> E[VarBuilder Creation]
E --> F[Architecture Detection]
F --> G[ModelFactory.build_from_safetensors]
G --> H[Model Instantiation]
H --> I[Model Wrapping]
I --> J[State Update]
Section sources
- hub_safetensors.rs
- local_safetensors.rs
- weights.rs
This section provides practical examples of using the Universal Weights Loader for common scenarios.
use crate::api::model_loading::hub_safetensors;
let mut model_state = ModelState::default();
let result = hub_safetensors::load_hub_safetensors_model(
&mut model_state,
"Qwen/Qwen1.5-0.5B-Chat".to_string(),
Some("main".to_string()),
2048,
None, // Use default device preference
);
match result {
Ok(()) => println!("Model loaded successfully"),
Err(e) => println!("Failed to load model: {}", e),
}use crate::api::model_loading::local_safetensors;
let mut model_state = ModelState::default();
let result = local_safetensors::load_local_safetensors_model(
&mut model_state,
"/path/to/local/model".to_string(),
2048,
None, // Use default device preference
);
match result {
Ok(()) => println!("Model loaded successfully"),
Err(e) => println!("Failed to load model: {}", e),
}use crate::core::weights;
use crate::models::registry::{get_model_factory, detect_arch_from_config};
// List safetensors files from a local directory
let safetensors_files = weights::local_list_safetensors("/path/to/model")?;
// Validate the files
weights::validate_safetensors_files(&safetensors_files)?;
// Load and parse config.json
let config_json_str = std::fs::read_to_string("/path/to/model/config.json")?;
let config: serde_json::Value = serde_json::from_str(&config_json_str)?;
// Detect architecture
let arch = detect_arch_from_config(&config).ok_or("Unsupported model architecture")?;
// Create a VarBuilder with automatic dtype selection
let device = Device::cuda_if_available(0).unwrap();
let var_builder = weights::build_varbuilder(&safetensors_files, &device)?;
// Use ModelFactory to build the model
let model_factory = get_model_factory();
let model = model_factory.build_from_safetensors(
arch,
&safetensors_files,
&config,
&device,
if device.is_cuda() || device.is_metal() { DType::BF16 } else { DType::F32 }
)?;Section sources
- hub_safetensors.rs
- local_safetensors.rs
- weights.rs
This section addresses common issues encountered when using the Universal Weights Loader and provides solutions.
Issue: "No safetensors files found"
-
Cause: The loader cannot find
model.safetensors.index.jsonormodel.safetensorsin the specified location -
Solution:
- Verify the model path is correct
- Check that the model directory contains the required safetensors files
- Ensure the model is in safetensors format (not GGUF or other formats)
Issue: "Safetensors file not found" during VarBuilder creation
- Cause: Files referenced in the index.json do not exist in the expected locations
-
Solution:
- Verify all files listed in
model.safetensors.index.jsonexist - Check file permissions and accessibility
- Ensure the model was downloaded completely
- Verify all files listed in
Issue: "Failed to parse config.json"
- Cause: The config.json file is malformed or not valid JSON
-
Solution:
- Validate the config.json file with a JSON validator
- Check for encoding issues
- Ensure the file is not corrupted
Issue: "Unsupported model architecture"
- Cause: The model architecture is not supported by the current implementation
-
Solution:
- Check the model's config.json for the model_type field
- Verify the architecture is in the supported list (Qwen3, Llama, etc.)
- Consider implementing a new model builder if needed
Issue: "No builder registered for architecture"
- Cause: The ModelFactory does not have a registered builder for the detected architecture
-
Solution:
- Check that the architecture is supported by the system
- Verify that the appropriate model builder is registered in the ModelFactory
- For Qwen3 models, ensure the Qwen3ModelBuilder is properly registered
- Enable verbose logging to see detailed loading steps
- Use the validation functions to check file integrity before loading
- Verify device availability and compatibility
- Check dtype compatibility between the model and target device
- Ensure sufficient memory is available for the model size
- Verify that the ModelFactory has the appropriate builders registered
Section sources
- weights.rs
- hub_safetensors.rs
- local_safetensors.rs
- builder.rs
The Universal Weights Loader incorporates several performance optimizations to ensure efficient model loading and memory usage.
The system uses memory mapping (mmap) for safetensors files, which provides several benefits:
- Reduces memory footprint by loading only needed portions of files
- Enables efficient random access to tensor data
- Avoids copying large amounts of data into memory
The loader applies a unified dtype policy based on the target device:
- CUDA/Metal devices: Uses BF16 for better performance and reduced memory usage
- CPU devices: Uses F32 for maximum compatibility and precision
This automatic dtype selection ensures optimal performance without requiring manual configuration.
For Hugging Face Hub models, the loader implements a caching strategy:
- Downloads required safetensors files to local cache
- Reuses cached files for subsequent loads
- Validates file integrity before use
The introduction of the ModelFactory pattern provides several performance and architectural benefits:
- Centralized model builder management
- Consistent interface for different model formats
- Efficient architecture detection
- Reduced code duplication
- Easier extension for new model architectures
The system is designed to support parallel operations where possible:
- Multiple safetensors files can be processed concurrently
- File downloads can occur in parallel
- Model loading can be performed asynchronously
These performance considerations make the Universal Weights Loader suitable for both development and production environments, providing efficient resource usage while maintaining flexibility and ease of use.
Section sources
- weights.rs
- safetensors.rs
- builder.rs
Referenced Files in This Document
- safetensors.rs - Core safetensors implementation
- safetensors.rs - Updated in recent commit with ModelFactory integration
- weights.rs - Updated in recent commit with unified loading approach
- registry.rs - Contains ModelFactory access and architecture detection
- builder.rs - New ModelFactory implementation for unified model building
- qwen3_builder.rs - Qwen3 model builder implementation
- model.rs - ModelBackend trait definition