diff --git a/README.md b/README.md index 6509556..5e95b7b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ This library is intended for the training and analysis of cross-layer sparse cod A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for `gpt2-small`). This implementation focuses on the core functionality needed to train and use CLTs, leveraging `nnsight` for model introspection and `datasets` for data handling. +The library now supports **tied decoders**, which can significantly reduce the number of parameters by sharing decoder weights across layers. Instead of training separate decoders for each (source, destination) pair, tied decoders use either: +- **Per-source tying**: One decoder per source layer, shared across all destination layers +- **Per-target tying**: One decoder per destination layer, shared across all source layers + Training a CLT involves the following steps: 1. Pre-generate activations with `scripts/generate_activations` (though an implementation of `StreamingActivationStore` is on the way). 2. Train a CLT (start with an expansion factor of at least `32`) using this data. Metrics can be logged to WandB. NMSE should get below `0.25`, or ideally even below `0.10`. As mentioned above, I recommend `BatchTopK` training, and suggest keeping `K` low--`200` is a good place to start. @@ -85,6 +89,16 @@ Key configuration parameters are mapped to config classes via script arguments: - `relu`: Standard ReLU activation. - `batchtopk`: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted to `jumprelu`. - `topk`: Selects top K features per token (row-wise top-k). + + **Decoder Tying Options** (`--decoder-tying`): + - `none` (default): Traditional untied decoders - separate decoder for each (source, destination) layer pair + - `per_source`: Share decoder weights per source layer - each source layer has one decoder used for all destinations + - `per_target`: Share decoder weights per destination layer - each destination layer has one decoder that combines features from all source layers + + **Additional Tied Decoder Features**: + - `--enable-feature-offset`: Add learnable per-feature bias terms + - `--enable-feature-scale`: Add learnable per-feature scaling + - `--skip-connection`: Enable skip connections from source inputs to decoder outputs - **TrainingConfig**: `--learning-rate`, `--training-steps`, `--train-batch-size-tokens`, `--activation-source`, `--activation-path` (for `local_manifest`), remote config fields (for `remote`, e.g. `--server-url`, `--dataset-id`), `--normalization-method`, `--sparsity-lambda`, `--preactivation-coef`, `--optimizer`, `--lr-scheduler`, `--log-interval`, `--eval-interval`, `--checkpoint-interval`, `--dead-feature-window`, WandB settings (`--enable-wandb`, `--wandb-project`, etc.). ### Single GPU Training Examples @@ -139,6 +153,38 @@ python scripts/train_clt.py \\ # Add other arguments as needed ``` +**Example: Training with Tied Decoders** + +Tied decoders can significantly reduce the parameter count while maintaining performance. Here's an example using per-source tying: + +```bash +python scripts/train_clt.py \ + --activation-source local_manifest \ + --activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \ + --output-dir ./clt_output_tied \ + --model-name gpt2 \ + --num-features 6144 \ + --decoder-tying per_source \ + --enable-feature-scale \ + --skip-connection \ + --activation-fn batchtopk \ + --batchtopk-k 256 \ + --learning-rate 3e-4 \ + --training-steps 100000 \ + --train-batch-size-tokens 8192 \ + --sparsity-lambda 1e-3 \ + --log-interval 100 \ + --eval-interval 1000 \ + --checkpoint-interval 5000 \ + --enable-wandb --wandb-project clt_tied_training +``` + +This configuration: +- Uses `per_source` tying: 12 decoders instead of 78 for gpt2-small +- Enables feature scaling for better expressiveness +- Includes skip connections to preserve input information +- Uses BatchTopK with k=256 for training (can be converted to JumpReLU later) + ### Multi-GPU Training (Tensor Parallelism) This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (`torch.distributed`). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation. diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index 3e0d577..1de0141 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -15,7 +15,7 @@ class CLTConfig: num_layers: int # Number of transformer layers d_model: int # Dimension of model's hidden state model_name: Optional[str] = None # Optional name for the underlying model - normalization_method: Literal["auto", "estimated_mean_std", "none"] = ( + normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = ( "none" # How activations were normalized during training ) activation_fn: Literal["jumprelu", "relu", "batchtopk", "topk"] = "jumprelu" @@ -27,6 +27,8 @@ class CLTConfig: topk_k: Optional[float] = None # Number or fraction of features to keep per token for TopK. # If < 1, treated as fraction. If >= 1, treated as int count. topk_straight_through: bool = True # Whether to use straight-through estimator for TopK. + # Top-K mode selection + topk_mode: Literal["global", "per_layer"] = "global" # How to apply top-k selection clt_dtype: Optional[str] = None # Optional dtype for the CLT model itself (e.g., "float16") expected_input_dtype: Optional[str] = None # Expected dtype of input activations mlp_input_template: Optional[str] = None # Module path template for MLP input activations @@ -34,13 +36,19 @@ class CLTConfig: tl_input_template: Optional[str] = None # TransformerLens hook point pattern before MLP tl_output_template: Optional[str] = None # TransformerLens hook point pattern after MLP # context_size: Optional[int] = None + + # Tied decoder configuration + decoder_tying: Literal["none", "per_source", "per_target"] = "none" # Decoder weight sharing strategy + enable_feature_offset: bool = False # Enable per-feature bias (feature_offset) + enable_feature_scale: bool = False # Enable per-feature scale (feature_scale) + skip_connection: bool = False # Enable skip connection from input to output def __post_init__(self): """Validate configuration parameters.""" assert self.num_features > 0, "Number of features must be positive" assert self.num_layers > 0, "Number of layers must be positive" assert self.d_model > 0, "Model dimension must be positive" - valid_norm_methods = ["auto", "estimated_mean_std", "none"] + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] assert ( self.normalization_method in valid_norm_methods ), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}" @@ -60,6 +68,12 @@ def __post_init__(self): raise ValueError("topk_k must be specified for TopK activation function.") if self.topk_k is not None and self.topk_k <= 0: raise ValueError("topk_k must be positive if specified.") + + # Validate decoder tying configuration + valid_decoder_tying = ["none", "per_source", "per_target"] + assert ( + self.decoder_tying in valid_decoder_tying + ), f"Invalid decoder_tying: {self.decoder_tying}. Must be one of {valid_decoder_tying}" @classmethod def from_json(cls: Type[C], json_path: str) -> C: @@ -73,6 +87,30 @@ def from_json(cls: Type[C], json_path: str) -> C: """ with open(json_path, "r") as f: config_dict = json.load(f) + + # Handle backward compatibility for old configs + if "decoder_tying" not in config_dict: + config_dict["decoder_tying"] = "none" # Default to original behavior + if "enable_feature_offset" not in config_dict: + config_dict["enable_feature_offset"] = False + if "enable_feature_scale" not in config_dict: + config_dict["enable_feature_scale"] = False + + # Handle backwards compatibility for old normalization methods + if "normalization_method" in config_dict: + old_method = config_dict["normalization_method"] + # Map old values to new ones + if old_method in ["auto", "estimated_mean_std"]: + config_dict["normalization_method"] = "mean_std" + elif old_method in ["auto_sqrt_d_model", "estimated_mean_std_sqrt_d_model"]: + config_dict["normalization_method"] = "sqrt_d_model" + + # Handle old sqrt_d_model_normalize flag + if "sqrt_d_model_normalize" in config_dict: + sqrt_normalize = config_dict.pop("sqrt_d_model_normalize") + if sqrt_normalize: + config_dict["normalization_method"] = "sqrt_d_model" + return cls(**config_dict) def to_json(self, json_path: str) -> None: @@ -108,11 +146,11 @@ class TrainingConfig: debug_anomaly: bool = False # Normalization parameters - normalization_method: Literal["auto", "estimated_mean_std", "none"] = "auto" - # 'auto': Use pre-calculated from mapped store, or estimate for streaming store. - # 'estimated_mean_std': Always estimate for streaming store (ignored for mapped). - # 'none': Disable normalization. - normalization_estimation_batches: int = 50 # Batches for normalization estimation + normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = "mean_std" + # 'none': No normalization. + # 'mean_std': Standard (x - mean) / std normalization using pre-calculated stats. + # 'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization. + normalization_estimation_batches: int = 50 # Batches for normalization estimation (if needed) # --- Activation Store Source --- # activation_source: Literal["local_manifest", "remote"] = "local_manifest" @@ -221,6 +259,12 @@ def __post_init__(self): assert ( 0.0 <= self.sparsity_lambda_delay_frac < 1.0 ), "sparsity_lambda_delay_frac must be between 0.0 (inclusive) and 1.0 (exclusive)" + + # Validate normalization method + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] + assert ( + self.normalization_method in valid_norm_methods + ), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}" @dataclass diff --git a/clt/models/clt.py b/clt/models/clt.py index 5ff9f4a..ff27f01 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -179,17 +179,24 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: ) return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype) - def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: - return self.decoder_module.decode(a, layer_idx) + + def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor: + return self.decoder_module.decode(a, layer_idx, source_inputs) def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: activations = self.get_feature_activations(inputs) + + # Note: feature affine transformations are now applied in the decoder reconstructions = {} for layer_idx in range(self.config.num_layers): relevant_activations = {k: v for k, v in activations.items() if k <= layer_idx and v.numel() > 0} if layer_idx in inputs and relevant_activations: - reconstructions[layer_idx] = self.decode(relevant_activations, layer_idx) + # Pass source inputs for EleutherAI-style skip connections + source_inputs = {k: inputs[k] for k in range(layer_idx + 1) if k in inputs} if self.config.skip_connection else None + reconstruction = self.decode(relevant_activations, layer_idx, source_inputs) + + reconstructions[layer_idx] = reconstruction elif layer_idx in inputs: batch_size = 0 input_tensor = inputs[layer_idx] @@ -216,6 +223,17 @@ def get_feature_activations(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, processed_inputs[layer_idx] = x_orig.to(device=self.device, dtype=self.dtype) if self.config.activation_fn == "batchtopk" or self.config.activation_fn == "topk": + # Check if we should use per-layer mode + if self.config.topk_mode == "per_layer": + # Use per-layer top-k by calling encode on each layer + activations = {} + for layer_idx in sorted(processed_inputs.keys()): + x_input = processed_inputs[layer_idx] + act = self.encode(x_input, layer_idx) + activations[layer_idx] = act + return activations + + # Otherwise use global top-k preactivations_dict, _ = self._encode_all_layers(processed_inputs) if not preactivations_dict: activations = {} @@ -325,3 +343,83 @@ def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None: if not hasattr(self, "theta_manager") or self.theta_manager is None: raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.") self.theta_manager.log_threshold = new_param + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True): + """Load state dict with backward compatibility for old checkpoints. + + Handles: + 1. Old untied decoder format -> new tied/untied format + 2. Missing theta_bias/theta_scale parameters + 3. Missing per_target_scale/per_target_bias parameters + """ + # Check if this is an old checkpoint by looking for decoder keys + old_format_decoder_keys = [k for k in state_dict.keys() if 'decoders.' in k and '->' in k] + is_old_checkpoint = len(old_format_decoder_keys) > 0 + + if is_old_checkpoint and self.config.decoder_tying == "per_source": + logger.warning( + "Loading old untied decoder checkpoint into tied decoder model. " + "This will use weights from the first target layer for each source layer." + ) + + # Convert old decoder weights to tied format + # For each source layer, use the weights from src->src decoder + new_state_dict = {} + for key, value in state_dict.items(): + if 'decoders.' in key and '->' in key: + # Extract source and target layer indices + # Key format: "decoder_module.decoders.{src}->{tgt}.weight" or ".bias" + parts = key.split('.') + decoder_key_idx = parts.index('decoders') + 1 + src_tgt = parts[decoder_key_idx].split('->') + src_layer = int(src_tgt[0]) + tgt_layer = int(src_tgt[1]) + param_type = parts[-1] # 'weight' or 'bias' + + # Only use diagonal decoders (src->src) for tied architecture + if src_layer == tgt_layer: + new_key = '.'.join(parts[:decoder_key_idx] + [str(src_layer), param_type]) + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + state_dict = new_state_dict + + # Handle feature affine parameters migration from encoder to decoder module + # (for backward compatibility with old checkpoints) + for i in range(self.config.num_layers): + old_offset_key = f"encoder_module.feature_offset.{i}" + new_offset_key = f"decoder_module.feature_offset.{i}" + if old_offset_key in state_dict and new_offset_key not in state_dict: + logger.info(f"Migrating {old_offset_key} to {new_offset_key}") + state_dict[new_offset_key] = state_dict.pop(old_offset_key) + + old_scale_key = f"encoder_module.feature_scale.{i}" + new_scale_key = f"decoder_module.feature_scale.{i}" + if old_scale_key in state_dict and new_scale_key not in state_dict: + logger.info(f"Migrating {old_scale_key} to {new_scale_key}") + state_dict[new_scale_key] = state_dict.pop(old_scale_key) + + # Handle missing feature affine parameters (now in decoder module) + if self.config.enable_feature_offset and hasattr(self.decoder_module, 'feature_offset') and self.decoder_module.feature_offset is not None: + for i in range(self.config.num_layers): + key = f"decoder_module.feature_offset.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} to zeros") + # Don't add to state_dict to let it be initialized by the module + + if self.config.enable_feature_scale and hasattr(self.decoder_module, 'feature_scale') and self.decoder_module.feature_scale is not None: + for i in range(self.config.num_layers): + key = f"decoder_module.feature_scale.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} (first target layer to ones, rest to zeros)") + # Don't add to state_dict to let it be initialized by the module + + # Handle missing skip weights + if self.config.skip_connection and hasattr(self.decoder_module, 'skip_weights'): + for i in range(self.config.num_layers): + key = f"decoder_module.skip_weights.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} to identity matrix") + + # Call parent's load_state_dict + return super().load_state_dict(state_dict, strict=strict) diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 68fd5c5..f6e9060 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -38,9 +38,11 @@ def __init__( self.world_size = dist_ops.get_world_size(process_group) self.rank = dist_ops.get_rank(process_group) - self.decoders = nn.ModuleDict( - { - f"{src_layer}->{tgt_layer}": RowParallelLinear( + # Initialize decoders based on tying configuration + if config.decoder_tying == "per_source": + # Tied decoders: one decoder per source layer + self.decoders = nn.ModuleList([ + RowParallelLinear( in_features=self.config.num_features, out_features=self.config.d_model, bias=True, @@ -51,13 +53,114 @@ def __init__( device=self.device, dtype=self.dtype, ) - for src_layer in range(self.config.num_layers) - for tgt_layer in range(src_layer, self.config.num_layers) - } - ) + for _ in range(self.config.num_layers) + ]) + elif config.decoder_tying == "per_target": + # Tied decoders: one decoder per target layer (EleutherAI style) + self.decoders = nn.ModuleList([ + RowParallelLinear( + in_features=self.config.num_features, + out_features=self.config.d_model, + bias=True, + process_group=self.process_group, + input_is_parallel=False, + d_model_for_init=self.config.d_model, + num_layers_for_init=self.config.num_layers, + device=self.device, + dtype=self.dtype, + ) + for _ in range(self.config.num_layers) + ]) + + # Initialize decoder weights to zeros for tied decoders (both per_source and per_target) + if config.decoder_tying in ["per_source", "per_target"]: + for decoder in self.decoders: + nn.init.zeros_(decoder.weight) + if hasattr(decoder, 'bias_param') and decoder.bias_param is not None: + nn.init.zeros_(decoder.bias_param) + elif hasattr(decoder, 'bias') and decoder.bias is not None: + nn.init.zeros_(decoder.bias) + + # Note: EleutherAI doesn't have per-target scale/bias parameters + # These have been removed to match their architecture exactly + else: + # Original untied decoders: one decoder per (src, tgt) pair + self.decoders = nn.ModuleDict( + { + f"{src_layer}->{tgt_layer}": RowParallelLinear( + in_features=self.config.num_features, + out_features=self.config.d_model, + bias=True, + process_group=self.process_group, + input_is_parallel=False, + d_model_for_init=self.config.d_model, + num_layers_for_init=self.config.num_layers, + device=self.device, + dtype=self.dtype, + ) + for src_layer in range(self.config.num_layers) + for tgt_layer in range(src_layer, self.config.num_layers) + } + ) + # Note: EleutherAI doesn't have per-target scale/bias parameters + + # Initialize skip connection weights if enabled + if config.skip_connection: + if config.decoder_tying in ["per_source", "per_target"]: + # For tied decoders, one skip connection per target layer + self.skip_weights = nn.ParameterList([ + nn.Parameter(torch.zeros(self.config.d_model, self.config.d_model, + device=self.device, dtype=self.dtype)) + for _ in range(self.config.num_layers) + ]) + else: + # For untied decoders, one skip connection per src->tgt pair + self.skip_weights = nn.ParameterDict({ + f"{src_layer}->{tgt_layer}": nn.Parameter( + torch.zeros(self.config.d_model, self.config.d_model, + device=self.device, dtype=self.dtype) + ) + for src_layer in range(self.config.num_layers) + for tgt_layer in range(src_layer, self.config.num_layers) + }) + else: + self.skip_weights = None + + # Initialize feature_offset and feature_scale (indexed by target layer) + # These match EleutherAI's post_enc and post_enc_scale + # Note: Currently only implemented for tied decoders to match EleutherAI + # For per_source tying, these would need to be indexed differently + if config.decoder_tying in ["per_source", "per_target"]: + features_per_rank = config.num_features // self.world_size if self.world_size > 1 else config.num_features + + if config.enable_feature_offset: + # Initialize feature_offset for each target layer + self.feature_offset = nn.ParameterList([ + nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype)) + for _ in range(config.num_layers) + ]) + else: + self.feature_offset = None + + if config.enable_feature_scale: + # Initialize feature_scale for each target layer + # First target layer gets ones, rest get small non-zero values to allow gradient flow + self.feature_scale = nn.ParameterList([ + nn.Parameter( + torch.ones(features_per_rank, device=self.device, dtype=self.dtype) if i == 0 + else torch.full((features_per_rank,), 0.1, device=self.device, dtype=self.dtype) + ) + for i in range(config.num_layers) + ]) + else: + self.feature_scale = None + else: + self.feature_offset = None + self.feature_scale = None + self.register_buffer("_cached_decoder_norms", None, persistent=False) - def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: + def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor: """Decode the feature activations to reconstruct outputs at the specified layer. Input activations `a` are expected to be the *full* tensors. @@ -87,21 +190,150 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: reconstruction = torch.zeros((batch_dim_size, self.config.d_model), device=self.device, dtype=self.dtype) - for src_layer in range(layer_idx + 1): - if src_layer in a: - activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) + if self.config.decoder_tying == "per_target": + # EleutherAI style: sum activations first, then decode once + summed_activation = torch.zeros((batch_dim_size, self.config.num_features), device=self.device, dtype=self.dtype) + + for src_layer in range(layer_idx + 1): + if src_layer in a: + activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) - if activation_tensor.numel() == 0: - continue - if activation_tensor.shape[-1] != self.config.num_features: - logger.warning( - f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." - ) - continue + if activation_tensor.numel() == 0: + continue + if activation_tensor.shape[-1] != self.config.num_features: + logger.warning( + f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." + ) + continue + + # Apply feature affine transformations (indexed by target layer) + # EleutherAI only applies these to non-zero (selected) features + if self.feature_offset is not None or self.feature_scale is not None: + # Get non-zero positions (selected features) + nonzero_mask = activation_tensor != 0 + + if nonzero_mask.any(): + # Apply transformations only to selected features + activation_tensor = activation_tensor.clone() + batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) + + if self.feature_offset is not None: + # Apply offset only to non-zero features + offset_values = self.feature_offset[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] += offset_values + + if self.feature_scale is not None: + # Apply scale only to non-zero features + scale_values = self.feature_scale[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] *= scale_values + + summed_activation += activation_tensor + + # Now decode ONCE with the summed activation + decoder = self.decoders[layer_idx] + reconstruction = decoder(summed_activation) + + # Apply skip connections from source inputs if enabled + if self.skip_weights is not None and source_inputs is not None: + skip_weight = self.skip_weights[layer_idx] + # Add skip connections from each source layer that contributed + for src_layer in range(layer_idx + 1): + if src_layer in source_inputs: + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + reconstruction += skip_contribution + + else: + # Original logic for per_source and untied decoders + for src_layer in range(layer_idx + 1): + if src_layer in a: + activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) + + if activation_tensor.numel() == 0: + continue + if activation_tensor.shape[-1] != self.config.num_features: + logger.warning( + f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." + ) + continue + + # Apply feature affine transformations for per_source + if self.config.decoder_tying == "per_source": + # Get non-zero positions (selected features) + nonzero_mask = activation_tensor != 0 + + if nonzero_mask.any(): + # Apply transformations only to selected features + activation_tensor = activation_tensor.clone() + batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) + + if self.feature_offset is not None: + # Apply offset indexed by target layer + offset_values = self.feature_offset[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] += offset_values + + if self.feature_scale is not None: + # Apply scale indexed by target layer + scale_values = self.feature_scale[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] *= scale_values - decoder = self.decoders[f"{src_layer}->{layer_idx}"] - decoded = decoder(activation_tensor) - reconstruction += decoded + if self.config.decoder_tying == "per_source": + # Use tied decoder for the source layer + decoder = self.decoders[src_layer] + decoded = decoder(activation_tensor) + + # Apply skip connection from this source input if enabled + if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs: + skip_weight = self.skip_weights[layer_idx] + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + decoded += skip_contribution + else: + # Use untied decoder for (src, tgt) pair + decoder = self.decoders[f"{src_layer}->{layer_idx}"] + decoded = decoder(activation_tensor) + + # Apply skip connection from this source input if enabled + if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs: + skip_key = f"{src_layer}->{layer_idx}" + if skip_key in self.skip_weights: + skip_weight = self.skip_weights[skip_key] + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + decoded += skip_contribution + + reconstruction += decoded + return reconstruction def get_decoder_norms(self) -> torch.Tensor: @@ -143,10 +375,10 @@ def get_decoder_norms(self) -> torch.Tensor: for src_layer in range(self.config.num_layers): local_norms_sq_accum = torch.zeros(self.config.num_features, device=self.device, dtype=torch.float32) - for tgt_layer in range(src_layer, self.config.num_layers): - decoder_key = f"{src_layer}->{tgt_layer}" - decoder = self.decoders[decoder_key] - assert isinstance(decoder, RowParallelLinear), f"Decoder {decoder_key} is not RowParallelLinear" + if self.config.decoder_tying == "per_source": + # For tied decoders, compute norms once per source layer + decoder = self.decoders[src_layer] + assert isinstance(decoder, RowParallelLinear), f"Decoder {src_layer} is not RowParallelLinear" current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) @@ -161,7 +393,7 @@ def get_decoder_norms(self) -> torch.Tensor: pass elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: logger.warning( - f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for {decoder_key}. This might indicate an issue with RowParallelLinear partitioning." + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for decoder {src_layer}. This might indicate an issue with RowParallelLinear partitioning." ) if actual_local_dim > 0: @@ -171,9 +403,75 @@ def get_decoder_norms(self) -> torch.Tensor: local_norms_sq_accum[global_slice] += valid_norms_sq else: logger.warning( - f"Rank {self.rank}: Shape mismatch in decoder norm calculation for {decoder_key}. " + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for decoder {src_layer}. " f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." ) + elif self.config.decoder_tying == "per_target": + # For per_target tying, each decoder corresponds to a target layer + # We accumulate decoder norms from all target layers >= src_layer + for tgt_layer in range(src_layer, self.config.num_layers): + decoder = self.decoders[tgt_layer] + assert isinstance(decoder, RowParallelLinear), f"Decoder {tgt_layer} is not RowParallelLinear" + + current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) + + full_dim = decoder.full_in_features + features_per_rank = (full_dim + self.world_size - 1) // self.world_size + start_idx = self.rank * features_per_rank + end_idx = min(start_idx + features_per_rank, full_dim) + actual_local_dim = max(0, end_idx - start_idx) + local_dim_padded = decoder.local_in_features + + if local_dim_padded != features_per_rank and self.rank == self.world_size - 1: + pass + elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: + logger.warning( + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for decoder {tgt_layer}. This might indicate an issue with RowParallelLinear partitioning." + ) + + if actual_local_dim > 0: + valid_norms_sq = current_norms_sq[:actual_local_dim] + if valid_norms_sq.shape[0] == actual_local_dim: + global_slice = slice(start_idx, end_idx) + local_norms_sq_accum[global_slice] += valid_norms_sq + else: + logger.warning( + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for decoder {tgt_layer}. " + f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." + ) + else: + # For untied decoders, accumulate norms from all target layers + for tgt_layer in range(src_layer, self.config.num_layers): + decoder_key = f"{src_layer}->{tgt_layer}" + decoder = self.decoders[decoder_key] + assert isinstance(decoder, RowParallelLinear), f"Decoder {decoder_key} is not RowParallelLinear" + + current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) + + full_dim = decoder.full_in_features + features_per_rank = (full_dim + self.world_size - 1) // self.world_size + start_idx = self.rank * features_per_rank + end_idx = min(start_idx + features_per_rank, full_dim) + actual_local_dim = max(0, end_idx - start_idx) + local_dim_padded = decoder.local_in_features + + if local_dim_padded != features_per_rank and self.rank == self.world_size - 1: + pass + elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: + logger.warning( + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for {decoder_key}. This might indicate an issue with RowParallelLinear partitioning." + ) + + if actual_local_dim > 0: + valid_norms_sq = current_norms_sq[:actual_local_dim] + if valid_norms_sq.shape[0] == actual_local_dim: + global_slice = slice(start_idx, end_idx) + local_norms_sq_accum[global_slice] += valid_norms_sq + else: + logger.warning( + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for {decoder_key}. " + f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." + ) if self.process_group is not None and dist_ops.is_dist_initialized_and_available(): dist_ops.all_reduce(local_norms_sq_accum, op=dist_ops.SUM, group=self.process_group) diff --git a/clt/models/encoder.py b/clt/models/encoder.py index b9032e1..07b30c2 100644 --- a/clt/models/encoder.py +++ b/clt/models/encoder.py @@ -47,6 +47,9 @@ def __init__( for _ in range(config.num_layers) ] ) + + # Note: feature_offset and feature_scale have been moved to Decoder module + # to match EleutherAI's architecture where they are indexed by target layer def get_preactivations(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: """Get pre-activation values (full tensor) for features at the specified layer.""" diff --git a/clt/training/data/manifest_activation_store.py b/clt/training/data/manifest_activation_store.py index da13cbb..0a01879 100644 --- a/clt/training/data/manifest_activation_store.py +++ b/clt/training/data/manifest_activation_store.py @@ -5,6 +5,7 @@ from collections import defaultdict import threading import queue +import math # import json # Unused from abc import ABC, abstractmethod @@ -367,6 +368,7 @@ def __init__( self.epoch = 0 self.prefetch_batches = max(1, prefetch_batches) self.sampling_strategy = sampling_strategy + self.normalization_method = normalization_method self.shard_data = shard_data # Device setup @@ -483,8 +485,17 @@ def __init__( self.apply_normalization = False if normalization_method == "none": self.apply_normalization = False - else: + elif normalization_method == "mean_std": + # mean_std requires normalization stats self.apply_normalization = bool(self.norm_stats_data) + elif normalization_method == "sqrt_d_model": + # sqrt_d_model doesn't need norm stats, just applies scaling + self.apply_normalization = True + else: + raise ValueError( + f"Invalid normalization_method: {normalization_method}. " + f"Must be one of ['none', 'mean_std', 'sqrt_d_model']" + ) if self.apply_normalization: self._prep_norm() @@ -556,9 +567,14 @@ def _prep_norm(self): self.mean_tg: Dict[int, torch.Tensor] = {} self.std_tg: Dict[int, torch.Tensor] = {} - if not self.norm_stats_data: - logger.warning("Normalization prep called but no stats data loaded.") - self.apply_normalization = False + # Only need to load stats for mean_std normalization + if self.normalization_method == "mean_std": + if not self.norm_stats_data: + logger.warning("mean_std normalization requested but no stats data loaded.") + self.apply_normalization = False + return + elif self.normalization_method == "sqrt_d_model": + # sqrt_d_model doesn't need stats, just return return missing_layers = set(self.layer_indices) @@ -901,10 +917,17 @@ def _fetch_and_parse_batch(self, idxs: np.ndarray) -> ActivationBatch: log_stats_this_batch["target_mean_in"] = self.mean_in[li].mean().item() log_stats_this_batch["target_std_in"] = self.std_in[li].mean().item() - if li in self.mean_in and li in self.std_in: - inputs_li = (inputs_li - self.mean_in[li]) / self.std_in[li] - if li in self.mean_tg and li in self.std_tg: - targets_li = (targets_li - self.mean_tg[li]) / self.std_tg[li] + if self.normalization_method == "mean_std": + # Standard normalization: (x - mean) / std + if li in self.mean_in and li in self.std_in: + inputs_li = (inputs_li - self.mean_in[li]) / self.std_in[li] + if li in self.mean_tg and li in self.std_tg: + targets_li = (targets_li - self.mean_tg[li]) / self.std_tg[li] + elif self.normalization_method == "sqrt_d_model": + # EleutherAI-style normalization: x * sqrt(d_model) + sqrt_d_model = math.sqrt(self.d_model) + inputs_li = inputs_li * sqrt_d_model + targets_li = targets_li * sqrt_d_model # Convert to final target dtype *after* normalization final_batch_inputs[li] = inputs_li.to(self.dtype) diff --git a/clt/training/evaluator.py b/clt/training/evaluator.py index eb77645..19570bc 100644 --- a/clt/training/evaluator.py +++ b/clt/training/evaluator.py @@ -35,6 +35,8 @@ def __init__( start_time: Optional[float] = None, mean_tg: Optional[Dict[int, torch.Tensor]] = None, std_tg: Optional[Dict[int, torch.Tensor]] = None, + normalization_method: str = "none", + d_model: Optional[int] = None, ): """Initialize the evaluator. @@ -44,6 +46,8 @@ def __init__( start_time: The initial time.time() from the trainer for elapsed time logging. mean_tg: Optional dictionary of per-layer target means for de-normalising outputs. std_tg: Optional dictionary of per-layer target stds for de-normalising outputs. + normalization_method: The normalization method being used. + d_model: Model dimension for sqrt_d_model normalization. """ self.model = model self.device = device @@ -51,6 +55,16 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + + # Validate normalization method + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] + if normalization_method not in valid_norm_methods: + raise ValueError( + f"Invalid normalization_method: {normalization_method}. " + f"Must be one of {valid_norm_methods}" + ) + self.normalization_method = normalization_method + self.d_model = d_model self.metrics_history: List[Dict[str, Any]] = [] # For storing metrics over time if needed @staticmethod @@ -249,6 +263,10 @@ def _compute_reconstruction_metrics( total_explained_variance = 0.0 total_nmse = 0.0 num_layers = 0 + + # For layerwise metrics + layerwise_nmse = {} + layerwise_explained_variance = {} for layer_idx, target_act in targets.items(): if layer_idx not in reconstructions: @@ -256,15 +274,22 @@ def _compute_reconstruction_metrics( recon_act = reconstructions[layer_idx] - # --- De-normalise if stats available --- + # --- De-normalise based on normalization method --- target_act_denorm = target_act recon_act_denorm = recon_act - if layer_idx in self.mean_tg and layer_idx in self.std_tg: + + if self.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + # Standard denormalization: x * std + mean mean = self.mean_tg[layer_idx].to(recon_act.device, recon_act.dtype) std = self.std_tg[layer_idx].to(recon_act.device, recon_act.dtype) # Ensure broadcast shape target_act_denorm = target_act * std + mean recon_act_denorm = recon_act * std + mean + elif self.normalization_method == "sqrt_d_model" and self.d_model is not None: + # sqrt_d_model denormalization: x / sqrt(d_model) + sqrt_d_model = (self.d_model ** 0.5) + target_act_denorm = target_act / sqrt_d_model + recon_act_denorm = recon_act / sqrt_d_model # --- End De-normalisation --- # Ensure shapes match (flatten if necessary) and up-cast to float32 for numerically stable metrics @@ -299,6 +324,10 @@ def _compute_reconstruction_metrics( else: # Target variance is zero but MSE is non-zero (implies error, NMSE is effectively infinite) nmse_layer = float("inf") # Or a large number, or handle as NaN depending on preference total_nmse += nmse_layer + + # Store layerwise metrics + layerwise_nmse[f"layer_{layer_idx}"] = nmse_layer + layerwise_explained_variance[f"layer_{layer_idx}"] = explained_variance_layer num_layers += 1 @@ -314,6 +343,8 @@ def _compute_reconstruction_metrics( return { "reconstruction/explained_variance": avg_explained_variance, "reconstruction/normalized_mean_reconstruction_error": avg_normalized_mean_reconstruction_error, + "layerwise/normalized_mse": layerwise_nmse, + "layerwise/explained_variance": layerwise_explained_variance, } def _compute_feature_density(self, activations: Dict[int, torch.Tensor]) -> Dict[str, Any]: diff --git a/clt/training/losses.py b/clt/training/losses.py index 6771148..f2245b8 100644 --- a/clt/training/losses.py +++ b/clt/training/losses.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from typing import Dict, Tuple, Optional -from clt.config import TrainingConfig +from clt.config import TrainingConfig, CLTConfig from clt.models.clt import CrossLayerTranscoder @@ -15,6 +15,7 @@ def __init__( config: TrainingConfig, mean_tg: Optional[Dict[int, torch.Tensor]] = None, std_tg: Optional[Dict[int, torch.Tensor]] = None, + clt_config: Optional['CLTConfig'] = None, ): """Initialize the loss manager. @@ -22,6 +23,7 @@ def __init__( config: Training configuration mean_tg: Optional dictionary of per-layer target means for de-normalising outputs std_tg: Optional dictionary of per-layer target stds for de-normalising outputs + clt_config: Optional CLT configuration for accessing d_model """ self.config = config self.reconstruction_loss_fn = nn.MSELoss() @@ -29,6 +31,7 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + self.clt_config = clt_config self.aux_loss_factor = config.aux_loss_factor # New: coefficient for auxiliary loss self.apply_sparsity_penalty_to_batchtopk = config.apply_sparsity_penalty_to_batchtopk @@ -69,13 +72,19 @@ def compute_reconstruction_loss( pred_layer = predicted[layer_idx] tgt_layer = target[layer_idx] - # De-normalise if stats available for this layer - if layer_idx in self.mean_tg and layer_idx in self.std_tg: + # De-normalise based on normalization method + if self.config.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + # Standard denormalization: x * std + mean mean = self.mean_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) std = self.std_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) # mean/std were stored with an added batch dim – ensure broadcast shape pred_layer = pred_layer * std + mean tgt_layer = tgt_layer * std + mean + elif self.config.normalization_method == "sqrt_d_model" and self.clt_config is not None: + # sqrt_d_model denormalization: x / sqrt(d_model) + sqrt_d_model = (self.clt_config.d_model ** 0.5) + pred_layer = pred_layer / sqrt_d_model + tgt_layer = tgt_layer / sqrt_d_model layer_loss = self.reconstruction_loss_fn(pred_layer, tgt_layer) total_loss += layer_loss @@ -360,10 +369,25 @@ def compute_total_loss( preactivation_loss = self.compute_preactivation_loss(model, inputs) # Compute residuals for auxiliary loss if needed + # Important: Compute residuals in denormalized (original) space for consistent auxiliary loss scale residuals = {} for layer_idx in predictions: if layer_idx in targets: - residuals[layer_idx] = targets[layer_idx] - predictions[layer_idx] + pred_layer = predictions[layer_idx] + tgt_layer = targets[layer_idx] + + # Denormalize before computing residuals + if self.config.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + mean = self.mean_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) + std = self.std_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) + pred_layer = pred_layer * std + mean + tgt_layer = tgt_layer * std + mean + elif self.config.normalization_method == "sqrt_d_model" and self.clt_config is not None: + sqrt_d_model = (self.clt_config.d_model ** 0.5) + pred_layer = pred_layer / sqrt_d_model + tgt_layer = tgt_layer / sqrt_d_model + + residuals[layer_idx] = tgt_layer - pred_layer # Compute auxiliary loss (only if configured and using BatchTopK) aux_loss = torch.tensor(0.0, device=reconstruction_loss.device) diff --git a/clt/training/trainer.py b/clt/training/trainer.py index 32f7ec2..2b5dc08 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -324,6 +324,7 @@ def lr_lambda(current_step: int): training_config, mean_tg=mean_tg_stats, std_tg=std_tg_stats, + clt_config=clt_config, ) # Initialize Evaluator - Pass norm stats here too @@ -333,6 +334,8 @@ def lr_lambda(current_step: int): start_time=self.start_time, mean_tg=mean_tg_stats, # Pass the same stats std_tg=std_tg_stats, # Pass the same stats + normalization_method=training_config.normalization_method, + d_model=clt_config.d_model, ) # Initialize dead neuron counters (replicated for now, consider sharding later if needed) @@ -472,11 +475,15 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: logger.info(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") # Check if using normalization and notify user - if self.training_config.normalization_method == "estimated_mean_std": - logger.info("\n>>> NORMALIZATION PHASE <<<") - logger.info("Normalization statistics are being estimated from dataset activations.") - logger.info("This may take some time, but happens only once before training begins.") - logger.info(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") + if self.training_config.normalization_method == "mean_std": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("Using mean/std normalization with pre-calculated statistics from norm_stats.json") + elif self.training_config.normalization_method == "sqrt_d_model": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("Using sqrt(d_model) normalization (EleutherAI-style)") + elif self.training_config.normalization_method == "none": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("No normalization will be applied to activations") # Make sure we flush stdout to ensure prints appear immediately, # especially important in Jupyter/interactive environments diff --git a/clt/training/wandb_logger.py b/clt/training/wandb_logger.py index 3c351af..521ee58 100644 --- a/clt/training/wandb_logger.py +++ b/clt/training/wandb_logger.py @@ -68,21 +68,41 @@ def __init__( entity_name = os.environ.get("WANDB_ENTITY", training_config.wandb_entity) run_name = training_config.wandb_run_name # Can be None + # Prepare config for both new and resumed runs + wandb_config = self._create_wandb_config(training_config, clt_config) + if self.resume_wandb_id: logger.info(f"Attempting to resume WandB run with ID: {self.resume_wandb_id}") - # When resuming, wandb.init will use the passed id. Do not pass project/entity/name again if resuming. - self.wandb_run = wandb.init(id=self.resume_wandb_id, resume="allow") + # When resuming, we need to update the config after init + self.wandb_run = wandb.init( + id=self.resume_wandb_id, + resume="allow", + project=project_name, + entity=entity_name, + tags=training_config.wandb_tags, + ) + # Update config after resuming + if self.wandb_run: + wandb.config.update(wandb_config, allow_val_change=True) else: self.wandb_run = wandb.init( project=project_name, entity=entity_name, name=run_name, - config=self._create_wandb_config(training_config, clt_config), - reinit=True, # Allow re-initialization in the same process if needed + config=wandb_config, + tags=training_config.wandb_tags, ) if self.wandb_run: + self._run_id = self.wandb_run.id logger.info(f"WandB logging initialized: {self.wandb_run.name} (ID: {self.wandb_run.id})") + # Log config keys to verify they were set (only in debug mode) + if logger.isEnabledFor(logging.DEBUG): + if hasattr(wandb, 'config') and wandb.config: + config_keys = list(wandb.config.keys()) + logger.debug(f"WandB config keys: {config_keys}") + else: + logger.debug("WandB config appears to be empty or not accessible") else: logger.warning("Warning: WandB run initialization failed but no exception was raised.") diff --git a/scripts/convert_batchtopk_to_jumprelu.py b/scripts/convert_batchtopk_to_jumprelu.py index 2822e30..a198b1f 100644 --- a/scripts/convert_batchtopk_to_jumprelu.py +++ b/scripts/convert_batchtopk_to_jumprelu.py @@ -31,7 +31,16 @@ def _remap_checkpoint_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Remaps old state_dict keys to the new format with module prefixes.""" + """Remaps old state_dict keys to the new format with module prefixes. + + Handles both old-style keys (encoders.*, decoders.*) and new tied decoder parameters: + - encoder_module.feature_offset.{layer_idx}: ParameterList for per-feature bias + - encoder_module.feature_scale.{layer_idx}: ParameterList for per-feature scale + - decoder_module.skip_weights.{layer_idx}: ParameterList for tied decoders + - decoder_module.skip_weights.{src}->{tgt}: ParameterDict for untied decoders + - decoder_module.per_target_scale: Tensor for per src->tgt scale (tied decoders) + - decoder_module.per_target_bias: Tensor for per src->tgt bias (tied decoders) + """ new_state_dict = {} for key, value in state_dict.items(): if key.startswith("encoders."): @@ -41,6 +50,23 @@ def _remap_checkpoint_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor else: new_key = key new_state_dict[new_key] = value + + # Handle new parameter names that might not have the correct module prefix + # Create a list of keys to avoid modifying dict during iteration + keys_to_check = list(new_state_dict.keys()) + for key in keys_to_check: + # Handle feature_offset/feature_scale that might be saved without module prefix + if key.startswith("feature_offset.") and not key.startswith("encoder_module."): + new_state_dict[f"encoder_module.{key}"] = new_state_dict.pop(key) + elif key.startswith("feature_scale.") and not key.startswith("encoder_module."): + new_state_dict[f"encoder_module.{key}"] = new_state_dict.pop(key) + # Handle skip_weights that might be saved without module prefix + elif key.startswith("skip_weights.") and not key.startswith("decoder_module."): + new_state_dict[f"decoder_module.{key}"] = new_state_dict.pop(key) + # Handle per_target parameters + elif key in ["per_target_scale", "per_target_bias"] and not key.startswith("decoder_module."): + new_state_dict[f"decoder_module.{key}"] = new_state_dict.pop(key) + if not any(k.startswith("encoder_module.") or k.startswith("decoder_module.") for k in new_state_dict.keys()): if any(k.startswith("encoders.") or k.startswith("decoders.") for k in state_dict.keys()): logger.warning( diff --git a/scripts/train_clt.py b/scripts/train_clt.py index bfca521..7f7ab07 100644 --- a/scripts/train_clt.py +++ b/scripts/train_clt.py @@ -228,6 +228,13 @@ def parse_args(): action="store_true", # If flag is present, disable is true. Default behavior is enabled. help="Disable straight-through estimator for BatchTopK. (BatchTopK default is True).", ) + clt_group.add_argument( + "--topk-mode", + type=str, + choices=["global", "per_layer"], + default="global", + help="How to apply top-k selection: 'global' (across all layers) or 'per_layer' (each layer independently).", + ) clt_group.add_argument( "--topk-k", type=float, # As per CLTConfig, topk_k can be a float (fraction) or int (count) @@ -245,6 +252,38 @@ def parse_args(): default=None, help="Optional data type for the CLT model parameters (e.g., 'float16', 'bfloat16').", ) + clt_group.add_argument( + "--decoder-tying", + type=str, + choices=["none", "per_source", "per_target"], + default="none", + help="Decoder weight sharing strategy: 'none' (default), 'per_source' (tied per source layer), or 'per_target' (tied per target layer, EleutherAI style).", + ) + clt_group.add_argument( + "--per-target-scale", + action="store_true", + help="Enable learned scale for each src->tgt path when using tied decoders.", + ) + clt_group.add_argument( + "--per-target-bias", + action="store_true", + help="Enable learned bias for each src->tgt path when using tied decoders.", + ) + clt_group.add_argument( + "--enable-feature-offset", + action="store_true", + help="Enable per-feature bias (theta_bias) applied after encoding.", + ) + clt_group.add_argument( + "--enable-feature-scale", + action="store_true", + help="Enable per-feature scale (theta_scale) applied after encoding.", + ) + clt_group.add_argument( + "--skip-connection", + action="store_true", + help="Enable skip connection from input to output.", + ) # --- Training Hyperparameters (TrainingConfig) --- train_group = parser.add_argument_group("Training Hyperparameters (TrainingConfig)") @@ -281,11 +320,13 @@ def parse_args(): train_group.add_argument( "--normalization-method", type=str, - choices=["auto", "none", "estimated_mean_std"], # Added estimated_mean_std from TrainingConfig - default="auto", + choices=["none", "mean_std", "sqrt_d_model"], + default="mean_std", help=( - "Normalization for activation store. 'auto' expects server/local store to provide stats. " - "'estimated_mean_std' forces estimation (if store supports it). 'none' disables." + "Normalization method for activations. " + "'none': No normalization. " + "'mean_std': Standard (x - mean) / std normalization using pre-calculated stats. " + "'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization." ), ) train_group.add_argument( @@ -608,6 +649,13 @@ def main(): clt_dtype=args.clt_dtype, topk_k=args.topk_k, topk_straight_through=(not args.disable_topk_straight_through), + decoder_tying=args.decoder_tying, + per_target_scale=args.per_target_scale, + per_target_bias=args.per_target_bias, + enable_feature_offset=args.enable_feature_offset, + enable_feature_scale=args.enable_feature_scale, + skip_connection=args.skip_connection, + topk_mode=args.topk_mode, ) logger.info(f"CLT Config: {clt_config}") diff --git a/tests/unit/data/test_data_integrity.py b/tests/unit/data/test_data_integrity.py index 9a5f7c9..d76893f 100644 --- a/tests/unit/data/test_data_integrity.py +++ b/tests/unit/data/test_data_integrity.py @@ -193,7 +193,7 @@ def test_normalization_application_correctness(self, tmp_path): store = LocalActivationStore( dataset_path=output_dir, train_batch_size_tokens=100, - normalization_method="standard", # Enable normalization + normalization_method="mean_std", # Enable normalization dtype="float32", device="cpu", ) diff --git a/tests/unit/models/test_tied_decoders.py b/tests/unit/models/test_tied_decoders.py new file mode 100644 index 0000000..9ef6df2 --- /dev/null +++ b/tests/unit/models/test_tied_decoders.py @@ -0,0 +1,274 @@ +"""Unit tests for tied decoder functionality in CLT models.""" + +import pytest +import torch +import torch.nn as nn +from typing import Dict + +from clt.config import CLTConfig +from clt.models.clt import CrossLayerTranscoder +from clt.models.decoder import Decoder +from clt.models.encoder import Encoder + + +class TestTiedDecoders: + """Test suite for tied decoder architecture.""" + + @pytest.fixture + def base_config(self): + """Base CLT configuration for testing.""" + return CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + decoder_tying="none", # Default untied + ) + + @pytest.fixture + def tied_config(self): + """CLT configuration with tied decoders.""" + return CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + decoder_tying="per_source", + ) + + def test_decoder_initialization_untied(self, base_config): + """Test that untied decoder creates correct number of decoder modules.""" + decoder = Decoder( + config=base_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Should have decoders for each (src, tgt) pair where src <= tgt + # For 4 layers: 0->0, 0->1, 0->2, 0->3, 1->1, 1->2, 1->3, 2->2, 2->3, 3->3 + # Total: 4 + 3 + 2 + 1 = 10 + expected_decoder_count = sum(range(1, base_config.num_layers + 1)) + assert len(decoder.decoders) == expected_decoder_count + + # Check that all expected keys exist + for src in range(base_config.num_layers): + for tgt in range(src, base_config.num_layers): + assert f"{src}->{tgt}" in decoder.decoders + + def test_decoder_initialization_tied(self, tied_config): + """Test that tied decoder creates one decoder per source layer.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Should have one decoder per source layer + assert len(decoder.decoders) == tied_config.num_layers + + # Check that decoders are indexed by layer + for layer in range(tied_config.num_layers): + assert isinstance(decoder.decoders[layer], nn.Module) + + def test_skip_connections(self, tied_config): + """Test skip connection functionality.""" + # Test with skip connections enabled + config_with_skip = CLTConfig( + **{**tied_config.__dict__, "skip_connection": True} + ) + decoder = Decoder( + config=config_with_skip, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Skip weights should be initialized + assert decoder.skip_weights is not None + assert len(decoder.skip_weights) == config_with_skip.num_layers + + # Each skip weight should have correct shape + for layer_idx in range(config_with_skip.num_layers): + skip_weight = decoder.skip_weights[layer_idx] + assert skip_weight.shape == (config_with_skip.d_model, config_with_skip.d_model) + # Should be initialized to zeros + expected = torch.zeros(config_with_skip.d_model, config_with_skip.d_model, dtype=torch.float32) + assert torch.allclose(skip_weight, expected) + + def test_feature_affine_parameters(self): + """Test feature offset and scale parameters in decoder.""" + config = CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + enable_feature_offset=True, + enable_feature_scale=True, + decoder_tying="per_source", # Feature affine only works with tied decoders + ) + + decoder = Decoder( + config=config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Check feature_offset initialization + assert decoder.feature_offset is not None + assert len(decoder.feature_offset) == config.num_layers + for layer_idx in range(config.num_layers): + assert decoder.feature_offset[layer_idx].shape == (config.num_features,) + assert torch.allclose(decoder.feature_offset[layer_idx], torch.zeros_like(decoder.feature_offset[layer_idx])) + + # Check feature_scale initialization + assert decoder.feature_scale is not None + assert len(decoder.feature_scale) == config.num_layers + for layer_idx in range(config.num_layers): + assert decoder.feature_scale[layer_idx].shape == (config.num_features,) + # First layer should be ones, rest should be 0.1 for tied decoders + if layer_idx == 0: + assert torch.allclose(decoder.feature_scale[layer_idx], torch.ones_like(decoder.feature_scale[layer_idx])) + else: + expected = torch.full_like(decoder.feature_scale[layer_idx], 0.1) + assert torch.allclose(decoder.feature_scale[layer_idx], expected) + + def test_decode_with_tied_decoders(self, tied_config): + """Test decoding with tied decoders.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Create test activations + batch_size = 8 + activations = { + 0: torch.randn(batch_size, tied_config.num_features), + 1: torch.randn(batch_size, tied_config.num_features), + } + + # Test reconstruction at layer 1 + reconstruction = decoder.decode(activations, layer_idx=1) + + assert reconstruction.shape == (batch_size, tied_config.d_model) + # With zero-initialized decoders (matching reference implementation), + # the output will be zeros initially + assert torch.allclose(reconstruction, torch.zeros_like(reconstruction)) + + # Verify that if we set non-zero weights, we get non-zero outputs + for decoder_module in decoder.decoders: + decoder_module.weight.data.fill_(0.1) + reconstruction2 = decoder.decode(activations, layer_idx=1) + assert not torch.allclose(reconstruction2, torch.zeros_like(reconstruction2)) + + def test_decoder_norms_tied(self, tied_config): + """Test decoder norm computation for tied decoders.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + norms = decoder.get_decoder_norms() + + # Should have shape [num_layers, num_features] + assert norms.shape == (tied_config.num_layers, tied_config.num_features) + + # Norms should be positive + assert torch.all(norms >= 0) + + def test_feature_affine_transformation(self): + """Test feature affine transformation in decoder.""" + config = CLTConfig( + num_features=128, + num_layers=2, + d_model=64, + activation_fn="relu", + enable_feature_offset=True, + enable_feature_scale=True, + decoder_tying="per_source", + ) + + decoder = Decoder( + config=config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Create test activations + batch_size = 4 + test_activations = { + 0: torch.randn(batch_size, config.num_features), + 1: torch.randn(batch_size, config.num_features), + } + + # Set some specific values for testing + decoder.feature_offset[0].data.fill_(0.5) + decoder.feature_scale[0].data.fill_(2.0) + + # Decode at layer 1 (should use features from layers 0 and 1) + result = decoder.decode(test_activations, layer_idx=1) + + # Result should have correct shape + assert result.shape == (batch_size, config.d_model) + + def test_backward_compatibility_config(self): + """Test loading old config without new fields.""" + old_config_dict = { + "num_features": 128, + "num_layers": 4, + "d_model": 64, + "activation_fn": "relu", + # Missing: decoder_tying, enable_feature_offset, enable_feature_scale, skip_connection + } + + # Should not raise an error + config = CLTConfig(**old_config_dict) + + # Should have default values + assert config.decoder_tying == "none" + assert config.enable_feature_offset == False + assert config.enable_feature_scale == False + assert config.skip_connection == False + + def test_checkpoint_compatibility(self, base_config, tied_config): + """Test loading old untied checkpoint into tied model.""" + # Create untied model and save checkpoint + untied_model = CrossLayerTranscoder( + config=base_config, + process_group=None, + device=torch.device("cpu"), + ) + + # Get state dict from untied model + untied_state_dict = untied_model.state_dict() + + # Create tied model + tied_model = CrossLayerTranscoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + ) + + # Should be able to load with custom logic + tied_model.load_state_dict(untied_state_dict, strict=False) + + # Tied model should have loaded the diagonal decoder weights + for src_layer in range(tied_config.num_layers): + tied_weight = tied_model.decoder_module.decoders[src_layer].weight + untied_key = f"decoder_module.decoders.{src_layer}->{src_layer}.weight" + if untied_key in untied_state_dict: + untied_weight = untied_state_dict[untied_key] + # Shapes might differ due to RowParallelLinear, so just check they're both tensors + assert isinstance(tied_weight, torch.Tensor) + assert isinstance(untied_weight, torch.Tensor) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tutorials/1B-end-to-end-training-pythia-batchtopk.py b/tutorials/1B-end-to-end-training-pythia-batchtopk.py index 5ef0177..248c03a 100644 --- a/tutorials/1B-end-to-end-training-pythia-batchtopk.py +++ b/tutorials/1B-end-to-end-training-pythia-batchtopk.py @@ -175,7 +175,7 @@ train_batch_size_tokens=_batch_size, sampling_strategy="sequential", # Normalization - normalization_method="auto", # Use pre-calculated stats + normalization_method="mean_std", # Use pre-calculated stats # Loss function coefficients sparsity_lambda=0.0, # Disable standard sparsity penalty sparsity_lambda_schedule="linear", @@ -194,7 +194,6 @@ max_features_for_diag_hist=1000, # optional cap per layer checkpoint_interval=500, dead_feature_window=200, - p # WandB (Optional) enable_wandb=True, wandb_project="clt-hp-sweeps-pythia-70m", diff --git a/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py b/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py new file mode 100644 index 0000000..bd1c9b1 --- /dev/null +++ b/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py @@ -0,0 +1,421 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with Tied Decoders and Feature Offset +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) using: +# - **Tied decoder architecture** to reduce memory usage +# - **Feature offset parameters** for per-feature bias +# - **BatchTopK activation** (same as Tutorial 1B) +# +# The tied decoder architecture uses one decoder per source layer (instead of one per source-target pair), +# significantly reducing memory usage from O(L²) to O(L) decoder parameters. +# +# We will: +# 1. Configure the CLT model with tied decoders and feature offset +# 2. Use the same pre-generated activations from Tutorial 1B +# 3. Train the model and compare memory usage +# 4. Demonstrate loading checkpoints with the new architecture + +# %% [markdown] +# ## 1. Imports and Setup + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Ensure tokenizers don't use parallelism +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Add project root to path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction (same as Tutorial 1B) +BASE_MODEL_NAME = "EleutherAI/pythia-70m" + +# %% [markdown] +# ## 2. Configuration with Tied Decoders +# +# Key differences from Tutorial 1B: +# - `decoder_tying="per_source"` - Enables tied decoder architecture +# - `enable_feature_offset=True` - Adds learnable per-feature bias +# - Memory savings: For 6 layers, we go from 21 decoders to just 6 + +# %% +# --- CLT Architecture Configuration with Tied Decoders --- +num_layers = 6 +d_model = 512 +expansion_factor = 32 +clt_num_features = d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", + batchtopk_k=batchtopk_k, + batchtopk_straight_through=True, + # NEW: Tied decoder configuration + decoder_tying="per_target", # Use one decoder per source layer + enable_feature_offset=True, # Enable per-feature bias (feature_offset) + enable_feature_scale=False, # Enable per-feature scale (feature_scale) + skip_connection=True, # Enable skip connection from input to output +) + +print("CLT Configuration (Tied Decoders with Feature Affine):") +print(f"- decoder_tying: {clt_config.decoder_tying}") +print(f"- enable_feature_offset: {clt_config.enable_feature_offset}") +print(f"- enable_feature_scale: {clt_config.enable_feature_scale}") +print(f"- skip_connection: {clt_config.skip_connection}") +print(f"- Number of features: {clt_config.num_features}") +print(f"- Number of layers: {clt_config.num_layers}") +print(f"- Activation function: {clt_config.activation_fn}") +print(f"- BatchTopK k: {clt_config.batchtopk_k}") + +# Calculate memory savings +untied_decoders = sum(range(1, num_layers + 1)) # 6 + 5 + 4 + 3 + 2 + 1 = 21 +tied_decoders = num_layers # 6 +print(f"\nMemory savings:") +print(f"- Untied decoders: {untied_decoders} decoder matrices") +print(f"- Tied decoders: {tied_decoders} decoder matrices") +print(f"- Reduction: {(1 - tied_decoders/untied_decoders)*100:.1f}%") + +# --- Use existing activations from Tutorial 1B --- +# We'll use the same activation directory as Tutorial 1B since the base model +# and dataset are identical - only the CLT architecture differs +activation_dir = "./tutorial_activations_local_1M_pythia" +dataset_name = "monology/pile-uncopyrighted" + +expected_activation_path = os.path.join( + activation_dir, + BASE_MODEL_NAME, + f"{os.path.basename(dataset_name)}_train", +) + +# Verify activations exist +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if not (os.path.exists(metadata_path) and os.path.exists(manifest_path)): + print(f"\nERROR: Activations not found at {expected_activation_path}") + print("Please run Tutorial 1B first to generate the activations.") + raise FileNotFoundError("Activation dataset not found") +else: + print(f"\nUsing existing activations from: {expected_activation_path}") + +# --- Training Configuration --- +_lr = 1e-4 +_batch_size = 1024 + +# WandB run name includes tied decoder info +wdb_run_name = ( + f"{clt_config.num_features}-width-" + f"tied-decoders-" # Indicate tied decoder architecture + f"feat-offset-" # Indicate feature offset is enabled + f"batchtopk-k{batchtopk_k}-" + f"{_batch_size}-batch-" + f"{_lr:.1e}-lr" +) +print(f"\nGenerated WandB run name: {wdb_run_name}") + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Same as Tutorial 1B for comparison + seed=42, + # Activation source (using existing activations) + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float32", + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + # Normalization + normalization_method="sqrt_d_model", + # Loss function coefficients (same as Tutorial 1B) + sparsity_lambda=0.0, + sparsity_lambda_schedule="linear", + sparsity_c=0.0, + preactivation_coef=0, + aux_loss_factor=1 / 32, + apply_sparsity_penalty_to_batchtopk=False, + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, + max_features_for_diag_hist=1000, + checkpoint_interval=500, + dead_feature_window=200, + # WandB + enable_wandb=True, + wandb_project="clt-debug-pythia-70m", + wandb_run_name=wdb_run_name, +) + +print("\nTraining Configuration:") +print(f"- Learning rate: {training_config.learning_rate}") +print(f"- Training steps: {training_config.training_steps}") +print(f"- Batch size (tokens): {training_config.train_batch_size_tokens}") + +# %% [markdown] +# ## 3. Initialize Model and Check Architecture +# +# Let's create the model and verify the tied decoder architecture is set up correctly. + +# %% +print("\nInitializing CLT model with tied decoders...") + +# Create model instance to inspect architecture +model = CrossLayerTranscoder( + config=clt_config, + process_group=None, + device=torch.device(device), +) + +print("\nModel architecture inspection:") +print(f"- Encoder modules: {len(model.encoder_module.encoders)}") +print(f"- Decoder modules: {len(model.decoder_module.decoders)}") + +# Check feature offset parameters +if model.decoder_module.feature_offset is not None: + print(f"- Feature offset parameters per layer: {len(model.decoder_module.feature_offset)}") + print(f"- Feature offset shape (layer 0): {model.decoder_module.feature_offset[0].shape}") +else: + print("- Feature offset: Not enabled") + +# Count total parameters +total_params = sum(p.numel() for p in model.parameters()) +encoder_params = sum(p.numel() for p in model.encoder_module.parameters()) +decoder_params = sum(p.numel() for p in model.decoder_module.parameters()) +print(f"\nParameter counts:") +print(f"- Total parameters: {total_params:,}") +print(f"- Encoder parameters: {encoder_params:,}") +print(f"- Decoder parameters: {decoder_params:,}") + +# Compare with untied architecture (approximate) +untied_decoder_params_approx = decoder_params * (untied_decoders / tied_decoders) +print(f"\nEstimated decoder parameters if untied: {untied_decoder_params_approx:,}") +print(f"Memory savings in decoder: {(1 - decoder_params/untied_decoder_params_approx)*100:.1f}%") + +# Clean up the test model +del model + +# %% [markdown] +# ## 4. Training the CLT with Tied Decoders + +# %% +print("\nInitializing CLTTrainer for training with tied decoders...") + +log_dir = f"clt_training_logs/clt_pythia_tied_decoders_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("\nCreating CLTTrainer instance...") + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("\nBeginning training with tied decoders...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Decoder tying: {clt_config.decoder_tying}") +print(f"Feature offset enabled: {clt_config.enable_feature_offset}") + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"\nTraining finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Tied Decoder Model + +# %% +# Save the final model state and config +final_model_state_path = os.path.join(log_dir, "clt_tied_final_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_tied_final_config.json") + +print(f"\nSaving final model state to: {final_model_state_path}") +print(f"Saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +# Verify the saved config has tied decoder settings +with open(final_model_config_path, "r") as f: + saved_config = json.load(f) + print(f"\nSaved config verification:") + print(f"- decoder_tying: {saved_config['decoder_tying']}") + print(f"- enable_feature_offset: {saved_config['enable_feature_offset']}") + print(f"- activation_fn: {saved_config['activation_fn']} (converted from batchtopk)") + +# Load the model back +print("\nLoading the saved tied decoder model...") +loaded_config = CLTConfig(**saved_config) +loaded_model = CrossLayerTranscoder( + config=loaded_config, + process_group=None, + device=torch.device(device), +) +loaded_model.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_model.eval() + +print("Model loaded successfully.") +print(f"Loaded model decoder count: {len(loaded_model.decoder_module.decoders)}") + +# %% [markdown] +# ## 6. Backward Compatibility Test +# +# Test loading an old untied checkpoint into our tied decoder model. +# This demonstrates the backward compatibility feature. + +# %% +print("\n=== Testing Backward Compatibility ===") + +# Create a simple untied model for testing +untied_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", # Simple activation for testing + decoder_tying="none", # Untied decoders +) + +print("Creating untied model for compatibility test...") +untied_model = CrossLayerTranscoder( + config=untied_config, + process_group=None, + device=torch.device("cpu"), # Use CPU for this test +) + +# Save untied model state +untied_state_dict = untied_model.state_dict() +print(f"Untied model decoder keys (first 5): {list(k for k in untied_state_dict.keys() if 'decoder' in k)[:5]}") + +# Create tied model with same dimensions +tied_test_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", + decoder_tying="per_source", # Tied decoders + enable_feature_offset=True, # This will be initialized to defaults +) + +tied_test_model = CrossLayerTranscoder( + config=tied_test_config, + process_group=None, + device=torch.device("cpu"), +) + +print("\nLoading untied checkpoint into tied model...") +try: + # This should work due to our custom load_state_dict + tied_test_model.load_state_dict(untied_state_dict, strict=False) + print("✓ Successfully loaded untied checkpoint into tied model!") + print(" The tied model uses diagonal decoder weights from the untied model.") +except Exception as e: + print(f"✗ Failed to load: {e}") + +# Clean up test models +del untied_model, tied_test_model + +# %% [markdown] +# ## 7. Performance Comparison Summary + +# %% +print("\n=== Tied Decoder Architecture Summary ===") +print(f"\nConfiguration used:") +print(f"- Model: {BASE_MODEL_NAME}") +print(f"- Layers: {num_layers}") +print(f"- Hidden dimension: {d_model}") +print(f"- Features per layer: {clt_num_features}") +print(f"- Decoder tying: {clt_config.decoder_tying}") +print(f"- Feature offset: {clt_config.enable_feature_offset}") + +print(f"\nMemory efficiency:") +print(f"- Traditional CLT: {untied_decoders} decoder matrices") +print(f"- Tied decoder CLT: {tied_decoders} decoder matrices") +print(f"- Memory reduction: ~{(1 - tied_decoders/untied_decoders)*100:.0f}%") + +print(f"\nKey benefits:") +print(f"1. Significant memory savings for decoder parameters") +print(f"2. Simpler feature interpretability (one decoder per source)") +print(f"3. Feature offset allows per-feature adaptation") +print(f"4. Backward compatible with existing checkpoints") + +print(f"\nTrade-offs:") +print(f"1. Less flexibility in source-target specific adaptations") +print(f"2. May require careful tuning of feature offset parameters") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial demonstrated: +# - Training a CLT with tied decoder architecture +# - Using feature offset parameters for per-feature bias +# - Significant memory savings compared to traditional CLT +# - Backward compatibility with untied checkpoints +# +# You can experiment with: +# - `per_target_scale` and `per_target_bias` for more flexibility +# - `enable_feature_scale` for per-feature scaling +# - Different values of `k` for BatchTopK +# - Comparing reconstruction quality between tied and untied architectures + +# %% +print(f"\n✓ Tied Decoder Tutorial Complete!") +print(f"Model and logs saved to: {log_dir}") diff --git a/tutorials/1F-end-to-end-training-pythia-tied-decoders.py b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py new file mode 100644 index 0000000..d5a6c1a --- /dev/null +++ b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py @@ -0,0 +1,421 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with Tied Decoders and Feature Offset +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) using: +# - **Tied decoder architecture** to reduce memory usage +# - **Feature offset parameters** for per-feature bias +# - **BatchTopK activation** (same as Tutorial 1B) +# +# The tied decoder architecture uses one decoder per source layer (instead of one per source-target pair), +# significantly reducing memory usage from O(L²) to O(L) decoder parameters. +# +# We will: +# 1. Configure the CLT model with tied decoders and feature offset +# 2. Use the same pre-generated activations from Tutorial 1B +# 3. Train the model and compare memory usage +# 4. Demonstrate loading checkpoints with the new architecture + +# %% [markdown] +# ## 1. Imports and Setup + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Ensure tokenizers don't use parallelism +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Add project root to path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction (same as Tutorial 1B) +BASE_MODEL_NAME = "EleutherAI/pythia-70m" + +# %% [markdown] +# ## 2. Configuration with Tied Decoders +# +# Key differences from Tutorial 1B: +# - `decoder_tying="per_source"` - Enables tied decoder architecture +# - `enable_feature_offset=True` - Adds learnable per-feature bias +# - Memory savings: For 6 layers, we go from 21 decoders to just 6 + +# %% +# --- CLT Architecture Configuration with Tied Decoders --- +num_layers = 6 +d_model = 512 +expansion_factor = 32 +clt_num_features = d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", + batchtopk_k=batchtopk_k, + batchtopk_straight_through=True, + # NEW: Tied decoder configuration + decoder_tying="per_target", # Use one decoder per source layer + enable_feature_offset=True, # Enable per-feature bias (feature_offset) + enable_feature_scale=False, # Enable per-feature scale (feature_scale) + skip_connection=True, # Enable skip connection from input to output +) + +print("CLT Configuration (Tied Decoders with Feature Affine):") +print(f"- decoder_tying: {clt_config.decoder_tying}") +print(f"- enable_feature_offset: {clt_config.enable_feature_offset}") +print(f"- enable_feature_scale: {clt_config.enable_feature_scale}") +print(f"- skip_connection: {clt_config.skip_connection}") +print(f"- Number of features: {clt_config.num_features}") +print(f"- Number of layers: {clt_config.num_layers}") +print(f"- Activation function: {clt_config.activation_fn}") +print(f"- BatchTopK k: {clt_config.batchtopk_k}") + +# Calculate memory savings +untied_decoders = sum(range(1, num_layers + 1)) # 6 + 5 + 4 + 3 + 2 + 1 = 21 +tied_decoders = num_layers # 6 +print(f"\nMemory savings:") +print(f"- Untied decoders: {untied_decoders} decoder matrices") +print(f"- Tied decoders: {tied_decoders} decoder matrices") +print(f"- Reduction: {(1 - tied_decoders/untied_decoders)*100:.1f}%") + +# --- Use existing activations from Tutorial 1B --- +# We'll use the same activation directory as Tutorial 1B since the base model +# and dataset are identical - only the CLT architecture differs +activation_dir = "./tutorial_activations_local_1M_pythia" +dataset_name = "monology/pile-uncopyrighted" + +expected_activation_path = os.path.join( + activation_dir, + BASE_MODEL_NAME, + f"{os.path.basename(dataset_name)}_train", +) + +# Verify activations exist +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if not (os.path.exists(metadata_path) and os.path.exists(manifest_path)): + print(f"\nERROR: Activations not found at {expected_activation_path}") + print("Please run Tutorial 1B first to generate the activations.") + raise FileNotFoundError("Activation dataset not found") +else: + print(f"\nUsing existing activations from: {expected_activation_path}") + +# --- Training Configuration --- +_lr = 1e-4 +_batch_size = 1024 + +# WandB run name includes tied decoder info +wdb_run_name = ( + f"{clt_config.num_features}-width-" + f"tied-decoders-" # Indicate tied decoder architecture + f"feat-offset-" # Indicate feature offset is enabled + f"batchtopk-k{batchtopk_k}-" + f"{_batch_size}-batch-" + f"{_lr:.1e}-lr" +) +print(f"\nGenerated WandB run name: {wdb_run_name}") + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Same as Tutorial 1B for comparison + seed=42, + # Activation source (using existing activations) + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float32", + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + # Normalization + normalization_method="none", + # Loss function coefficients (same as Tutorial 1B) + sparsity_lambda=0.0, + sparsity_lambda_schedule="linear", + sparsity_c=0.0, + preactivation_coef=0, + aux_loss_factor=1 / 32, + apply_sparsity_penalty_to_batchtopk=False, + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, + max_features_for_diag_hist=1000, + checkpoint_interval=500, + dead_feature_window=200, + # WandB + enable_wandb=True, + wandb_project="clt-debug-pythia-70m", + wandb_run_name=wdb_run_name, +) + +print("\nTraining Configuration:") +print(f"- Learning rate: {training_config.learning_rate}") +print(f"- Training steps: {training_config.training_steps}") +print(f"- Batch size (tokens): {training_config.train_batch_size_tokens}") + +# %% [markdown] +# ## 3. Initialize Model and Check Architecture +# +# Let's create the model and verify the tied decoder architecture is set up correctly. + +# %% +print("\nInitializing CLT model with tied decoders...") + +# Create model instance to inspect architecture +model = CrossLayerTranscoder( + config=clt_config, + process_group=None, + device=torch.device(device), +) + +print("\nModel architecture inspection:") +print(f"- Encoder modules: {len(model.encoder_module.encoders)}") +print(f"- Decoder modules: {len(model.decoder_module.decoders)}") + +# Check feature offset parameters +if model.decoder_module.feature_offset is not None: + print(f"- Feature offset parameters per layer: {len(model.decoder_module.feature_offset)}") + print(f"- Feature offset shape (layer 0): {model.decoder_module.feature_offset[0].shape}") +else: + print("- Feature offset: Not enabled") + +# Count total parameters +total_params = sum(p.numel() for p in model.parameters()) +encoder_params = sum(p.numel() for p in model.encoder_module.parameters()) +decoder_params = sum(p.numel() for p in model.decoder_module.parameters()) +print(f"\nParameter counts:") +print(f"- Total parameters: {total_params:,}") +print(f"- Encoder parameters: {encoder_params:,}") +print(f"- Decoder parameters: {decoder_params:,}") + +# Compare with untied architecture (approximate) +untied_decoder_params_approx = decoder_params * (untied_decoders / tied_decoders) +print(f"\nEstimated decoder parameters if untied: {untied_decoder_params_approx:,}") +print(f"Memory savings in decoder: {(1 - decoder_params/untied_decoder_params_approx)*100:.1f}%") + +# Clean up the test model +del model + +# %% [markdown] +# ## 4. Training the CLT with Tied Decoders + +# %% +print("\nInitializing CLTTrainer for training with tied decoders...") + +log_dir = f"clt_training_logs/clt_pythia_tied_decoders_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("\nCreating CLTTrainer instance...") + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("\nBeginning training with tied decoders...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Decoder tying: {clt_config.decoder_tying}") +print(f"Feature offset enabled: {clt_config.enable_feature_offset}") + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"\nTraining finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Tied Decoder Model + +# %% +# Save the final model state and config +final_model_state_path = os.path.join(log_dir, "clt_tied_final_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_tied_final_config.json") + +print(f"\nSaving final model state to: {final_model_state_path}") +print(f"Saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +# Verify the saved config has tied decoder settings +with open(final_model_config_path, "r") as f: + saved_config = json.load(f) + print(f"\nSaved config verification:") + print(f"- decoder_tying: {saved_config['decoder_tying']}") + print(f"- enable_feature_offset: {saved_config['enable_feature_offset']}") + print(f"- activation_fn: {saved_config['activation_fn']} (converted from batchtopk)") + +# Load the model back +print("\nLoading the saved tied decoder model...") +loaded_config = CLTConfig(**saved_config) +loaded_model = CrossLayerTranscoder( + config=loaded_config, + process_group=None, + device=torch.device(device), +) +loaded_model.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_model.eval() + +print("Model loaded successfully.") +print(f"Loaded model decoder count: {len(loaded_model.decoder_module.decoders)}") + +# %% [markdown] +# ## 6. Backward Compatibility Test +# +# Test loading an old untied checkpoint into our tied decoder model. +# This demonstrates the backward compatibility feature. + +# %% +print("\n=== Testing Backward Compatibility ===") + +# Create a simple untied model for testing +untied_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", # Simple activation for testing + decoder_tying="none", # Untied decoders +) + +print("Creating untied model for compatibility test...") +untied_model = CrossLayerTranscoder( + config=untied_config, + process_group=None, + device=torch.device("cpu"), # Use CPU for this test +) + +# Save untied model state +untied_state_dict = untied_model.state_dict() +print(f"Untied model decoder keys (first 5): {list(k for k in untied_state_dict.keys() if 'decoder' in k)[:5]}") + +# Create tied model with same dimensions +tied_test_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", + decoder_tying="per_source", # Tied decoders + enable_feature_offset=True, # This will be initialized to defaults +) + +tied_test_model = CrossLayerTranscoder( + config=tied_test_config, + process_group=None, + device=torch.device("cpu"), +) + +print("\nLoading untied checkpoint into tied model...") +try: + # This should work due to our custom load_state_dict + tied_test_model.load_state_dict(untied_state_dict, strict=False) + print("✓ Successfully loaded untied checkpoint into tied model!") + print(" The tied model uses diagonal decoder weights from the untied model.") +except Exception as e: + print(f"✗ Failed to load: {e}") + +# Clean up test models +del untied_model, tied_test_model + +# %% [markdown] +# ## 7. Performance Comparison Summary + +# %% +print("\n=== Tied Decoder Architecture Summary ===") +print(f"\nConfiguration used:") +print(f"- Model: {BASE_MODEL_NAME}") +print(f"- Layers: {num_layers}") +print(f"- Hidden dimension: {d_model}") +print(f"- Features per layer: {clt_num_features}") +print(f"- Decoder tying: {clt_config.decoder_tying}") +print(f"- Feature offset: {clt_config.enable_feature_offset}") + +print(f"\nMemory efficiency:") +print(f"- Traditional CLT: {untied_decoders} decoder matrices") +print(f"- Tied decoder CLT: {tied_decoders} decoder matrices") +print(f"- Memory reduction: ~{(1 - tied_decoders/untied_decoders)*100:.0f}%") + +print(f"\nKey benefits:") +print(f"1. Significant memory savings for decoder parameters") +print(f"2. Simpler feature interpretability (one decoder per source)") +print(f"3. Feature offset allows per-feature adaptation") +print(f"4. Backward compatible with existing checkpoints") + +print(f"\nTrade-offs:") +print(f"1. Less flexibility in source-target specific adaptations") +print(f"2. May require careful tuning of feature offset parameters") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial demonstrated: +# - Training a CLT with tied decoder architecture +# - Using feature offset parameters for per-feature bias +# - Significant memory savings compared to traditional CLT +# - Backward compatibility with untied checkpoints +# +# You can experiment with: +# - `per_target_scale` and `per_target_bias` for more flexibility +# - `enable_feature_scale` for per-feature scaling +# - Different values of `k` for BatchTopK +# - Comparing reconstruction quality between tied and untied architectures + +# %% +print(f"\n✓ Tied Decoder Tutorial Complete!") +print(f"Model and logs saved to: {log_dir}") diff --git a/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py b/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py new file mode 100644 index 0000000..f563323 --- /dev/null +++ b/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py @@ -0,0 +1,594 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with GPT-2, BatchTopK, and FP16 +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) +# on **GPT-2** using the **BatchTopK** activation function and **FP16** precision. We will: +# 1. Configure the CLT model for GPT-2, BatchTopK, and FP16 training. +# 2. Generate FP16 activations locally (with manifest) using the ActivationGenerator. +# 3. Configure the trainer to use the locally stored FP16 activations. +# 4. Train the CLT model using BatchTopK activation in mixed precision. +# 5. Save and load the final trained model (which will be JumpReLU if converted). +# 6. Load a model from a distributed checkpoint. +# 7. Perform a post-hoc conversion sweep (θ scaling) on a BatchTopK checkpoint. + +# %% [markdown] +# ## 1. Imports and Setup +# +# First, let's import the necessary components and set up the device. + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributions.normal import Normal # For post-hoc sweep +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging # Import logging + +# Configure logging to show INFO level messages for the notebook +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Import from torch.distributed.checkpoint and related modules later, only when needed for that specific section +# from torch.distributed.checkpoint import load_state_dict +# from torch.distributed.checkpoint.filesystem import FileSystemReader + +# logging.basicConfig(level=logging.DEBUG) + +# Import components from the clt library +# (Ensure the 'clt' directory is in your Python path or installed) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction +BASE_MODEL_NAME = "gpt2" + +# For post-hoc sweep N(0,1) assumption +std_normal = Normal(0, 1) + +# %% [markdown] +# ## 2. Configuration +# +# We configure the CLT, Activation Generation, and Training for GPT-2 with FP16. +# Key changes: `CLTConfig` matches GPT-2 dims, `ActivationConfig` and `TrainingConfig` use FP16. + +# %% +# --- CLT Architecture Configuration --- +num_layers = 12 # GPT-2 small +d_model = 768 # GPT-2 small +expansion_factor = 32 +clt_num_features = 16384 # d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", # Use BatchTopK activation + batchtopk_k=batchtopk_k, # Specify k directly + batchtopk_straight_through=True, # Use STE for gradients + # jumprelu_threshold is not used for batchtopk +) +print("CLT Configuration (BatchTopK for GPT-2):") +print(clt_config) + +# --- Activation Generation Configuration --- +# Generate FP16 activations from GPT-2 +activation_dir = "./tutorial_activations_local_1M_fp16" +dataset_name = "monology/pile-uncopyrighted" +activation_config = ActivationConfig( + # Model Source + model_name=BASE_MODEL_NAME, + mlp_input_module_path_template="transformer.h.{}.ln_2.input", + mlp_output_module_path_template="transformer.h.{}.mlp.output", + model_dtype=None, + # Dataset Source + dataset_path=dataset_name, + dataset_split="train", + dataset_text_column="text", + # Generation Parameters + context_size=128, + inference_batch_size=192, + exclude_special_tokens=True, + prepend_bos=True, + # Dataset Handling + streaming=True, + dataset_trust_remote_code=False, + cache_path=None, + # Generation Output Control + target_total_tokens=1_000_000, # Keep it small for tutorial + # Storage Parameters + activation_dir=activation_dir, + output_format="hdf5", + compression="gzip", + chunk_token_threshold=16_000, + activation_dtype="float16", # Store activations in FP16 + # Normalization + compute_norm_stats=True, + # NNsight args + nnsight_tracer_kwargs={}, + nnsight_invoker_args={}, +) +print("Activation Generation Configuration:") +print(activation_config) + +# --- Training Configuration --- +expected_activation_path = os.path.join( + activation_config.activation_dir, + activation_config.model_name, + f"{os.path.basename(activation_config.dataset_path)}_{activation_config.dataset_split}", +) + +# --- Determine WandB Run Name (using config values) --- +_lr = 1e-4 +_batch_size = 1024 +_k_int = clt_config.batchtopk_k + +wdb_run_name = ( + f"gpt2-{clt_config.num_features}-width-" f"batchtopk-k{_k_int}-" f"{_batch_size}-batch-" f"{_lr:.1e}-lr-fp16" +) +print("\nGenerated WandB run name: " + wdb_run_name) + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Reduced steps for tutorial + seed=42, + # Activation source + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float16", # Load activations in FP16 + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + precision="fp16", # Enable mixed-precision training + # Normalization + normalization_method="mean_std", # Use pre-calculated stats + # Loss function coefficients + sparsity_lambda=0.0, # Disable standard sparsity penalty + sparsity_lambda_schedule="linear", + sparsity_c=0.0, # Disable standard sparsity penalty + preactivation_coef=0, # Disable preactivation loss (AuxK handles dead latents) + aux_loss_factor=1 / 32, # Enable AuxK loss with typical factor from paper + apply_sparsity_penalty_to_batchtopk=False, # Ensure standard sparsity penalty is off for BatchTopK + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, # run diagnostics every eval + max_features_for_diag_hist=1000, # optional cap per layer + checkpoint_interval=500, + dead_feature_window=200, + # WandB (Optional) + enable_wandb=True, + wandb_project="clt-debug-gpt2", + wandb_run_name=wdb_run_name, +) +print("\nTraining Configuration (BatchTopK, FP16):") +print(training_config) + + +# %% [markdown] +# ## 3. Generate Activations (One-Time Step) +# +# Generate the activation dataset for GPT-2 in FP16, including the manifest file. + +# %% +print("Step 1: Generating/Verifying Activations (including manifest)...") + +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if os.path.exists(metadata_path) and os.path.exists(manifest_path): + print(f"Activations and manifest already found at: {expected_activation_path}") + print("Skipping generation. Delete the directory to regenerate.") +else: + print(f"Activations or manifest not found. Generating them now at: {expected_activation_path}") + try: + generator = ActivationGenerator( + cfg=activation_config, + device=device, + ) + generation_start_time = time.time() + generator.generate_and_save() + generation_end_time = time.time() + print(f"Activation generation complete in {generation_end_time - generation_start_time:.2f}s.") + except Exception as gen_err: + print(f"[ERROR] Activation generation failed: {gen_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 4. Training the CLT with BatchTopK Activation and FP16 +# +# Instantiate the `CLTTrainer` for FP16 training. + +# %% +print("Initializing CLTTrainer for training with BatchTopK and FP16...") + +log_dir = f"clt_training_logs/clt_gpt2_batchtopk_fp16_train_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("Creating CLTTrainer instance...") + print(f"- Using device: {device}") + print(f"- CLT config (BatchTopK): {vars(clt_config)}") + print(f"- Activation Source: {training_config.activation_source}") + print(f"- Reading activations from: {training_config.activation_path}") + print(f"- Training precision: {training_config.precision}") + + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("Beginning training using BatchTopK activation and FP16...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Normalization method set to: {training_config.normalization_method}") +print( + f"Standard sparsity penalty applied to BatchTopK activations: {training_config.apply_sparsity_penalty_to_batchtopk}" +) + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"Training finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Final Trained Model +# +# The `CLTTrainer` automatically saves the final model and its configuration (cfg.json) +# in the `log_dir/final/` directory. If the training started with BatchTopK, +# the trainer converts the model to JumpReLU before this final save. +# Here, we'll also demonstrate a manual save of the model state and its config as Python dict, +# and then load it back. This manually saved model will be the one returned by trainer.train(), +# so it will also be JumpReLU if conversion occurred. + +# %% +# The trained_clt_model is what trainer.train() returned. +# If clt_config.activation_fn was 'batchtopk', trainer.train() converts it to JumpReLU in-place. +final_model_state_path = os.path.join(log_dir, "clt_final_manual_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_final_manual_config.json") + +print(f"\nManually saving final model state to: {final_model_state_path}") +print(f"Manually saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + # The config on trained_clt_model will reflect 'jumprelu' if conversion happened + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +print(f"\nContents of log directory ({log_dir}):") +for item in os.listdir(log_dir): + print(f"- {item}") + +# --- Loading the manually saved model --- +print("\nLoading the manually saved model...") + +# 1. Load the saved configuration +with open(final_model_config_path, "r") as f: + loaded_config_dict_manual = json.load(f) +loaded_clt_config_manual = CLTConfig(**loaded_config_dict_manual) + +print(f"Loaded manual config, activation_fn: {loaded_clt_config_manual.activation_fn}") + +# 2. Instantiate model with this loaded config and load state dict +loaded_clt_model_manual = CrossLayerTranscoder( + config=loaded_clt_config_manual, + process_group=None, # Assuming non-distributed for this load + device=torch.device(device), +) +loaded_clt_model_manual.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_clt_model_manual.eval() # Set to evaluation mode + +print("Manually saved model loaded successfully.") +print(f"Loaded model is on device: {next(loaded_clt_model_manual.parameters()).device}") + + +# %% [markdown] +# ## 6. Loading from Distributed Checkpoint (DC) +# +# The trainer saves checkpoints in a distributed-compatible format (using `torch.distributed.checkpoint`) +# in `log_dir/step_/` and `log_dir/final/`. We can load the `final` one. +# This model will also be in JumpReLU format if the original training was BatchTopK. + +# %% +# Imports moved to top: +# from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +# from torch.distributed.checkpoint.filesystem import FileSystemReader + +# Path to the 'final' directory created by the trainer +# This contains the sharded checkpoint and the cfg.json (which reflects JumpReLU if converted) +dc_final_checkpoint_dir = os.path.join(log_dir, "final") + +print(f"\nLoading model from distributed checkpoint: {dc_final_checkpoint_dir}") + +# 1. Load the config from cfg.json in that directory +dc_config_path = os.path.join(dc_final_checkpoint_dir, "cfg.json") +if not os.path.exists(dc_config_path): + print(f"ERROR: cfg.json not found in {dc_final_checkpoint_dir}. Cannot load distributed checkpoint correctly.") +else: + with open(dc_config_path, "r") as f: + loaded_config_dict_dc = json.load(f) + loaded_clt_config_dc = CLTConfig(**loaded_config_dict_dc) + print(f"Loaded DC config, activation_fn: {loaded_clt_config_dc.activation_fn}") + + # 2. Instantiate the model with this config + # Determine device (mps not directly supported by some distributed ops, fallback to cpu if necessary for loading) + device_to_load_on = device if device != "mps" else "cpu" + print(f"Instantiating model on device: {device_to_load_on} for DC load") + + model_for_dc_load = CrossLayerTranscoder( + config=loaded_clt_config_dc, + process_group=None, # For non-distributed loading of a dist checkpoint + device=torch.device(device_to_load_on), + ) + model_for_dc_load.eval() + + # 3. Create an empty state dict and load into it + state_dict_to_populate_dc = model_for_dc_load.state_dict() + + try: + dist_load_state_dict( + state_dict=state_dict_to_populate_dc, + storage_reader=FileSystemReader(dc_final_checkpoint_dir), + no_dist=True, # Important for loading a sharded checkpoint into a non-distributed model + ) + # Load the populated state dict into the model + model_for_dc_load.load_state_dict(state_dict_to_populate_dc) + print("Model loaded successfully from distributed checkpoint.") + print(f"Model is on device: {next(model_for_dc_load.parameters()).device}") + except Exception as e_dc: + print(f"ERROR loading distributed checkpoint: {e_dc}") + traceback.print_exc() + +# %% [markdown] +# ## 7. Post-hoc Conversion Sweep (θ scaling) from a BatchTopK Checkpoint +# +# To experiment with different θ scaling factors for BatchTopK-to-JumpReLU conversion, +# we need a model checkpoint that was saved *before* any automatic conversion by the trainer. +# The trainer saves checkpoints periodically (e.g., `clt_checkpoint_500.pt`). +# We'll load one of these, assuming it's still in BatchTopK format. + +# %% + +# Path to a BatchTopK checkpoint (e.g., one saved mid-training) +# Ensure this checkpoint was saved when the model was still BatchTopK. +# The trainer converts to JumpReLU only at the very end of training if the original was BatchTopK. +# So, a checkpoint from step 500 should be BatchTopK. +# Note: This part uses the log_dir defined in Section 4. If you are running this +# section independently, you'll need to set log_dir to a valid path. +batchtopk_checkpoint_path = os.path.join(log_dir, "clt_checkpoint_500.pt") + +if not os.path.exists(batchtopk_checkpoint_path): + print(f"WARNING: BatchTopK checkpoint {batchtopk_checkpoint_path} not found. Skipping sweep.") + print("Ensure your training ran for at least 500 steps and saved a checkpoint.") +else: + print(f"\nLoading BatchTopK model from checkpoint: {batchtopk_checkpoint_path} for sweep...") + + # clt_config_for_batchtopk_load is now defined INSIDE the loop below + + # 2. Load the BatchTopK model state + batchtopk_model_state = torch.load(batchtopk_checkpoint_path, map_location=device) + + # This is the StateDict from the BatchTopK model + # It will be used as the starting point for each conversion in the sweep. + + # std_normal is already defined at the top of the script if using the sweep code from previous turn + from torch.distributions.normal import Normal # Moved to top + + std_normal = Normal(0, 1) + + # Define quick_l0_checks here + def quick_l0_checks( + model: CrossLayerTranscoder, sample_batch_inputs: Dict[int, torch.Tensor], num_tokens_for_l0_check: int = 100 + ) -> tuple[float, float]: + """Return (avg_empirical_l0_layer0, expected_l0) + using an average over random tokens from sample_batch_inputs for empirical L0.""" + model.eval() + avg_empirical_l0_layer0 = float("nan") + std_normal_dist = torch.distributions.normal.Normal(0, 1) + + # Assume sample_batch_inputs[0] is valid if this function is called after store initialization + layer0_inputs_all_tokens = sample_batch_inputs.get(0) # Use .get() for safety, though we assume it exists + + if layer0_inputs_all_tokens is None or layer0_inputs_all_tokens.numel() == 0: + print("Warning: quick_l0_checks received no valid input for layer 0. Empirical L0 will be NaN.") + else: + layer0_inputs_all_tokens = layer0_inputs_all_tokens.to(device=model.device, dtype=model.dtype) + if layer0_inputs_all_tokens.dim() == 3: # B, S, D + num_tokens_in_batch = layer0_inputs_all_tokens.shape[0] * layer0_inputs_all_tokens.shape[1] + layer0_inputs_flat = layer0_inputs_all_tokens.reshape(num_tokens_in_batch, model.config.d_model) + elif layer0_inputs_all_tokens.dim() == 2: # Already [num_tokens, d_model] + num_tokens_in_batch = layer0_inputs_all_tokens.shape[0] + layer0_inputs_flat = layer0_inputs_all_tokens + else: + print( + f"Warning: quick_l0_checks received unexpected input shape {layer0_inputs_all_tokens.shape} for layer 0. Empirical L0 will be NaN." + ) + layer0_inputs_flat = None + + if layer0_inputs_flat is not None and num_tokens_in_batch > 0: + num_to_sample = min(num_tokens_for_l0_check, num_tokens_in_batch) + indices = torch.randperm(num_tokens_in_batch, device=model.device)[:num_to_sample] + selected_tokens_for_l0 = layer0_inputs_flat[indices] + if selected_tokens_for_l0.numel() > 0: + acts_layer0_selected = model.encode(selected_tokens_for_l0, layer_idx=0) + l0_per_token_selected = (acts_layer0_selected > 1e-6).sum(dim=1).float() + avg_empirical_l0_layer0 = l0_per_token_selected.mean().item() + else: + print( + "Warning: No tokens selected for empirical L0 check after sampling. Empirical L0 will be NaN." + ) + # Removed redundant checks for layer0_inputs_flat being None or num_tokens_in_batch == 0, covered by outer if/else + + expected_l0 = float("nan") + if hasattr(model, "log_threshold") and model.log_threshold is not None: + theta = model.log_threshold.exp().cpu() + p_fire = 1.0 - std_normal_dist.cdf(theta.float()) + expected_l0 = p_fire.sum().item() + else: + print("Warning: Model does not have log_threshold. Cannot compute expected_l0.") + return avg_empirical_l0_layer0, expected_l0 + + # Initialize LocalActivationStore for the sweep, assuming training_config is available from earlier cells + print("Initializing LocalActivationStore for theta estimation sweep...") + posthoc_activation_store: Optional[BaseActivationStore] = None + try: + from clt.training.data.local_activation_store import LocalActivationStore # Ensure import + + if training_config.activation_path is None: # This check is still good practice + raise ValueError("training_config.activation_path is None. Cannot initialize activation store for sweep.") + + posthoc_activation_store = LocalActivationStore( + dataset_path=training_config.activation_path, + train_batch_size_tokens=1024, # Can use a reasonable batch size for estimation + device=torch.device(device), + dtype=training_config.activation_dtype, + rank=0, + world=1, + seed=42, + sampling_strategy="sequential", + normalization_method="auto", + ) + print(f"Successfully initialized LocalActivationStore from: {training_config.activation_path}") + except NameError: # Handles case where training_config might not be defined if cells are run out of order + print("Error: 'training_config' not defined. Please ensure previous cells initializing it have been run.") + print("Skipping post-hoc theta scaling sweep.") + except Exception as e_store_init: + print(f"Error initializing LocalActivationStore for post-hoc sweep: {e_store_init}") + print("Skipping post-hoc theta scaling sweep.") + + if posthoc_activation_store: + scale_factors = [1.0] + n_batches_for_theta_estimation = 1 # Number of batches to use for theta estimation + + print("\n=== θ-scaling sweep (from BatchTopK checkpoint) using estimate_theta_posthoc ===") + print(f"Using {n_batches_for_theta_estimation} batches for theta estimation in each iteration.") + + # Import tqdm for the progress bar + from tqdm.auto import tqdm + + for sf in tqdm(scale_factors, desc="Scaling Factor Sweep"): + # Define clt_config_for_batchtopk_load INSIDE the loop + # to ensure a fresh BatchTopK config for each iteration. + clt_config_for_sweep = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", # Start with BatchTopK config + batchtopk_k=batchtopk_k, # Specify k directly + batchtopk_straight_through=True, + clt_dtype="float32", # Match model dtype for consistency during load + ) + + tmp_model_for_sweep = CrossLayerTranscoder( + config=clt_config_for_sweep, + process_group=None, + device=torch.device(device), + ) + # Load the original BatchTopK state dict + tmp_model_for_sweep.load_state_dict(batchtopk_model_state) + tmp_model_for_sweep.eval() + + print(f"Estimating theta and converting with scale_factor = {sf:.2f}...") + try: + # Ensure the data iterator is reset or re-created if it's a one-shot iterator + # For this tutorial, assuming posthoc_activation_store can be iterated multiple times + # or we re-initialize it if it's a generator type that gets exhausted. + data_iterator_for_estimation = iter(posthoc_activation_store) + + estimated_thetas = tmp_model_for_sweep.estimate_theta_posthoc( + data_iter=data_iterator_for_estimation, + num_batches=n_batches_for_theta_estimation, + scale_factor=sf, + default_theta_value=1e6, # Default from convert_to_jumprelu_inplace + ) + # estimate_theta_posthoc now calls convert_to_jumprelu_inplace internally + print(f"Estimated theta shape: {estimated_thetas.shape}") + # Now tmp_model_for_sweep is a JumpReLU model + + # Get a sample batch for quick_l0_checks + # We need to be careful if data_iterator_for_estimation was exhausted + # For simplicity, let's try to get one more batch or re-initialize iterator for this check + sample_batch_for_l0_check_inputs: Dict[int, torch.Tensor] = {} + try: + sample_inputs_l0, _ = next(data_iterator_for_estimation) # Try to get next from current iterator + sample_batch_for_l0_check_inputs = sample_inputs_l0 + except StopIteration: + print("Warning: data_iterator_for_estimation exhausted. Re-initializing for L0 check.") + try: + reinitialized_iterator = iter(posthoc_activation_store) + sample_inputs_l0, _ = next(reinitialized_iterator) + sample_batch_for_l0_check_inputs = sample_inputs_l0 + except Exception as e_reinit_fetch: + print(f"Error re-fetching batch for L0 check: {e_reinit_fetch}. L0 check might use zeros.") + except Exception as e_fetch_l0_batch: + print(f"Error fetching batch for L0 check: {e_fetch_l0_batch}. L0 check might use zeros.") + + d_l0, exp_l0 = quick_l0_checks(tmp_model_for_sweep, sample_batch_for_l0_check_inputs) + print( + f"scale {sf:4.2f} | dummy-L0 {d_l0:6.0f} | expected-L0 {exp_l0:7.1f} (num_features={tmp_model_for_sweep.config.num_features}, num_layers={tmp_model_for_sweep.config.num_layers})" + ) + except Exception as e_sweep_iter: + print(f"ERROR during sweep iteration for scale_factor={sf:.2f}: {e_sweep_iter}") + traceback.print_exc() + continue # Continue to next scale factor + else: + print("Skipping post-hoc theta scaling sweep as activation store could not be initialized.") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial showed how to train a CLT for GPT-2 using BatchTopK activation and FP16, +# save/load models, and perform a post-hoc analysis. + +# %% +print("\nGPT-2 FP16 BatchTopK Tutorial Complete!") +print(f"Logs and checkpoints are saved in: {log_dir}")