Skip to content

10. Universal Weights Loader

FerrisMind edited this page Sep 10, 2025 · 1 revision

Universal Weights Loader

Update Summary

Changes Made

  • Updated Core Architecture section to reflect the new ModelFactory pattern
  • Added detailed explanation of the unified model building approach
  • Updated API Interfaces to show the new ModelFactory-based loading pattern
  • Revised Implementation Details to include ModelFactory integration
  • Updated Integration Patterns to reflect the unified loading approach
  • Added new section on ModelFactory architecture
  • Updated Practical Examples to show current implementation
  • Revised Troubleshooting Guide to include ModelFactory-related issues
  • Updated Performance Considerations with new architectural insights

Table of Contents

  1. Introduction
  2. Core Architecture
  3. ModelFactory Architecture
  4. API Interfaces
  5. Implementation Details
  6. Integration Patterns
  7. Practical Examples
  8. Troubleshooting Guide
  9. Performance Considerations

Introduction

The Universal Weights Loader is a comprehensive system designed to load machine learning model weights from safetensors files, supporting both local and remote (Hugging Face Hub) sources. This module provides a unified interface for loading model weights in the safetensors format, handling various model architectures and device configurations. The system is built on the Candle framework and integrates with candle-transformers for model instantiation.

The loader supports multiple loading strategies including memory mapping for efficient resource usage, and provides automatic dtype selection based on the target device. It handles both sharded models (using index files) and single-file models, making it versatile for different deployment scenarios.

Section sources

  • safetensors.rs
  • safetensors.rs

Core Architecture

The Universal Weights Loader follows a modular architecture with clear separation of concerns between file discovery, weight loading, and model instantiation. The system is organized into several key components that work together to provide a seamless model loading experience.

mermaid
graph TD
A[Model Loading Request] --> B{Source Type}
B --> |Hugging Face Hub| C[Hub Model Loader]
B --> |Local Path| D[Local Model Loader]
C --> E[Hub Weights Utilities]
D --> F[Local Weights Utilities]
E --> G[Universal Weights Module]
F --> G
G --> H[VarBuilder Creation]
H --> I[ModelFactory]
I --> J[Architecture Detection]
J --> K[Model Builder]
K --> L[Final Model]

Diagram sources

  • hub_safetensors.rs
  • local_safetensors.rs
  • weights.rs

Section sources

  • hub_safetensors.rs
  • local_safetensors.rs

ModelFactory Architecture

The recent update introduces a unified ModelFactory pattern that consolidates model building for both GGUF and safetensors formats. This architectural change provides a consistent interface for model instantiation regardless of the source format.

ModelFactory Implementation

The ModelFactory is implemented as a singleton that manages model builders for different architectures. It provides a centralized registry for model builders and handles the model creation process.

pub struct ModelFactory {
    builders: HashMap<ArchKind, ModelBuilder>,
}

The factory supports registration of model builders and provides methods for building models from different sources:

impl ModelFactory {
    pub fn register_builder(&mut self, builder: ModelBuilder) {
        let arch_kind = builder.arch_kind();
        self.builders.insert(arch_kind, builder);
    }
    
    pub fn build_from_gguf<R: Read + Seek>(
        &self,
        arch: ArchKind,
        content: candle::quantized::gguf_file::Content,
        reader: &mut R,
        device: &Device,
        context_length: usize,
        flag: bool,
    ) -> BuildResult<Box<dyn ModelBackend>> {
        // Implementation details
    }
    
    pub fn build_from_safetensors(
        &self,
        arch: ArchKind,
        filenames: &[String],
        config: &serde_json::Value,
        device: &Device,
        dtype: DType,
    ) -> BuildResult<Box<dyn ModelBackend>> {
        // Implementation details
    }
}

Global Factory Instance

The system uses a global ModelFactory instance managed through a OnceLock:

static MODEL_FACTORY: OnceLock<ModelFactory> = OnceLock::new();

pub fn get_model_factory() -> &'static ModelFactory {
    MODEL_FACTORY.get_or_init(|| {
        let mut factory = ModelFactory::new();
        
        // Register Qwen3 builder
        factory.register_builder(crate::models::common::builder::ModelBuilder::Qwen3(Qwen3ModelBuilder::new()));
        
        factory
    })
}

Section sources

  • builder.rs
  • registry.rs

