Skip to content

Feature/tied decoders#38

Merged
curt-tigges merged 5 commits intomainfrom
feature/tied-decoders
Jul 1, 2025
Merged

Feature/tied decoders#38
curt-tigges merged 5 commits intomainfrom
feature/tied-decoders

Conversation

@curt-tigges
Copy link
Copy Markdown
Owner

No description provided.

@curt-tigges
Copy link
Copy Markdown
Owner Author

Tied Decoders Implementation Update

This document details the implementation of tied decoders for Cross-Layer Transcoders (CLTs), which significantly reduces the parameter count while maintaining model performance.

Overview

Tied decoders allow sharing of decoder weights across layers instead of having separate decoders for each (source, destination) layer pair. This can reduce the number of decoders from O(L²) to O(L), where L is the number of layers.

Configuration Changes

CLTConfig Updates (clt/config/clt_config.py)

Added new configuration parameters to support tied decoders:

# Tied decoder configuration
decoder_tying: Literal["none", "per_source", "per_target"] = "none"  # Decoder weight sharing strategy
enable_feature_offset: bool = False  # Enable per-feature bias (feature_offset)
enable_feature_scale: bool = False  # Enable per-feature scale (feature_scale)
skip_connection: bool = False  # Enable skip connection from input to output

Decoder Tying Options

  • "none" (default): Traditional untied decoders with separate decoder for each (source, destination) layer pair
  • "per_source": One decoder per source layer, shared across all destination layers
  • "per_target": One decoder per destination layer, shared across all source layers

Additional Features

  • enable_feature_offset: Adds learnable per-feature bias terms (currently only supported for tied decoders)
  • enable_feature_scale: Adds learnable per-feature scaling (currently only supported for tied decoders)
  • skip_connection: Enables skip connections from source inputs to decoder outputs

Validation

Added validation for the new normalization_method field to ensure it's one of the valid options: ["none", "mean_std", "sqrt_d_model"]

Model Changes

CrossLayerTranscoder (clt/models/clt.py)

  1. Updated decode method signature:

    def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor:
    • Added source_inputs parameter to support skip connections
  2. Refactored skip connection handling:

    • Removed _apply_skip_connection method from CLT
    • Skip connections are now handled entirely within the decoder module
    • Source inputs are passed through to the decoder when skip connections are enabled
  3. Updated forward method:

    • Now passes source inputs to the decode method when skip connections are enabled
    • Creates a dictionary of source inputs for layers up to the current target layer
  4. Updated load_state_dict method:

    • Removed references to non-existent per_target_scale and per_target_bias parameters
    • Added handling for skip weights initialization
    • Maintains backward compatibility with old checkpoint formats

Decoder Module (clt/models/decoder.py)

  1. Skip Connection Implementation:

    • For tied decoders: One skip connection weight matrix per target layer
    • For untied decoders: One skip connection weight matrix per (source, target) pair
    • Skip weights initialized to zeros (not identity)
  2. Feature Affine Parameters:

    • feature_offset and feature_scale are only created for tied decoders
    • Indexed by target layer
    • Feature scale initialization:
      • First layer: initialized to ones
      • Other layers: initialized to 0.1 (to allow gradient flow)
  3. Updated decode method:

    • Added source_inputs parameter
    • Skip connections are applied after decoder transformations
    • For each source input, applies: source @ W_skip^T
    • Supports both tied and untied decoder architectures

Skip Connection Details

The skip connection implementation follows these principles:

  1. Weight Matrices:

    • Shape: (d_model, d_model)
    • Initialized to zeros
    • Learnable parameters
  2. Application:

    • Applied after decoder transformation
    • Each source layer's input is transformed by its corresponding skip weight
    • All skip contributions are summed with the decoder output
  3. Memory Efficiency:

    • Tied decoders: L skip weight matrices
    • Untied decoders: L(L+1)/2 skip weight matrices

Performance Implications

Parameter Reduction

For a model with L layers:

  • Untied decoders: L(L+1)/2 decoders (e.g., 78 for GPT-2 small with 12 layers)
  • Per-source tying: L decoders (e.g., 12 for GPT-2 small)
  • Per-target tying: L decoders (e.g., 12 for GPT-2 small)

This represents an 84.6% reduction in decoder parameters for a 12-layer model.

Feature Affine Transformations

The feature offset and scale parameters add minimal overhead:

  • Per layer: 2 * num_features parameters
  • Total: 2 * L * num_features additional parameters

Usage Examples

Training with Per-Source Tied Decoders

python scripts/train_clt.py \
    --activation-source local_manifest \
    --activation-path ./activations/gpt2 \
    --output-dir ./output_tied \
    --model-name gpt2 \
    --num-features 6144 \
    --decoder-tying per_source \
    --enable-feature-scale \
    --skip-connection \
    --activation-fn batchtopk \
    --batchtopk-k 256 \
    --learning-rate 3e-4 \
    --training-steps 100000

Training with Per-Target Tied Decoders

python scripts/train_clt.py \
    --activation-source local_manifest \
    --activation-path ./activations/gpt2 \
    --output-dir ./output_tied \
    --model-name gpt2 \
    --num-features 6144 \
    --decoder-tying per_target \
    --enable-feature-offset \
    --enable-feature-scale \
    --activation-fn jumprelu \
    --learning-rate 3e-4 \
    --training-steps 100000

Backward Compatibility

The implementation maintains full backward compatibility:

  1. Default behavior: When decoder_tying="none", the model behaves identically to the previous implementation
  2. Checkpoint loading: Old checkpoints can be loaded into new models with appropriate conversion
  3. Configuration: Old configs without the new fields will use sensible defaults

Testing Updates

Fixed Tests (tests/unit/models/test_tied_decoders.py)

  1. Removed references to non-existent per_target_scale and per_target_bias parameters
  2. Updated skip weight initialization expectations (zeros instead of identity)
  3. Fixed feature affine parameter tests to only expect them for tied decoders
  4. Updated feature scale initialization expectations (0.1 for non-first layers)

Test Coverage

The implementation includes comprehensive tests for:

  • Decoder initialization (tied vs untied)
  • Skip connection functionality
  • Feature affine parameters
  • Decoding with tied decoders
  • Decoder norm computation
  • Backward compatibility
  • Checkpoint loading and conversion

Implementation Notes

  1. Tensor Parallelism: The implementation is compatible with distributed training and tensor parallelism
  2. Memory Efficiency: Tied decoders significantly reduce memory usage for large models
  3. Gradient Flow: Feature scale initialization to 0.1 (instead of 0) ensures proper gradient flow
  4. Flexibility: The architecture supports future extensions like per-layer tying strategies

Future Considerations

  1. Mixed Tying Strategies: Could support tying only certain layer ranges
  2. Dynamic Tying: Could implement learnable tying weights
  3. Hierarchical Tying: Could tie decoders based on layer similarity
  4. Feature Affine for Untied: Could extend feature affine support to untied decoders

@curt-tigges curt-tigges merged commit 0909d36 into main Jul 1, 2025
1 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant