Skip to content

24.6. Extending With New Model Architectures

FerrisMind edited this page Sep 10, 2025 · 1 revision

Extending with New Model Architectures

Table of Contents

  1. ModelBackend Trait Interface
  2. Model Registration and Detection
  3. Example Implementation: Qwen3
  4. Custom Layers and Attention Mechanisms
  5. Position Embedding Strategies
  6. Testing New Model Implementations
  7. Step-by-Step Guide to Adding a New Architecture

ModelBackend Trait Interface

The ModelBackend trait defines the core interface that all model implementations must conform to in order to be integrated into the system. This trait ensures a consistent API for forward computation across different model architectures.

pub trait ModelBackend: Send {
    fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String>;
}

This trait requires implementation of a single method:

:forward_layered

  • input: A reference to a Tensor containing the input token embeddings
  • position: The current position in the sequence (used for KV cache management and position embeddings)
  • Returns a Result<Tensor, String> representing the output logits or an error message

The trait is designed to support stateful models that maintain internal caches (e.g., KV caches for autoregressive generation). Implementations are expected to handle incremental decoding where the model processes one token at a time during generation.

The AnyModel wrapper struct provides a type-erased interface to any model that implements ModelBackend, allowing the system to work uniformly with different model types through dynamic dispatch.

Section sources

  • model.rs

Model Registration and Detection

Model registration is handled through the architecture detection system in registry.rs. The system uses metadata from model files to automatically detect the appropriate architecture type and instantiate the correct model implementation.

pub enum ArchKind {
    Qwen3,
}

pub fn detect_arch(metadata: &HashMap<String, candle::quantized::gguf_file::Value>) -> Option<ArchKind> {
    for (_k, v) in metadata.iter() {
        if let Ok(s) = v.to_string() {
            if s.to_lowercase().contains("qwen") {
                return Some(ArchKind::Qwen3);
            }
        }
    }
    None
}

:ArchKind
An enumeration that lists all supported model architectures. Currently only Qwen3 is supported, but new variants should be added when implementing additional architectures.

:detect_arch
A heuristic-based function that examines GGUF metadata to identify the model architecture. It performs case-insensitive string matching on metadata values to detect architecture indicators.

To register a new model architecture:

  1. Add a new variant to the ArchKind enum
  2. Extend the detect_arch function with detection logic for the new architecture
  3. Ensure the detection is reliable across different model variants and quantization levels

The registration system enables automatic model loading without requiring user specification of the architecture type, improving usability.

Section sources

  • registry.rs

Example Implementation: Qwen3

The Qwen3 implementation serves as a template for adding new model architectures. It demonstrates how to wrap an existing quantized model implementation and conform it to the ModelBackend trait.

impl crate::models::common::model::ModelBackend for ModelWeights {
    fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String> {
        self.inner.forward(input, position).map_err(|e| e.to_string())
    }
}

:ModelWeights
The wrapper struct that contains the actual Qwen3 model from the candle-transformers crate. It provides a simplified interface to the underlying implementation.

:from_gguf
The constructor method that loads a model from a GGUF file:

pub fn from_gguf<R: Read + Seek>(
    content: Content, 
    reader: &mut R, 
    device: &Device, 
    _context_length: usize, 
    _flag: bool
) -> Result<Self, String>

Key aspects of the Qwen3 implementation:

  • Delegates to candle_transformers::models::quantized_qwen3::ModelWeights
  • Handles device placement and quantized tensor loading
  • Implements the required forward_layered method by delegating to the inner model's forward method
  • Translates errors from the underlying library into string messages

The implementation shows how to bridge between the system's expected interface and existing model implementations, minimizing duplication of low-level logic.

Section sources

  • qwen3.rs
  • quantized_qwen3.rs

Custom Layers and Attention Mechanisms

When implementing new architectures with custom layers or attention mechanisms, you can follow the patterns established in the Qwen3 implementation. The quantized_qwen3.rs file contains detailed implementations of various transformer components that can serve as templates.

Attention Implementation

The AttentionWeights struct implements multi-head attention with key features:

classDiagram
class AttentionWeights {
+q_proj : QMatMul
+k_proj : QMatMul
+v_proj : QMatMul
+o_proj : QMatMul
+q_norm : RmsNorm
+k_norm : RmsNorm
+num_heads : usize
+num_kv_heads : usize
+kv_cache : KvCache
+forward(x : &Tensor, mask : Option<&Tensor>, offset : usize) -> Result<Tensor>
}
class RotaryEmbedding {
+sin : Tensor
+cos : Tensor
+apply(q : &Tensor, k : &Tensor, offset : usize) -> Result<(Tensor, Tensor)>
}
AttentionWeights --> RotaryEmbedding : "uses"
AttentionWeights --> KvCache : "contains"
Loading

Diagram sources

  • quantized_qwen3.rs

Key components:

  • Q/K/V Projections: Quantized matrix multiplications for query, key, and value transformations
  • RMS Normalization: Applied to queries and keys before attention computation
  • KV Cache: Stores past key and value tensors for efficient autoregressive generation
  • RoPE (Rotary Position Embeddings): Applied to queries and keys to incorporate positional information

MLP Implementation

The MlpWeights struct implements the feed-forward network with gated activation:

struct MlpWeights {
    gate_proj: QMatMul,
    up_proj: QMatMul,
    down_proj: QMatMul,
    act_fn: Activation,
}

This implements a SwiGLU activation scheme where:

  • gate_proj and up_proj compute the gate and up projections
  • Element-wise multiplication creates the gated activation
  • down_proj maps back to the output dimension

When creating custom layers:

  1. Define a struct to hold the layer's weights and parameters
  2. Implement a new constructor that loads weights from the GGUF file
  3. Implement a forward method that performs the computation
  4. Use appropriate normalization, activation functions, and caching mechanisms