API Interfaces

The Universal Weights Loader provides a comprehensive API for loading models from safetensors files, with functions designed for different use cases and source types. The API is organized into a hierarchical structure that separates concerns between source-specific loading and universal weight management.

Primary Loading Functions

The main entry points for model loading are provided through the model loading module, which now integrates with the ModelFactory pattern:

pub fn load_hub_safetensors_model(
    guard: &mut ModelState<Box<dyn ModelBackend + Send>>,
    repo_id: String,
    revision: Option<String>,
    context_length: usize,
    device_pref: Option<DevicePreference>,
) -> Result<(), String>

pub fn load_local_safetensors_model(
    guard: &mut ModelState<Box<dyn ModelBackend + Send>>,
    model_path: String,
    context_length: usize,
    device_pref: Option<DevicePreference>,
) -> Result<(), String>

Universal Weights Management

The core weights module provides universal functions for handling safetensors files regardless of source:

pub fn hub_list_safetensors(api: &hf_hub::api::sync::ApiRepo) -> Result<Vec<String>, String>
pub fn local_list_safetensors<P: AsRef<Path>>(path: P) -> Result<Vec<String>, String>
pub fn build_varbuilder(safetensors_paths: &[String], device: &Device) -> Result<VarBuilder<'static>, String>
pub fn hub_cache_safetensors(api: &hf_hub::api::sync::ApiRepo, safetensors_files: &[String]) -> Result<Vec<String>, String>
pub fn validate_safetensors_files(safetensors_paths: &[String]) -> Result<(), String>

Core safetensors Operations

The underlying candle-core implementation provides fundamental operations for safetensors file handling:

pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>>
pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>>
pub fn save<K: AsRef<str> + Ord + std::fmt::Display, P: AsRef<Path>>(tensors: &HashMap<K, Tensor>, filename: P) -> Result<()>

Section sources

  • safetensors.rs
  • weights.rs
  • safetensors.rs

Implementation Details

The Universal Weights Loader implementation is built on a layered architecture that separates file discovery, weight loading, and model instantiation. This section details the key implementation aspects of each component.

File Discovery and Validation

The system uses a consistent approach for discovering safetensors files, whether from local paths or the Hugging Face Hub. The process follows a priority order:

  1. Look for model.safetensors.index.json to identify sharded models
  2. If no index file is found, look for a single model.safetensors file
  3. Validate that all referenced files exist and are accessible
mermaid
flowchart TD
Start([Start]) --> CheckIndex["Check for model.safetensors.index.json"]
CheckIndex --> IndexExists{Index File Exists?}
IndexExists --> |Yes| ParseIndex["Parse index.json to get weight_map"]
ParseIndex --> ExtractFiles["Extract unique file names from weight_map"]
ExtractFiles --> ValidateFiles["Validate file existence"]
ValidateFiles --> ReturnFiles["Return full file paths"]
IndexExists --> |No| CheckSingle["Check for model.safetensors"]
CheckSingle --> SingleExists{Single File Exists?}
SingleExists --> |Yes| ReturnSingle["Return model.safetensors path"]
SingleExists --> |No| Error["Return error: No safetensors files found"]
ReturnFiles --> End([End])
ReturnSingle --> End
Error --> End

Diagram sources

  • weights.rs

Section sources

  • weights.rs

Weight Loading and VarBuilder Creation

The system creates a VarBuilder instance from the discovered safetensors files, applying a unified dtype policy based on the target device:

mermaid
sequenceDiagram
participant Client as "Client Application"
participant Loader as "Universal Weights Loader"
participant VarBuilder as "VarBuilder"
participant Device as "Target Device"
Client->>Loader : Request model loading
Loader->>Loader : Discover safetensors files
Loader->>Loader : Validate file existence
Loader->>Loader : Apply dtype policy
alt CUDA/Metal Device
Loader->>Loader : Select BF16 dtype
else CPU Device
Loader->>Loader : Select F32 dtype
end
Loader->>VarBuilder : Create from mmaped safetensors
VarBuilder-->>Loader : Return VarBuilder instance
Loader-->>Client : Return loaded model

Diagram sources

  • weights.rs
  • safetensors.rs

Section sources

  • weights.rs

Model Instantiation with ModelFactory

The system uses the ModelFactory to detect model architectures and instantiate the appropriate model builder:

