-
Notifications
You must be signed in to change notification settings - Fork 0
24.6. Extending With New Model Architectures
- ModelBackend Trait Interface
- Model Registration and Detection
- Example Implementation: Qwen3
- Custom Layers and Attention Mechanisms
- Position Embedding Strategies
- Testing New Model Implementations
- Step-by-Step Guide to Adding a New Architecture
The ModelBackend trait defines the core interface that all model implementations must conform to in order to be integrated into the system. This trait ensures a consistent API for forward computation across different model architectures.
pub trait ModelBackend: Send {
fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String>;
}This trait requires implementation of a single method:
:forward_layered
-
input: A reference to a
Tensorcontaining the input token embeddings - position: The current position in the sequence (used for KV cache management and position embeddings)
- Returns a
Result<Tensor, String>representing the output logits or an error message
The trait is designed to support stateful models that maintain internal caches (e.g., KV caches for autoregressive generation). Implementations are expected to handle incremental decoding where the model processes one token at a time during generation.
The AnyModel wrapper struct provides a type-erased interface to any model that implements ModelBackend, allowing the system to work uniformly with different model types through dynamic dispatch.
Section sources
- model.rs
Model registration is handled through the architecture detection system in registry.rs. The system uses metadata from model files to automatically detect the appropriate architecture type and instantiate the correct model implementation.
pub enum ArchKind {
Qwen3,
}
pub fn detect_arch(metadata: &HashMap<String, candle::quantized::gguf_file::Value>) -> Option<ArchKind> {
for (_k, v) in metadata.iter() {
if let Ok(s) = v.to_string() {
if s.to_lowercase().contains("qwen") {
return Some(ArchKind::Qwen3);
}
}
}
None
}:ArchKind
An enumeration that lists all supported model architectures. Currently only Qwen3 is supported, but new variants should be added when implementing additional architectures.
:detect_arch
A heuristic-based function that examines GGUF metadata to identify the model architecture. It performs case-insensitive string matching on metadata values to detect architecture indicators.
To register a new model architecture:
- Add a new variant to the
ArchKindenum - Extend the
detect_archfunction with detection logic for the new architecture - Ensure the detection is reliable across different model variants and quantization levels
The registration system enables automatic model loading without requiring user specification of the architecture type, improving usability.
Section sources
- registry.rs
The Qwen3 implementation serves as a template for adding new model architectures. It demonstrates how to wrap an existing quantized model implementation and conform it to the ModelBackend trait.
impl crate::models::common::model::ModelBackend for ModelWeights {
fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String> {
self.inner.forward(input, position).map_err(|e| e.to_string())
}
}:ModelWeights
The wrapper struct that contains the actual Qwen3 model from the candle-transformers crate. It provides a simplified interface to the underlying implementation.
:from_gguf
The constructor method that loads a model from a GGUF file:
pub fn from_gguf<R: Read + Seek>(
content: Content,
reader: &mut R,
device: &Device,
_context_length: usize,
_flag: bool
) -> Result<Self, String>Key aspects of the Qwen3 implementation:
- Delegates to
candle_transformers::models::quantized_qwen3::ModelWeights - Handles device placement and quantized tensor loading
- Implements the required
forward_layeredmethod by delegating to the inner model'sforwardmethod - Translates errors from the underlying library into string messages
The implementation shows how to bridge between the system's expected interface and existing model implementations, minimizing duplication of low-level logic.
Section sources
- qwen3.rs
- quantized_qwen3.rs
When implementing new architectures with custom layers or attention mechanisms, you can follow the patterns established in the Qwen3 implementation. The quantized_qwen3.rs file contains detailed implementations of various transformer components that can serve as templates.
The AttentionWeights struct implements multi-head attention with key features:
classDiagram
class AttentionWeights {
+q_proj : QMatMul
+k_proj : QMatMul
+v_proj : QMatMul
+o_proj : QMatMul
+q_norm : RmsNorm
+k_norm : RmsNorm
+num_heads : usize
+num_kv_heads : usize
+kv_cache : KvCache
+forward(x : &Tensor, mask : Option<&Tensor>, offset : usize) -> Result<Tensor>
}
class RotaryEmbedding {
+sin : Tensor
+cos : Tensor
+apply(q : &Tensor, k : &Tensor, offset : usize) -> Result<(Tensor, Tensor)>
}
AttentionWeights --> RotaryEmbedding : "uses"
AttentionWeights --> KvCache : "contains"
Diagram sources
- quantized_qwen3.rs
Key components:
- Q/K/V Projections: Quantized matrix multiplications for query, key, and value transformations
- RMS Normalization: Applied to queries and keys before attention computation
- KV Cache: Stores past key and value tensors for efficient autoregressive generation
- RoPE (Rotary Position Embeddings): Applied to queries and keys to incorporate positional information
The MlpWeights struct implements the feed-forward network with gated activation:
struct MlpWeights {
gate_proj: QMatMul,
up_proj: QMatMul,
down_proj: QMatMul,
act_fn: Activation,
}This implements a SwiGLU activation scheme where:
-
gate_projandup_projcompute the gate and up projections - Element-wise multiplication creates the gated activation
-
down_projmaps back to the output dimension
When creating custom layers:
- Define a struct to hold the layer's weights and parameters
- Implement a
newconstructor that loads weights from the GGUF file - Implement a
forwardmethod that performs the computation - Use appropriate normalization, activation functions, and caching mechanisms
Section sources
- quantized_qwen3.rs
The system implements Rotary Position Embeddings (RoPE) for position encoding, which is particularly effective for transformer models. The RotaryEmbedding struct pre-computes sine and cosine frequency components for efficient application during inference.
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}:new
Constructs rotary embeddings with:
- dtype: Data type for the embeddings
- head_dim: Dimension of each attention head
- max_position_embeddings: Maximum sequence length supported
- rope_theta: Frequency base for position encoding
- dev: Target device for tensor allocation
:apply
Applies rotary embeddings to query and key tensors:
fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)>The implementation follows the standard RoPE formula:
- Pre-computes frequency components based on
rope_theta - Generates position-specific sine and cosine tensors
- Applies the rotation to queries and keys using the
ropefunction fromcandle_nn
For models requiring different position embedding strategies:
- Create a new struct implementing the desired position encoding method
- Implement the same
applyinterface for consistency - Integrate with the attention layer through the same interface
- Handle position offsets correctly for streaming/incremental inference
The current implementation supports variable context lengths through dynamic slicing of the pre-computed sinusoidal tables.
Section sources
- quantized_qwen3.rs
Effective testing of new model implementations requires multiple levels of validation to ensure correctness, performance, and compatibility.
Test individual components in isolation:
- Layer forward passes with known inputs
- Weight loading from GGUF files
- Position embedding application
- Attention mechanism outputs
Example test structure:
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_forward() {
// Create test inputs and verify output shape and properties
}
#[test]
fn test_rope_application() {
// Verify rotary embeddings are applied correctly
}
}Test the complete model workflow:
- Model loading from GGUF file
- Forward pass with sample input
- Output shape and type verification
- Consistency across multiple forward passes
- Output Comparison: Compare outputs with reference implementations
- Shape Verification: Ensure tensors have expected dimensions
- Error Handling: Test edge cases and invalid inputs
- Performance Benchmarking: Measure inference speed and memory usage
:Model Loading
- Valid GGUF file with correct architecture
- Corrupted or incomplete GGUF file
- Unsupported quantization format
- Missing required metadata
:Forward Pass
- Single token input (streaming mode)
- Sequence input (batch mode)
- Maximum context length
- Position offset continuity
:Edge Cases
- Empty input
- Invalid token IDs
- Device transfer (CPU/GPU)
- Multi-threaded access
Section sources
- quantized_qwen3.rs
Follow this comprehensive guide to add support for a new transformer architecture:
Add a new variant to the ArchKind enum in registry.rs:
pub enum ArchKind {
Qwen3,
NewArchitecture, // Add your new architecture
}Extend the detect_arch function with detection logic:
fn detect_arch(metadata: &HashMap<String, Value>) -> Option<ArchKind> {
for (_k, v) in metadata.iter() {
if let Ok(s) = v.to_string() {
if s.to_lowercase().contains("qwen") {
return Some(ArchKind::Qwen3);
}
if s.to_lowercase().contains("newarch") {
return Some(ArchKind::NewArchitecture);
}
}
}
None
}Create a new .rs file in the models directory (e.g., newarch.rs) and implement the ModelBackend trait:
use std::io::{Read, Seek};
use candle::Device;
use candle::Tensor;
use crate::models::common::model::ModelBackend;
pub struct ModelWeights {
inner: YourUnderlyingModelImplementation,
}
impl ModelWeights {
pub fn from_gguf<R: Read + Seek>(
content: Content,
reader: &mut R,
device: &Device,
_context_length: usize,
_flag: bool
) -> Result<Self, String> {
// Implement model loading logic
}
}
impl ModelBackend for ModelWeights {
fn forward_layered(&mut self, input: &Tensor, position: usize) -> Result<Tensor, String> {
self.inner.forward(input, position).map_err(|e| e.to_string())
}
}Based on your architecture, implement:
- Embedding Layer: Token and position embeddings
- Transformer Blocks: Attention and feed-forward layers
- Normalization: RMSNorm or LayerNorm as required
- Head: Language modeling head for output
Follow the Qwen3 implementation as a template for organizing these components.
Add the new model to the models/mod.rs file:
pub mod common;
pub mod qwen3;
pub mod newarch; // Add your model
pub mod registry;Create comprehensive tests for your implementation:
- Unit tests for individual layers
- Integration tests for the complete model
- Performance benchmarks
- Edge case handling
Verify your implementation by:
- Loading a real model file
- Running forward passes with sample inputs
- Comparing outputs with reference implementations
- Testing in the full application context
Document your implementation with:
- Code comments explaining key design decisions
- Usage examples
- Performance characteristics
- Known limitations
This systematic approach ensures new architectures are integrated consistently and reliably into the system.
Section sources
- model.rs
- registry.rs
- qwen3.rs
- quantized_qwen3.rs
Referenced Files in This Document
- model.rs
- registry.rs
- qwen3.rs
- quantized_qwen3.rs