Section sources

  • quantized_qwen3.rs

Position Embedding Strategies

The system implements Rotary Position Embeddings (RoPE) for position encoding, which is particularly effective for transformer models. The RotaryEmbedding struct pre-computes sine and cosine frequency components for efficient application during inference.

struct RotaryEmbedding {
    sin: Tensor,
    cos: Tensor,
}

:new
Constructs rotary embeddings with:

  • dtype: Data type for the embeddings
  • head_dim: Dimension of each attention head
  • max_position_embeddings: Maximum sequence length supported
  • rope_theta: Frequency base for position encoding
  • dev: Target device for tensor allocation

:apply
Applies rotary embeddings to query and key tensors:

fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)>

The implementation follows the standard RoPE formula:

  • Pre-computes frequency components based on rope_theta
  • Generates position-specific sine and cosine tensors
  • Applies the rotation to queries and keys using the rope function from candle_nn

For models requiring different position embedding strategies:

  1. Create a new struct implementing the desired position encoding method
  2. Implement the same apply interface for consistency
  3. Integrate with the attention layer through the same interface
  4. Handle position offsets correctly for streaming/incremental inference

The current implementation supports variable context lengths through dynamic slicing of the pre-computed sinusoidal tables.

Section sources

  • quantized_qwen3.rs

Testing New Model Implementations

Effective testing of new model implementations requires multiple levels of validation to ensure correctness, performance, and compatibility.

Unit Testing

Test individual components in isolation:

  • Layer forward passes with known inputs
  • Weight loading from GGUF files
  • Position embedding application
  • Attention mechanism outputs

Example test structure:

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_attention_forward() {
        // Create test inputs and verify output shape and properties
    }
    
    #[test]
    fn test_rope_application() {
        // Verify rotary embeddings are applied correctly
    }
}

Integration Testing

Test the complete model workflow:

  1. Model loading from GGUF file
  2. Forward pass with sample input
  3. Output shape and type verification
  4. Consistency across multiple forward passes

Validation Strategy

  1. Output Comparison: Compare outputs with reference implementations
  2. Shape Verification: Ensure tensors have expected dimensions
  3. Error Handling: Test edge cases and invalid inputs
  4. Performance Benchmarking: Measure inference speed and memory usage

Recommended Test Cases

:Model Loading

  • Valid GGUF file with correct architecture
  • Corrupted or incomplete GGUF file
  • Unsupported quantization format
  • Missing required metadata

:Forward Pass

  • Single token input (streaming mode)
  • Sequence input (batch mode)
  • Maximum context length
  • Position offset continuity

:Edge Cases

  • Empty input
  • Invalid token IDs
  • Device transfer (CPU/GPU)
  • Multi-threaded access

Section sources

  • quantized_qwen3.rs

Step-by-Step Guide to Adding a New Architecture

Follow this comprehensive guide to add support for a new transformer architecture:

Step 1: Define the Architecture Type

Add a new variant to the ArchKind enum in registry.rs:

pub enum ArchKind {
    Qwen3,
    NewArchitecture,  // Add your new architecture
}

Step 2: Implement Architecture Detection

Extend the detect_arch function with detection logic:

fn detect_arch(metadata: &HashMap<String, Value>) -> Option<ArchKind> {
    for (_k, v) in metadata.iter() {
        if let Ok(s) = v.to_string() {
            if s.to_lowercase().contains("qwen") {
                return Some(ArchKind::Qwen3);
            }
            if s.to_lowercase().contains("newarch") {
                return Some(ArchKind::NewArchitecture);
            }
        }
    }
    None
}

Step 3: Create Model Wrapper

Create a new .rs file in the models directory (e.g., newarch.rs) and implement the ModelBackend trait:

use std::io::{Read, Seek};
use candle::Device;
use candle::Tensor;
use crate::models::common::model::ModelBackend;

pub struct ModelWeights {
    inner: YourUnderlyingModelImplementation,
}

impl ModelWeights {
    pub fn from_gguf<R: Read + Seek>(
        content: Content, 
        reader: &mut R, 
        device: &Device,
        _context_length: usize, 
        _flag: bool
    ) -> Result<Self, String> {
        // Implement model loading logic
    }
}

impl ModelBackend for ModelWeights {
    fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String> {
        self.inner.forward(input, position).map_err(|e| e.to_string())
    }
}

Step 4: Implement Core Components

Based on your architecture, implement:

  1. Embedding Layer: Token and position embeddings
  2. Transformer Blocks: Attention and feed-forward layers
  3. Normalization: RMSNorm or LayerNorm as required
  4. Head: Language modeling head for output

Follow the Qwen3 implementation as a template for organizing these components.

Step 5: Register the Model

Add the new model to the models/mod.rs file:

pub mod common;
pub mod qwen3;
pub mod newarch;  // Add your model
pub mod registry;

Step 6: Implement Testing

Create comprehensive tests for your implementation:

  1. Unit tests for individual layers
  2. Integration tests for the complete model
  3. Performance benchmarks
  4. Edge case handling

Step 7: Validate Implementation

Verify your implementation by:

  1. Loading a real model file
  2. Running forward passes with sample inputs
  3. Comparing outputs with reference implementations
  4. Testing in the full application context

Step 8: Documentation

Document your implementation with:

  • Code comments explaining key design decisions
  • Usage examples
  • Performance characteristics
  • Known limitations

This systematic approach ensures new architectures are integrated consistently and reliably into the system.

Section sources

  • model.rs
  • registry.rs
  • qwen3.rs
  • quantized_qwen3.rs

Referenced Files in This Document

  • model.rs
  • registry.rs
  • qwen3.rs
  • quantized_qwen3.rs

Clone this wiki locally