diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 583ce1aea..e9343e619 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -174,7 +174,9 @@ def __init__( else: # If no tokenizer name is provided, we assume we're training on an algorithmic task and # will pass in tokens directly. In this case, we don't need a tokenizer. - assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" + assert ( + self.cfg.d_vocab != -1 + ), "Must provide a tokenizer if d_vocab is not provided" self.tokenizer = None if default_padding_side != "right": logging.warning( @@ -192,7 +194,10 @@ def __init__( self.hook_tokens = HookPoint() # [batch, pos] self.blocks = nn.ModuleList( - [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] + [ + TransformerBlock(self.cfg, block_index) + for block_index in range(self.cfg.n_layers) + ] ) if self.cfg.normalization_type == "RMS": @@ -214,7 +219,9 @@ def __init__( # If it's None, don't create either layer pass else: - logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) + logging.warning( + "Invalid normalization_type passed in %s", self.cfg.normalization_type + ) self.unembed = Unembed(self.cfg) if self.cfg.init_weights: @@ -340,7 +347,9 @@ def input_to_embed( self, input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, attention_mask: Optional[torch.Tensor] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, ) -> Tuple[ @@ -369,7 +378,9 @@ def input_to_embed( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" # This is only intended to support passing in a single string - tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + tokens = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) else: tokens = input if len(tokens.shape) == 1: @@ -391,14 +402,20 @@ def input_to_embed( if prepend_bos is USE_DEFAULT_VALUE: prepend_bos = self.cfg.default_prepend_bos if self.tokenizer is None: - raise ValueError("Cannot compute attention mask without a tokenizer.") - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) + raise ValueError( + "Cannot compute attention mask without a tokenizer." + ) + attention_mask = utils.get_attention_mask( + self.tokenizer, tokens, prepend_bos + ) assert attention_mask.shape == tokens.shape, ( f"Attention mask shape {attention_mask.shape} does not match tokens shape " f"{tokens.shape}" ) - attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) + attention_mask = attention_mask.to( + devices.get_device_for_block_index(0, self.cfg) + ) if past_kv_cache is not None: # past_kv_cache is not None, so we're doing caching. # We need to extend the previous attention_mask. @@ -432,15 +449,18 @@ def forward( return_type: Literal["logits"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, - ) -> Loss: - ... + ) -> Loss: ... @overload def forward( @@ -449,15 +469,18 @@ def forward( return_type: Literal["loss"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, - ) -> Loss: - ... + ) -> Loss: ... @overload def forward( @@ -466,15 +489,18 @@ def forward( return_type: Literal["both"], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, - ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: - ... + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: ... @overload def forward( @@ -483,15 +509,18 @@ def forward( return_type: Literal[None], loss_per_token: bool = False, prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, - ) -> None: - ... + ) -> None: ... def forward( self, @@ -507,7 +536,9 @@ def forward( padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, start_at_layer: Optional[int] = None, tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, - shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, + shortformer_pos_embed: Optional[ + Float[torch.Tensor, "batch pos d_model"] + ] = None, attention_mask: Optional[torch.Tensor] = None, # [batch pos] stop_at_layer: Optional[int] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, @@ -621,7 +652,9 @@ def forward( residual, # Cache contains a list of HookedTransformerKeyValueCache objects, one for each # block - past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, + past_kv_cache_entry=( + past_kv_cache[i] if past_kv_cache is not None else None + ), shortformer_pos_embed=shortformer_pos_embed, attention_mask=attention_mask, ) # [batch, pos, d_model] @@ -646,7 +679,9 @@ def forward( assert ( tokens is not None ), "tokens must be passed in if return_type is 'loss' or 'both'" - loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) + loss = self.loss_fn( + logits, tokens, attention_mask, per_token=loss_per_token + ) if return_type == "loss": return loss elif return_type == "both": @@ -673,14 +708,12 @@ def loss_fn( @overload def run_with_cache( self, *model_args, return_cache_object: Literal[True] = True, **kwargs - ) -> Tuple[Output, ActivationCache]: - ... + ) -> Tuple[Output, ActivationCache]: ... @overload def run_with_cache( self, *model_args, return_cache_object: Literal[False], **kwargs - ) -> Tuple[Output, Dict[str, torch.Tensor]]: - ... + ) -> Tuple[Output, Dict[str, torch.Tensor]]: ... def run_with_cache( self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs @@ -703,7 +736,9 @@ def run_with_cache( *model_args, remove_batch_dim=remove_batch_dim, **kwargs ) if return_cache_object: - cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) + cache = ActivationCache( + cache_dict, self, has_batch_dim=not remove_batch_dim + ) return out, cache else: return out, cache_dict @@ -759,7 +794,9 @@ def to_tokens( self, input: Union[str, List[str]], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, move_to_device: bool = True, truncate: bool = True, ) -> Int[torch.Tensor, "batch pos"]: @@ -796,14 +833,18 @@ def to_tokens( with utils.LocallyOverridenDefaults( self, prepend_bos=prepend_bos, padding_side=padding_side ): - assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" + assert ( + self.tokenizer is not None + ), "Cannot use to_tokens without a tokenizer" assert ( self.cfg.tokenizer_prepends_bos is not None ), "Set the tokenizer for the model by calling set_tokenizer" if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually - input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) + input = utils.get_input_with_manually_prepended_bos( + self.tokenizer, input + ) tokens = self.tokenizer( input, @@ -848,7 +889,9 @@ def to_string( # it's set, then tokenization is no longer invertible, and some tokens # with a bunch of whitespace get collapsed together if len(tokens.shape) == 2: - return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) + return self.tokenizer.batch_decode( + tokens, clean_up_tokenization_spaces=False + ) elif len(tokens.shape) <= 1: return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) else: @@ -865,7 +908,9 @@ def to_str_tokens( list, ], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, ) -> Union[List[str], List[List[str]]]: """Map text, a list of text or tokens to a list of tokens as strings. @@ -906,14 +951,16 @@ def to_str_tokens( if isinstance(input, list): return list( map( - lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), + lambda tokens: self.to_str_tokens( + tokens, prepend_bos, padding_side + ), input, ) ) # type: ignore elif isinstance(input, str): - tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ - 0 - ] + tokens = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + )[0] # Gemma tokenizer expects a batch dimension if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: tokens = tokens.unsqueeze(1) @@ -937,7 +984,9 @@ def to_str_tokens( ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" else: raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") - str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) + str_tokens = self.tokenizer.batch_decode( + tokens, clean_up_tokenization_spaces=False + ) return str_tokens def to_single_token(self, string): @@ -962,10 +1011,14 @@ def to_single_str_token(self, int_token: int) -> str: def get_token_position( self, single_token: Union[str, int], - input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], + input: Union[ + str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]] + ], mode="first", prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, - padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + padding_side: Optional[ + Union[Literal["left", "right"], None] + ] = USE_DEFAULT_VALUE, ): """Get the position of a single_token in a string or sequence of tokens. @@ -996,7 +1049,9 @@ def get_token_position( """ if isinstance(input, str): # If the input is a string, convert to tensor - tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + tokens = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) else: tokens = input @@ -1013,7 +1068,9 @@ def get_token_position( elif isinstance(single_token, torch.Tensor): single_token = single_token.item() - indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] + indices = torch.arange(len(tokens), device=tokens.device)[ + tokens == single_token + ] assert len(indices) > 0, "The token does not occur in the prompt" if mode == "first": return indices[0].item() @@ -1301,7 +1358,8 @@ def from_pretrained( quant_method = qc.get("quant_method", "") assert not load_in_8bit, "8-bit quantization is not supported" assert not ( - load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) + load_in_4bit + and (version.parse(torch.__version__) < version.parse("2.1.1")) ), "Quantization is only supported for torch versions >= 2.1.1" assert not ( load_in_4bit and ("llama" not in model_name.lower()) @@ -1325,7 +1383,9 @@ def from_pretrained( (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) or dtype == torch.float16 ) and device in ["cpu", None]: - logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") + logging.warning( + "float16 models may not work on CPU. Consider using a GPU or bfloat16." + ) # Get the model name used in HuggingFace, rather than the alias. official_model_name = loading.get_official_model_name(model_name) @@ -1530,9 +1590,13 @@ def _init_weights_kaiming(self, dist_type="uniform"): for name, param in self.named_parameters(): if "W_" in name: if dist_type == "uniform": - init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") + init_kaiming_uniform_( + param, gain=gain, nonlinearity="relu", mode="fan_in" + ) elif dist_type == "normal": - init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") + init_kaiming_normal_( + param, gain=gain, nonlinearity="relu", mode="fan_in" + ) def _init_weights_muP(self, dist_type="uniform"): """ @@ -1618,16 +1682,44 @@ def load_and_process_state_dict( logging.warning( "You are using MoE, so the layer norm weights can't be folded! Skipping" ) - elif self.cfg.normalization_type in ["LN", "LNPre"]: - state_dict = self.fold_layer_norm(state_dict) - elif self.cfg.normalization_type in ["RMS", "RMSPre"]: - state_dict = self.fold_layer_norm( - state_dict, fold_biases=False, center_weights=False - ) - else: + elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: logging.warning( "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" ) + else: + ln_keys_present = any( + k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict + ) + if not ln_keys_present: + logging.warning( + "fold_ln=True but no LayerNorm weights found in state_dict. " + "The model may have been saved with already-folded LayerNorms. " + "Skipping fold." + ) + else: + if self.cfg.normalization_type == "LN": + self.cfg.normalization_type = "LNPre" + self.ln_final = LayerNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = LayerNormPre(self.cfg) + layer.ln2 = LayerNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = LayerNormPre(self.cfg) + elif self.cfg.normalization_type == "RMS": + self.cfg.normalization_type = "RMSPre" + self.ln_final = RMSNormPre(self.cfg) + for layer in self.blocks: + layer.ln1 = RMSNormPre(self.cfg) + layer.ln2 = RMSNormPre(self.cfg) + if self.cfg.is_layer_norm_activation(): + layer.mlp.ln = RMSNormPre(self.cfg) + + if self.cfg.normalization_type in ["LNPre"]: + state_dict = self.fold_layer_norm(state_dict) + elif self.cfg.normalization_type in ["RMSPre"]: + state_dict = self.fold_layer_norm( + state_dict, fold_biases=False, center_weights=False + ) if center_writing_weights: if self.cfg.normalization_type not in ["LN", "LNPre"]: @@ -1658,6 +1750,9 @@ def load_and_process_state_dict( self.load_state_dict({key: state_dict[key]}, strict=False) del state_dict[key] + if fold_ln: + self.setup() + def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict) @@ -1688,10 +1783,14 @@ def fold_layer_norm( # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need # to sum along axis -2, which is the residual stream space axis. if fold_biases: - state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( + state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[ + f"blocks.{l}.attn.b_Q" + ] + ( state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.b"][None, :, None] - ).sum(-2) + ).sum( + -2 + ) state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ f"blocks.{l}.attn.{gqa}b_K" ] + ( @@ -1711,7 +1810,8 @@ def fold_layer_norm( del state_dict[f"blocks.{l}.ln1.b"] state_dict[f"blocks.{l}.attn.W_Q"] = ( - state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] + state_dict[f"blocks.{l}.attn.W_Q"] + * state_dict[f"blocks.{l}.ln1.w"][None, :, None] ) state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( state_dict[f"blocks.{l}.attn.{gqa}W_K"] @@ -1748,14 +1848,19 @@ def fold_layer_norm( # Fold ln2 into MLP if not self.cfg.attn_only: if fold_biases: - state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( + state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[ + f"blocks.{l}.mlp.b_in" + ] + ( state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.b"][:, None] - ).sum(-2) + ).sum( + -2 + ) del state_dict[f"blocks.{l}.ln2.b"] state_dict[f"blocks.{l}.mlp.W_in"] = ( - state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] + state_dict[f"blocks.{l}.mlp.W_in"] + * state_dict[f"blocks.{l}.ln2.w"][:, None] ) if self.cfg.gated_mlp: @@ -1812,7 +1917,9 @@ def fold_layer_norm( ).sum(dim=-2) del state_dict[f"ln_final.b"] - state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] + state_dict[f"unembed.W_U"] = ( + state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] + ) del state_dict[f"ln_final.w"] if center_weights: @@ -1830,28 +1937,30 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): W_out. This is done by subtracting the mean of the weights from the weights themselves. This is done in-place. See fold_layer_norm for more details. """ - state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( - -1, keepdim=True - ) + state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict[ + "embed.W_E" + ].mean(-1, keepdim=True) if self.cfg.positional_embedding_type != "rotary": state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ "pos_embed.W_pos" ].mean(-1, keepdim=True) for l in range(self.cfg.n_layers): - state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ + state_dict[f"blocks.{l}.attn.W_O"] = state_dict[ f"blocks.{l}.attn.W_O" - ].mean( + ] - state_dict[f"blocks.{l}.attn.W_O"].mean( -1, keepdim=True ) # W_O is [head_index, d_model, d_head] state_dict[f"blocks.{l}.attn.b_O"] = ( - state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() + state_dict[f"blocks.{l}.attn.b_O"] + - state_dict[f"blocks.{l}.attn.b_O"].mean() ) # b_O is [d_model] if not self.cfg.attn_only: state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ f"blocks.{l}.mlp.W_out" ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) state_dict[f"blocks.{l}.mlp.b_out"] = ( - state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() + state_dict[f"blocks.{l}.mlp.b_out"] + - state_dict[f"blocks.{l}.mlp.b_out"].mean() ) return state_dict @@ -1864,10 +1973,12 @@ def center_unembed(self, state_dict: Dict[str, torch.Tensor]): how components contribute to the logits, we'll be less misled by components that just add something to every logit. """ - state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( - -1, keepdim=True + state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict[ + "unembed.W_U" + ].mean(-1, keepdim=True) + state_dict["unembed.b_U"] = ( + state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() ) - state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() return state_dict def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): @@ -1984,7 +2095,9 @@ def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): b_O = state_dict[f"blocks.{l}.attn.b_O"] # Add singleton dimension for broadcasting - b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") + b_V_expanded = einops.rearrange( + b_V, "head_index d_head -> head_index d_head 1" + ) # Element-wise multiplication of b_V and W_O b_V_times_W_O = b_V_expanded * W_O @@ -2032,7 +2145,9 @@ def set_use_attn_in(self, use_attn_in: bool): ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" self.cfg.use_attn_in = use_attn_in - def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): + def set_ungroup_grouped_query_attention( + self, ungroup_grouped_query_attention: bool + ): """ Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). """ @@ -2052,31 +2167,6 @@ def process_weights_( version of the same model. """ state_dict = self.state_dict() - if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: - # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing - # A warning is already issued in `load_and_process_state_dict` - pass - elif fold_ln and self.cfg.normalization_type == "LN": - # If we're folding the LN into the weights, we need to replace all the layernorm layers - # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, - # but it's the easiest way to do it. - self.cfg.normalization_type = "LNPre" - self.ln_final = LayerNormPre(self.cfg) - for layer in self.blocks: - layer.ln1 = LayerNormPre(self.cfg) - layer.ln2 = LayerNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = LayerNormPre(self.cfg) - elif fold_ln and self.cfg.normalization_type == "RMS": - # We do the same for RMSNorm if used - self.cfg.normalization_type = "RMSPre" - self.ln_final = RMSNormPre(self.cfg) - for layer in self.blocks: - layer.ln1 = RMSNormPre(self.cfg) - layer.ln2 = RMSNormPre(self.cfg) - if self.cfg.is_layer_norm_activation(): - layer.mlp.ln = RMSNormPre(self.cfg) - self.load_and_process_state_dict( state_dict, fold_ln=fold_ln, @@ -2193,7 +2283,9 @@ def generate( assert ( self.tokenizer is not None ), "Must provide a tokenizer if passing a string to the model" - input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) + input = self.to_tokens( + input, prepend_bos=prepend_bos, padding_side=padding_side + ) elif input.ndim == 2: input_type = "tokens" else: @@ -2220,7 +2312,8 @@ def generate( assert self.tokenizer is not None if stop_at_eos: tokenizer_has_eos_token = ( - self.tokenizer is not None and self.tokenizer.eos_token_id is not None + self.tokenizer is not None + and self.tokenizer.eos_token_id is not None ) if eos_token_id is None: assert ( @@ -2236,11 +2329,15 @@ def generate( # eos_token_id is a Sequence (e.g. list or tuple) stop_tokens = eos_token_id eos_token_for_padding = ( - self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] + self.tokenizer.eos_token_id + if tokenizer_has_eos_token + else eos_token_id[0] ) # An array to track which sequences in the batch have finished. - finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) + finished_sequences = torch.zeros( + batch_size, dtype=torch.bool, device=self.cfg.device + ) # Currently nothing in HookedTransformer changes with eval, but this is here in case # that changes in the future. @@ -2251,7 +2348,9 @@ def generate( tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) attention_mask = utils.get_attention_mask( - self.tokenizer, tokens, False if prepend_bos is None else prepend_bos + self.tokenizer, + tokens, + False if prepend_bos is None else prepend_bos, ).to(device) residual, shortformer_pos_embed = self.get_residual( embeds, @@ -2314,7 +2413,11 @@ def generate( freq_penalty=freq_penalty, tokens=( torch.cat( - (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 + ( + input_tokens, + torch.cat(sampled_tokens_list, dim=1), + ), + dim=1, ) if "sampled_tokens" in locals() else input_tokens @@ -2322,7 +2425,10 @@ def generate( ).to(devices.get_device_for_block_index(0, self.cfg)) else: sampled_tokens = utils.sample_logits( - final_logits, top_k=top_k, top_p=top_p, temperature=temperature + final_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, ).to(devices.get_device_for_block_index(0, self.cfg)) else: sampled_tokens = final_logits.argmax(-1).to( @@ -2341,7 +2447,9 @@ def generate( ) ) - embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) + embeds = torch.hstack( + [embeds, self.embed(sampled_tokens.unsqueeze(-1))] + ) if stop_at_eos and finished_sequences.all(): break @@ -2513,7 +2621,9 @@ def accumulated_bias( if include_mlp_biases: accumulated_bias += cast(torch.Tensor, block.mlp.b_out) if mlp_input: - assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" + assert ( + layer < self.cfg.n_layers + ), "Cannot include attn_bias from beyond the final layer" block = cast(TransformerBlock, self.blocks[layer]) accumulated_bias += cast(torch.Tensor, block.attn.b_O) return accumulated_bias @@ -2548,14 +2658,20 @@ def all_composition_scores( # layer than the left head. mask = ( torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] - < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] + < torch.arange(self.cfg.n_layers, device=self.cfg.device)[ + None, None, :, None + ] ) scores = torch.where(mask, scores, torch.zeros_like(scores)) return scores def all_head_labels(self): """Returns a list of all head names in the model.""" - return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + return [ + f"L{l}H{h}" + for l in range(self.cfg.n_layers) + for h in range(self.cfg.n_heads) + ] def load_sample_training_dataset(self, **kwargs): """Load Sample Training Dataset.