mermaid
classDiagram
class ModelFactory {
+builders : HashMap~ArchKind, ModelBuilder~
+register_builder(builder : ModelBuilder)
+build_from_safetensors(arch, filenames, config, device, dtype)
+build_from_gguf(arch, content, reader, device, context_length, flag)
+detect_gguf_arch(metadata)
+detect_config_arch(config)
}
class ModelBuilder {
<<enum>>
+from_gguf(content, reader, device, context_length, flag)
+from_varbuilder(vb, config, device, dtype)
+detect_gguf_arch(metadata)
+detect_config_arch(config)
+arch_kind()
}
class Qwen3ModelBuilder {
+from_gguf()
+from_varbuilder()
+detect_gguf_arch()
+detect_config_arch()
+arch_kind()
}
ModelFactory --> ModelBuilder : "contains"
Qwen3ModelBuilder --|> ModelBuilder : "implements"
ModelFactory --> Qwen3ModelBuilder : "uses"

Diagram sources

  • builder.rs
  • registry.rs

Section sources

  • builder.rs
  • registry.rs

Integration Patterns

The Universal Weights Loader supports several integration patterns for different use cases, from direct API usage to custom model loading workflows.

Hugging Face Hub Integration

For models hosted on the Hugging Face Hub, the loader follows a specific pattern:

  1. Initialize the Hugging Face API client
  2. Resolve the repository and revision
  3. Discover safetensors files using hub_list_safetensors
  4. Cache the required files locally using hub_cache_safetensors
  5. Validate the downloaded files
  6. Create a VarBuilder with build_varbuilder
  7. Use ModelFactory to instantiate the model

Local Model Integration

For locally stored models, the pattern is similar but with local file operations:

  1. Validate the model path exists and is accessible
  2. Discover safetensors files using local_list_safetensors
  3. Validate the local files
  4. Create a VarBuilder with build_varbuilder
  5. Use ModelFactory to instantiate the model

Unified Loading Pattern

Both Hub and local loading follow a unified pattern through the core weights module and ModelFactory:

mermaid
flowchart LR
A[Source Detection] --> B[File Discovery]
B --> C[File Validation]
C --> D[dtype Policy Application]
D --> E[VarBuilder Creation]
E --> F[Architecture Detection]
F --> G[ModelFactory.build_from_safetensors]
G --> H[Model Instantiation]
H --> I[Model Wrapping]
I --> J[State Update]

Section sources

  • hub_safetensors.rs
  • local_safetensors.rs
  • weights.rs

Practical Examples

This section provides practical examples of using the Universal Weights Loader for common scenarios.

Loading a Model from Hugging Face Hub

use crate::api::model_loading::hub_safetensors;

let mut model_state = ModelState::default();
let result = hub_safetensors::load_hub_safetensors_model(
    &mut model_state,
    "Qwen/Qwen1.5-0.5B-Chat".to_string(),
    Some("main".to_string()),
    2048,
    None, // Use default device preference
);

match result {
    Ok(()) => println!("Model loaded successfully"),
    Err(e) => println!("Failed to load model: {}", e),
}

Loading a Model from Local Path

use crate::api::model_loading::local_safetensors;

let mut model_state = ModelState::default();
let result = local_safetensors::load_local_safetensors_model(
    &mut model_state,
    "/path/to/local/model".to_string(),
    2048,
    None, // Use default device preference
);

match result {
    Ok(()) => println!("Model loaded successfully"),
    Err(e) => println!("Failed to load model: {}", e),
}

Manual Weight Loading with ModelFactory

use crate::core::weights;
use crate::models::registry::{get_model_factory, detect_arch_from_config};

// List safetensors files from a local directory
let safetensors_files = weights::local_list_safetensors("/path/to/model")?;

// Validate the files
weights::validate_safetensors_files(&safetensors_files)?;

// Load and parse config.json
let config_json_str = std::fs::read_to_string("/path/to/model/config.json")?;
let config: serde_json::Value = serde_json::from_str(&config_json_str)?;

// Detect architecture
let arch = detect_arch_from_config(&config).ok_or("Unsupported model architecture")?;

// Create a VarBuilder with automatic dtype selection
let device = Device::cuda_if_available(0).unwrap();
let var_builder = weights::build_varbuilder(&safetensors_files, &device)?;

