transformer_lens.components#

Hooked Transformer Components.

This module contains all the components (e.g. Attention, MLP, LayerNorm) needed to create many different types of generative language models. They are used by transformer_lens.HookedTransformer.

class transformer_lens.components.Attention(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Bases: Module

property OV: FactoredMatrix#

OV-Circuit, as defined in A Mathematical Framework. Because there’s no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more)

Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry!

Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works.

property QK: FactoredMatrix#

QK-Circuit, as defined in A Mathematical Framework. Because there’s no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more).

Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos]

Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works.

__init__(cfg: Union[Dict, HookedTransformerConfig], attn_type: str = 'global', layer_id: Optional[int] = None)#

Attention Block - params have shape [head_index, d_model, d_head] (or [head_index, d_head, d_model] for W_O) and multiply on the right. attn_scores refers to query key dot product immediately before attention softmax

Convention: All attention pattern-style matrices have shape [batch, head_index, query_pos, key_pos]

Parameters:
  • cfg (Union[Dict, HookedTransformerConfig]) – Config

  • attn_type (str, optional) – “global” or “local”, used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to “global”.

  • layer_id (int, optional) – The index of the current layer. Used by the Mistal models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None.

apply_causal_mask(attn_scores: Float[Tensor, 'batch head_index pos pos_plus_past_kv_pos_offset'], past_kv_pos_offset: int = 0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None)#
apply_rotary(x: Float[Tensor, 'batch pos head_index d_head'], past_kv_pos_offset=0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos head_index d_head']#
calculate_sin_cos_rotary(rotary_dim: int, n_ctx: int, base: int = 10000, dtype: dtype = torch.float32) Tuple[Float[Tensor, 'n_ctx rotary_dim'], Float[Tensor, 'n_ctx rotary_dim']]#

Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details

Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. To resolve this, I’ve coded it to default to the GPT-J mode, but to explicitly check whether it’s GPT-NeoX and then do the GPT-NeoX thing if it is.

forward(query_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], key_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], value_input: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 pos']] = None, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model']#

shortformer_pos_embed is only used if self.cfg.positional_embedding_type == “shortformer”, else defaults to None and is irrelevant. See HookedTransformerConfig for more details past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. attention_mask is the attention mask for padded tokens. Defaults to None.

rotate_every_two(x: Float[Tensor, '... rotary_dim']) Float[Tensor, '... rotary_dim']#

Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0]

The final axis of x must have even length.

GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details.

class transformer_lens.components.BertBlock(cfg: HookedTransformerConfig)#

Bases: Module

BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.

forward(resid_pre: Float[Tensor, 'batch pos d_model'], additive_attention_mask: Optional[Float[Tensor, 'batch 1 1 pos']] = None)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.BertEmbed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.

forward(input_ids: Int[Tensor, 'batch pos'], token_type_ids: Optional[Int[Tensor, 'batch pos']] = None)#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.BertMLMHead(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.

forward(resid: Float[Tensor, 'batch pos d_model']) Tensor#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.Embed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

forward(tokens: Int[Tensor, 'batch pos']) Float[Tensor, 'batch pos d_model']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.GatedMLP(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

The equation of a gated MLP: pre = x @ W_gate pre_linear = x @ W_in post = Gelu(pre) * (pre_linear) + b_in mlp_out = post @ W_out + b_out

In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out

forward(x: Float[Tensor, 'batch pos d_model']) Float[Tensor, 'batch pos d_model']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.LayerNorm(cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None)#

Bases: Module

__init__(cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None)#

LayerNorm with optional length parameter

length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model

forward(x: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']]) Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.LayerNormPre(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

__init__(cfg: Union[Dict, HookedTransformerConfig])#

LayerNormPre - the ‘center and normalise’ part of LayerNorm. Length is normally d_model, but is d_mlp for softmax. Not needed as a parameter. This should only be used in inference mode after folding in LayerNorm weights

forward(x: Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']]) Union[Float[Tensor, 'batch pos d_model'], Float[Tensor, 'batch pos head_index d_model']]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.MLP(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

forward(x: Float[Tensor, 'batch pos d_model']) Float[Tensor, 'batch pos d_model']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.PosEmbed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

forward(tokens: Int[Tensor, 'batch pos'], past_kv_pos_offset: int = 0, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model']#

Forward pass for positional embeddings.

Parameters:
  • tokens (Int[torch.Tensor, "batch pos"]) – Input tokens.

  • past_kv_pos_offset (int, optional) – The length of tokens in the past_kv_cache. Defaults to 0.

  • attention_mask (Int[torch.Tensor, "batch pos"], optional) – The attention mask for padded tokens. Defaults to None.

Returns:

Absolute position embeddings.

Return type:

Float[torch.Tensor, “batch pos d_model”]

class transformer_lens.components.RMSNorm(cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None)#

Bases: Module

__init__(cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None)#

RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)

length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model

forward(x: Float[Tensor, 'batch pos length']) Float[Tensor, 'batch pos length']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.RMSNormPre(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

__init__(cfg: Union[Dict, HookedTransformerConfig])#

RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)

forward(x: Float[Tensor, 'batch pos length']) Float[Tensor, 'batch pos length']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.TokenTypeEmbed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: “[CLS] Sentence A [SEP] Sentence B [SEP]”, token_type_ids would be [0, 0, …, 0, 1, …, 1, 1]. 0 represents tokens from Sentence A, 1 from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).

See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf

forward(token_type_ids: Int[Tensor, 'batch pos'])#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class transformer_lens.components.TransformerBlock(cfg: Union[Dict, HookedTransformerConfig], block_index)#

Bases: Module

forward(resid_pre: Float[Tensor, 'batch pos d_model'], shortformer_pos_embed: Optional[Float[Tensor, 'batch pos d_model']] = None, past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, attention_mask: Optional[Int[Tensor, 'batch offset_pos']] = None) Float[Tensor, 'batch pos d_model']#

A single Transformer block.

Parameters:
  • resid_pre (torch.Tensor) – The residual stream - shape [batch, pos, d_model]

  • cache (HookedTransformerKeyValueCache) – A cache of previous keys and values, used only when generating text. Defaults to None.

  • shortformer_pos_embed (torch.Tensor, optional) – Only used for positional_embeddings_type == “shortformer”. The positional embeddings. See HookedTransformerConfig for details. Defaults to None.

  • attention_mask (torch.Tensor, optional) – The attention mask for padded tokens. Defaults to None.

Returns:

_description_

Return type:

_type_

class transformer_lens.components.Unembed(cfg: Union[Dict, HookedTransformerConfig])#

Bases: Module

forward(residual: Float[Tensor, 'batch pos d_model']) Float[Tensor, 'batch pos d_vocab_out']#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.