// Use ModelFactory to build the model
let model_factory = get_model_factory();
let model = model_factory.build_from_safetensors(
    arch,
    &safetensors_files,
    &config,
    &device,
    if device.is_cuda() || device.is_metal() { DType::BF16 } else { DType::F32 }
)?;

Section sources

  • hub_safetensors.rs
  • local_safetensors.rs
  • weights.rs

Troubleshooting Guide

This section addresses common issues encountered when using the Universal Weights Loader and provides solutions.

Common Issues and Solutions

Issue: "No safetensors files found"

  • Cause: The loader cannot find model.safetensors.index.json or model.safetensors in the specified location
  • Solution:
    • Verify the model path is correct
    • Check that the model directory contains the required safetensors files
    • Ensure the model is in safetensors format (not GGUF or other formats)

Issue: "Safetensors file not found" during VarBuilder creation

  • Cause: Files referenced in the index.json do not exist in the expected locations
  • Solution:
    • Verify all files listed in model.safetensors.index.json exist
    • Check file permissions and accessibility
    • Ensure the model was downloaded completely

Issue: "Failed to parse config.json"

  • Cause: The config.json file is malformed or not valid JSON
  • Solution:
    • Validate the config.json file with a JSON validator
    • Check for encoding issues
    • Ensure the file is not corrupted

Issue: "Unsupported model architecture"

  • Cause: The model architecture is not supported by the current implementation
  • Solution:
    • Check the model's config.json for the model_type field
    • Verify the architecture is in the supported list (Qwen3, Llama, etc.)
    • Consider implementing a new model builder if needed

Issue: "No builder registered for architecture"

  • Cause: The ModelFactory does not have a registered builder for the detected architecture
  • Solution:
    • Check that the architecture is supported by the system
    • Verify that the appropriate model builder is registered in the ModelFactory
    • For Qwen3 models, ensure the Qwen3ModelBuilder is properly registered

Debugging Tips

  1. Enable verbose logging to see detailed loading steps
  2. Use the validation functions to check file integrity before loading
  3. Verify device availability and compatibility
  4. Check dtype compatibility between the model and target device
  5. Ensure sufficient memory is available for the model size
  6. Verify that the ModelFactory has the appropriate builders registered

Section sources

  • weights.rs
  • hub_safetensors.rs
  • local_safetensors.rs
  • builder.rs

Performance Considerations

The Universal Weights Loader incorporates several performance optimizations to ensure efficient model loading and memory usage.

Memory Mapping

The system uses memory mapping (mmap) for safetensors files, which provides several benefits:

  • Reduces memory footprint by loading only needed portions of files
  • Enables efficient random access to tensor data
  • Avoids copying large amounts of data into memory

dtype Optimization

The loader applies a unified dtype policy based on the target device:

  • CUDA/Metal devices: Uses BF16 for better performance and reduced memory usage
  • CPU devices: Uses F32 for maximum compatibility and precision

This automatic dtype selection ensures optimal performance without requiring manual configuration.

Caching Strategy

For Hugging Face Hub models, the loader implements a caching strategy:

  • Downloads required safetensors files to local cache
  • Reuses cached files for subsequent loads
  • Validates file integrity before use

ModelFactory Benefits

The introduction of the ModelFactory pattern provides several performance and architectural benefits:

  • Centralized model builder management
  • Consistent interface for different model formats
  • Efficient architecture detection
  • Reduced code duplication
  • Easier extension for new model architectures

Parallel Operations

The system is designed to support parallel operations where possible:

  • Multiple safetensors files can be processed concurrently
  • File downloads can occur in parallel
  • Model loading can be performed asynchronously

These performance considerations make the Universal Weights Loader suitable for both development and production environments, providing efficient resource usage while maintaining flexibility and ease of use.

Section sources

  • weights.rs
  • safetensors.rs
  • builder.rs

Referenced Files in This Document

  • safetensors.rs - Core safetensors implementation
  • safetensors.rs - Updated in recent commit with ModelFactory integration
  • weights.rs - Updated in recent commit with unified loading approach
  • registry.rs - Contains ModelFactory access and architecture detection
  • builder.rs - New ModelFactory implementation for unified model building
  • qwen3_builder.rs - Qwen3 model builder implementation
  • model.rs - ModelBackend trait definition

Clone this wiki locally