Skip to content

API

ActivationsStore

Class for streaming tokens and generating and storing activations while training SAEs.

Source code in sae_lens/training/activations_store.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
class ActivationsStore:
    """
    Class for streaming tokens and generating and storing activations
    while training SAEs.
    """

    model: HookedRootModule
    dataset: HfDataset
    cached_activations_path: str | None
    cached_activation_dataset: Dataset | None = None
    tokens_column: Literal["tokens", "input_ids", "text", "problem"]
    hook_name: str
    hook_layer: int
    hook_head_index: int | None
    _dataloader: Iterator[Any] | None = None
    _storage_buffer: torch.Tensor | None = None
    device: torch.device

    @classmethod
    def from_config(
        cls,
        model: HookedRootModule,
        cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
        override_dataset: HfDataset | None = None,
    ) -> "ActivationsStore":
        cached_activations_path = cfg.cached_activations_path
        # set cached_activations_path to None if we're not using cached activations
        if (
            isinstance(cfg, LanguageModelSAERunnerConfig)
            and not cfg.use_cached_activations
        ):
            cached_activations_path = None

        if override_dataset is None and cfg.dataset_path == "":
            raise ValueError(
                "You must either pass in a dataset or specify a dataset_path in your configutation."
            )

        return cls(
            model=model,
            dataset=override_dataset or cfg.dataset_path,
            streaming=cfg.streaming,
            hook_name=cfg.hook_name,
            hook_layer=cfg.hook_layer,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            d_in=cfg.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.store_batch_size_prompts,
            train_batch_size_tokens=cfg.train_batch_size_tokens,
            prepend_bos=cfg.prepend_bos,
            normalize_activations=cfg.normalize_activations,
            device=torch.device(cfg.act_store_device),
            dtype=cfg.dtype,
            cached_activations_path=cached_activations_path,
            model_kwargs=cfg.model_kwargs,
            autocast_lm=cfg.autocast_lm,
            dataset_trust_remote_code=cfg.dataset_trust_remote_code,
            seqpos_slice=cfg.seqpos_slice,
        )

    @classmethod
    def from_sae(
        cls,
        model: HookedRootModule,
        sae: SAE,
        context_size: int | None = None,
        dataset: HfDataset | str | None = None,
        streaming: bool = True,
        store_batch_size_prompts: int = 8,
        n_batches_in_buffer: int = 8,
        train_batch_size_tokens: int = 4096,
        total_tokens: int = 10**9,
        device: str = "cpu",
    ) -> "ActivationsStore":
        return cls(
            model=model,
            dataset=sae.cfg.dataset_path if dataset is None else dataset,
            d_in=sae.cfg.d_in,
            hook_name=sae.cfg.hook_name,
            hook_layer=sae.cfg.hook_layer,
            hook_head_index=sae.cfg.hook_head_index,
            context_size=sae.cfg.context_size if context_size is None else context_size,
            prepend_bos=sae.cfg.prepend_bos,
            streaming=streaming,
            store_batch_size_prompts=store_batch_size_prompts,
            train_batch_size_tokens=train_batch_size_tokens,
            n_batches_in_buffer=n_batches_in_buffer,
            total_training_tokens=total_tokens,
            normalize_activations=sae.cfg.normalize_activations,
            dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
            dtype=sae.cfg.dtype,
            device=torch.device(device),
            seqpos_slice=sae.cfg.seqpos_slice,
        )

    def __init__(
        self,
        model: HookedRootModule,
        dataset: HfDataset | str,
        streaming: bool,
        hook_name: str,
        hook_layer: int,
        hook_head_index: int | None,
        context_size: int,
        d_in: int,
        n_batches_in_buffer: int,
        total_training_tokens: int,
        store_batch_size_prompts: int,
        train_batch_size_tokens: int,
        prepend_bos: bool,
        normalize_activations: str,
        device: torch.device,
        dtype: str,
        cached_activations_path: str | None = None,
        model_kwargs: dict[str, Any] | None = None,
        autocast_lm: bool = False,
        dataset_trust_remote_code: bool | None = None,
        seqpos_slice: tuple[int | None, ...] = (None,),
    ):
        self.model = model
        if model_kwargs is None:
            model_kwargs = {}
        self.model_kwargs = model_kwargs
        self.dataset = (
            load_dataset(
                dataset,
                split="train",
                streaming=streaming,
                trust_remote_code=dataset_trust_remote_code,  # type: ignore
            )
            if isinstance(dataset, str)
            else dataset
        )

        if isinstance(dataset, (Dataset, DatasetDict)):
            self.dataset = cast(Dataset | DatasetDict, self.dataset)
            n_samples = len(self.dataset)

            if n_samples < total_training_tokens:
                warnings.warn(
                    f"The training dataset contains fewer samples ({n_samples}) than the number of samples required by your training configuration ({total_training_tokens}). This will result in multiple training epochs and some samples being used more than once."
                )

        self.hook_name = hook_name
        self.hook_layer = hook_layer
        self.hook_head_index = hook_head_index
        self.context_size = context_size
        self.d_in = d_in
        self.n_batches_in_buffer = n_batches_in_buffer
        self.half_buffer_size = n_batches_in_buffer // 2
        self.total_training_tokens = total_training_tokens
        self.store_batch_size_prompts = store_batch_size_prompts
        self.train_batch_size_tokens = train_batch_size_tokens
        self.prepend_bos = prepend_bos
        self.normalize_activations = normalize_activations
        self.device = torch.device(device)
        self.dtype = DTYPE_MAP[dtype]
        self.cached_activations_path = cached_activations_path
        self.autocast_lm = autocast_lm
        self.seqpos_slice = seqpos_slice

        self.n_dataset_processed = 0

        self.estimated_norm_scaling_factor = 1.0

        # Check if dataset is tokenized
        dataset_sample = next(iter(self.dataset))

        # check if it's tokenized
        if "tokens" in dataset_sample.keys():
            self.is_dataset_tokenized = True
            self.tokens_column = "tokens"
        elif "input_ids" in dataset_sample.keys():
            self.is_dataset_tokenized = True
            self.tokens_column = "input_ids"
        elif "text" in dataset_sample.keys():
            self.is_dataset_tokenized = False
            self.tokens_column = "text"
        elif "problem" in dataset_sample.keys():
            self.is_dataset_tokenized = False
            self.tokens_column = "problem"
        else:
            raise ValueError(
                "Dataset must have a 'tokens', 'input_ids', 'text', or 'problem' column."
            )
        if self.is_dataset_tokenized:
            ds_context_size = len(dataset_sample[self.tokens_column])
            if ds_context_size < self.context_size:
                raise ValueError(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
                    The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}."""
                )
            if self.context_size != ds_context_size:
                warnings.warn(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""",
                    RuntimeWarning,
                )
            # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement
            if hasattr(self.dataset, "set_format"):
                self.dataset.set_format(type="torch", columns=[self.tokens_column])  # type: ignore

            if (
                isinstance(dataset, str)
                and hasattr(model, "tokenizer")
                and model.tokenizer is not None
            ):
                validate_pretokenized_dataset_tokenizer(
                    dataset_path=dataset, model_tokenizer=model.tokenizer
                )
        else:
            warnings.warn(
                "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
            )

        self.iterable_sequences = self._iterate_tokenized_sequences()

        self.cached_activation_dataset = self.load_cached_activation_dataset()

        # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

    def _iterate_raw_dataset(
        self,
    ) -> Generator[torch.Tensor | list[int] | str, None, None]:
        """
        Helper to iterate over the dataset while incrementing n_dataset_processed
        """
        for row in self.dataset:
            # typing datasets is difficult
            yield row[self.tokens_column]  # type: ignore
            self.n_dataset_processed += 1

    def _iterate_raw_dataset_tokens(self) -> Generator[torch.Tensor, None, None]:
        """
        Helper to create an iterator which tokenizes raw text from the dataset on the fly
        """
        for row in self._iterate_raw_dataset():
            tokens = (
                self.model.to_tokens(
                    row,
                    truncate=False,
                    move_to_device=False,  # we move to device below
                    prepend_bos=False,
                )
                .squeeze(0)
                .to(self.device)
            )
            assert (
                len(tokens.shape) == 1
            ), f"tokens.shape should be 1D but was {tokens.shape}"
            yield tokens

    def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]:
        """
        Generator which iterates over full sequence of context_size tokens
        """
        # If the datset is pretokenized, we will slice the dataset to the length of the context window if needed. Otherwise, no further processing is needed.
        # We assume that all necessary BOS/EOS/SEP tokens have been added during pretokenization.
        if self.is_dataset_tokenized:
            for row in self._iterate_raw_dataset():
                yield torch.tensor(
                    row[
                        : self.context_size
                    ],  # If self.context_size = None, this line simply returns the whole row
                    dtype=torch.long,
                    device=self.device,
                    requires_grad=False,
                )
        # If the dataset isn't tokenized, we'll tokenize, concat, and batch on the fly
        else:
            tokenizer = getattr(self.model, "tokenizer", None)
            bos_token_id = None if tokenizer is None else tokenizer.bos_token_id
            yield from concat_and_batch_sequences(
                tokens_iterator=self._iterate_raw_dataset_tokens(),
                context_size=self.context_size,
                begin_batch_token_id=(bos_token_id if self.prepend_bos else None),
                begin_sequence_token_id=None,
                sequence_separator_token_id=(
                    bos_token_id if self.prepend_bos else None
                ),
            )

    def load_cached_activation_dataset(self) -> Dataset | None:
        """
        Load the cached activation dataset from disk.

        - If cached_activations_path is set, returns Huggingface Dataset else None
        - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
        """
        if self.cached_activations_path is None:
            return None

        assert self.cached_activations_path is not None  # keep pyright happy
        # Sanity check: does the cache directory exist?
        assert os.path.exists(
            self.cached_activations_path
        ), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

        # ---
        # Actual code
        activations_dataset = datasets.load_from_disk(self.cached_activations_path)
        activations_dataset.set_format(
            type="torch", columns=[self.hook_name], device=self.device, dtype=self.dtype
        )
        self.current_row_idx = 0  # idx to load next batch from
        # ---

        assert isinstance(activations_dataset, Dataset)

        # multiple in hooks future
        if not set([self.hook_name]).issubset(activations_dataset.column_names):
            raise ValueError(
                f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
            )

        if activations_dataset.features[self.hook_name].shape != (
            self.context_size,
            self.d_in,
        ):
            raise ValueError(
                f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
            )

        return activations_dataset

    def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        return activations * self.estimated_norm_scaling_factor

    def unscale(self, activations: torch.Tensor) -> torch.Tensor:
        return activations / self.estimated_norm_scaling_factor

    def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        return (self.d_in**0.5) / activations.norm(dim=-1).mean()

    @torch.no_grad()
    def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
        norms_per_batch = []
        for _ in tqdm(
            range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
        ):
            acts = self.next_batch()
            norms_per_batch.append(acts.norm(dim=-1).mean().item())
        mean_norm = np.mean(norms_per_batch)
        scaling_factor = np.sqrt(self.d_in) / mean_norm

        return scaling_factor

    def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
        """
        This applies a shuffle to the huggingface dataset that is the input to the activations store. This
        also shuffles the shards of the dataset, which is especially useful for evaluating on different
        sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
        The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
        additionally shuffle individual elements within the shard.
        """
        if type(self.dataset) == IterableDataset:
            self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
        else:
            self.dataset = self.dataset.shuffle(seed=seed)
        self.iterable_dataset = iter(self.dataset)

    def reset_input_dataset(self):
        """
        Resets the input dataset iterator to the beginning.
        """
        self.iterable_dataset = iter(self.dataset)

    @property
    def storage_buffer(self) -> torch.Tensor:
        if self._storage_buffer is None:
            self._storage_buffer = self.get_buffer(self.half_buffer_size)

        return self._storage_buffer

    @property
    def dataloader(self) -> Iterator[Any]:
        if self._dataloader is None:
            self._dataloader = self.get_data_loader()
        return self._dataloader

    def get_batch_tokens(
        self, batch_size: int | None = None, raise_at_epoch_end: bool = False
    ):
        """
        Streams a batch of tokens from a dataset.

        If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
        """
        if not batch_size:
            batch_size = self.store_batch_size_prompts
        sequences = []
        # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
        for _ in range(batch_size):
            try:
                sequences.append(next(self.iterable_sequences))
            except StopIteration:
                self.iterable_sequences = self._iterate_tokenized_sequences()
                if raise_at_epoch_end:
                    raise StopIteration(
                        f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                    )
                else:
                    sequences.append(next(self.iterable_sequences))

        return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

    @torch.no_grad()
    def get_activations(self, batch_tokens: torch.Tensor):
        """
        Returns activations of shape (batches, context, num_layers, d_in)

        d_in may result from a concatenated head dimension.
        """

        # Setup autocast if using
        if self.autocast_lm:
            autocast_if_enabled = torch.autocast(
                device_type="cuda",
                dtype=torch.bfloat16,
                enabled=self.autocast_lm,
            )
        else:
            autocast_if_enabled = contextlib.nullcontext()

        with autocast_if_enabled:
            layerwise_activations_cache = self.model.run_with_cache(
                batch_tokens,
                names_filter=[self.hook_name],
                stop_at_layer=self.hook_layer + 1,
                prepend_bos=False,
                **self.model_kwargs,
            )[1]

        layerwise_activations = layerwise_activations_cache[self.hook_name][
            :, slice(*self.seqpos_slice)
        ]

        n_batches, n_context = layerwise_activations.shape[:2]

        stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

        if self.hook_head_index is not None:
            stacked_activations[:, :, 0] = layerwise_activations[
                :, :, self.hook_head_index
            ]
        elif layerwise_activations.ndim > 3:  # if we have a head dimension
            try:
                stacked_activations[:, :, 0] = layerwise_activations.view(
                    n_batches, n_context, -1
                )
            except RuntimeError as e:
                print(f"Error during view operation: {e}")
                print("Attempting to use reshape instead...")
                stacked_activations[:, :, 0] = layerwise_activations.reshape(
                    n_batches, n_context, -1
                )
        else:
            stacked_activations[:, :, 0] = layerwise_activations

        return stacked_activations

    def _load_buffer_from_cached(
        self,
        total_size: int,
        context_size: int,
        num_layers: int,
        d_in: int,
        raise_on_epoch_end: bool,
    ) -> Float[torch.Tensor, "(total_size context_size) num_layers d_in"]:
        """
        Loads `total_size` activations from `cached_activation_dataset`

        The dataset has columns for each hook_name,
        each containing activations of shape (context_size, d_in).

        raises StopIteration
        """
        assert self.cached_activation_dataset is not None
        # In future, could be a list of multiple hook names
        hook_names = [self.hook_name]
        assert set(hook_names).issubset(self.cached_activation_dataset.column_names)

        if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
            self.current_row_idx = 0
            if raise_on_epoch_end:
                raise StopIteration

        new_buffer = []
        for hook_name in hook_names:
            # Load activations for each hook.
            # Usually faster to first slice dataset then pick column
            _hook_buffer = self.cached_activation_dataset[
                self.current_row_idx : self.current_row_idx + total_size
            ][hook_name]
            assert _hook_buffer.shape == (total_size, context_size, d_in)
            new_buffer.append(_hook_buffer)

        # Stack across num_layers dimension
        # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
        new_buffer = torch.stack(new_buffer, dim=2)
        assert new_buffer.shape == (total_size, context_size, num_layers, d_in)

        self.current_row_idx += total_size
        return new_buffer.reshape(total_size * context_size, num_layers, d_in)

    @torch.no_grad()
    def get_buffer(
        self,
        n_batches_in_buffer: int,
        raise_on_epoch_end: bool = False,
        shuffle: bool = True,
    ) -> torch.Tensor:
        """
        Loads the next n_batches_in_buffer batches of activations into a tensor and returns half of it.

        The primary purpose here is maintaining a shuffling buffer.

        If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
        """
        context_size = self.context_size
        training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
        batch_size = self.store_batch_size_prompts
        d_in = self.d_in
        total_size = batch_size * n_batches_in_buffer
        num_layers = 1

        if self.cached_activation_dataset is not None:
            return self._load_buffer_from_cached(
                total_size, context_size, num_layers, d_in, raise_on_epoch_end
            )

        refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
        # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
        new_buffer = torch.zeros(
            (total_size, training_context_size, num_layers, d_in),
            dtype=self.dtype,  # type: ignore
            device=self.device,
        )

        for refill_batch_idx_start in refill_iterator:
            # move batch toks to gpu for model
            refill_batch_tokens = self.get_batch_tokens(
                raise_at_epoch_end=raise_on_epoch_end
            ).to(_get_model_device(self.model))
            refill_activations = self.get_activations(refill_batch_tokens)
            # move acts back to cpu
            refill_activations.to(self.device)
            new_buffer[
                refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
            ] = refill_activations

            # pbar.update(1)

        new_buffer = new_buffer.reshape(-1, num_layers, d_in)
        if shuffle:
            new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

        # every buffer should be normalized:
        if self.normalize_activations == "expected_average_only_in":
            new_buffer = self.apply_norm_scaling_factor(new_buffer)

        return new_buffer

    def get_data_loader(
        self,
    ) -> Iterator[Any]:
        """
        Return a torch.utils.dataloader which you can get batches from.

        Should automatically refill the buffer when it gets to n % full.
        (better mixing if you refill and shuffle regularly).

        """

        batch_size = self.train_batch_size_tokens

        try:
            new_samples = self.get_buffer(
                self.half_buffer_size, raise_on_epoch_end=True
            )
        except StopIteration:
            warnings.warn(
                "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
            )
            self._storage_buffer = (
                None  # dump the current buffer so samples do not leak between epochs
            )
            try:
                new_samples = self.get_buffer(self.half_buffer_size)
            except StopIteration:
                raise ValueError(
                    "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
                )

        # 1. # create new buffer by mixing stored and new buffer
        mixing_buffer = torch.cat(
            [new_samples, self.storage_buffer],
            dim=0,
        )

        mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

        # 2.  put 50 % in storage
        self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

        # 3. put other 50 % in a dataloader
        dataloader = iter(
            DataLoader(
                # TODO: seems like a typing bug?
                cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
                batch_size=batch_size,
                shuffle=True,
            )
        )

        return dataloader

    def next_batch(self):
        """
        Get the next batch from the current DataLoader.
        If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
        """
        try:
            # Try to get the next batch
            return next(self.dataloader)
        except StopIteration:
            # If the DataLoader is exhausted, create a new one
            self._dataloader = self.get_data_loader()
            return next(self.dataloader)

    def state_dict(self) -> dict[str, torch.Tensor]:
        result = {
            "n_dataset_processed": torch.tensor(self.n_dataset_processed),
        }
        if self._storage_buffer is not None:  # first time might be None
            result["storage_buffer"] = self._storage_buffer
        return result

    def save(self, file_path: str):
        save_file(self.state_dict(), file_path)

get_activations(batch_tokens)

Returns activations of shape (batches, context, num_layers, d_in)

d_in may result from a concatenated head dimension.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
    """
    Returns activations of shape (batches, context, num_layers, d_in)

    d_in may result from a concatenated head dimension.
    """

    # Setup autocast if using
    if self.autocast_lm:
        autocast_if_enabled = torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=self.autocast_lm,
        )
    else:
        autocast_if_enabled = contextlib.nullcontext()

    with autocast_if_enabled:
        layerwise_activations_cache = self.model.run_with_cache(
            batch_tokens,
            names_filter=[self.hook_name],
            stop_at_layer=self.hook_layer + 1,
            prepend_bos=False,
            **self.model_kwargs,
        )[1]

    layerwise_activations = layerwise_activations_cache[self.hook_name][
        :, slice(*self.seqpos_slice)
    ]

    n_batches, n_context = layerwise_activations.shape[:2]

    stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

    if self.hook_head_index is not None:
        stacked_activations[:, :, 0] = layerwise_activations[
            :, :, self.hook_head_index
        ]
    elif layerwise_activations.ndim > 3:  # if we have a head dimension
        try:
            stacked_activations[:, :, 0] = layerwise_activations.view(
                n_batches, n_context, -1
            )
        except RuntimeError as e:
            print(f"Error during view operation: {e}")
            print("Attempting to use reshape instead...")
            stacked_activations[:, :, 0] = layerwise_activations.reshape(
                n_batches, n_context, -1
            )
    else:
        stacked_activations[:, :, 0] = layerwise_activations

    return stacked_activations

get_batch_tokens(batch_size=None, raise_at_epoch_end=False)

Streams a batch of tokens from a dataset.

If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.

Source code in sae_lens/training/activations_store.py
def get_batch_tokens(
    self, batch_size: int | None = None, raise_at_epoch_end: bool = False
):
    """
    Streams a batch of tokens from a dataset.

    If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
    """
    if not batch_size:
        batch_size = self.store_batch_size_prompts
    sequences = []
    # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
    for _ in range(batch_size):
        try:
            sequences.append(next(self.iterable_sequences))
        except StopIteration:
            self.iterable_sequences = self._iterate_tokenized_sequences()
            if raise_at_epoch_end:
                raise StopIteration(
                    f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                )
            else:
                sequences.append(next(self.iterable_sequences))

    return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

get_buffer(n_batches_in_buffer, raise_on_epoch_end=False, shuffle=True)

Loads the next n_batches_in_buffer batches of activations into a tensor and returns half of it.

The primary purpose here is maintaining a shuffling buffer.

If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_buffer(
    self,
    n_batches_in_buffer: int,
    raise_on_epoch_end: bool = False,
    shuffle: bool = True,
) -> torch.Tensor:
    """
    Loads the next n_batches_in_buffer batches of activations into a tensor and returns half of it.

    The primary purpose here is maintaining a shuffling buffer.

    If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
    """
    context_size = self.context_size
    training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
    batch_size = self.store_batch_size_prompts
    d_in = self.d_in
    total_size = batch_size * n_batches_in_buffer
    num_layers = 1

    if self.cached_activation_dataset is not None:
        return self._load_buffer_from_cached(
            total_size, context_size, num_layers, d_in, raise_on_epoch_end
        )

    refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
    # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
    new_buffer = torch.zeros(
        (total_size, training_context_size, num_layers, d_in),
        dtype=self.dtype,  # type: ignore
        device=self.device,
    )

    for refill_batch_idx_start in refill_iterator:
        # move batch toks to gpu for model
        refill_batch_tokens = self.get_batch_tokens(
            raise_at_epoch_end=raise_on_epoch_end
        ).to(_get_model_device(self.model))
        refill_activations = self.get_activations(refill_batch_tokens)
        # move acts back to cpu
        refill_activations.to(self.device)
        new_buffer[
            refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
        ] = refill_activations

        # pbar.update(1)

    new_buffer = new_buffer.reshape(-1, num_layers, d_in)
    if shuffle:
        new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

    # every buffer should be normalized:
    if self.normalize_activations == "expected_average_only_in":
        new_buffer = self.apply_norm_scaling_factor(new_buffer)

    return new_buffer

get_data_loader()

Return a torch.utils.dataloader which you can get batches from.

Should automatically refill the buffer when it gets to n % full. (better mixing if you refill and shuffle regularly).

Source code in sae_lens/training/activations_store.py
def get_data_loader(
    self,
) -> Iterator[Any]:
    """
    Return a torch.utils.dataloader which you can get batches from.

    Should automatically refill the buffer when it gets to n % full.
    (better mixing if you refill and shuffle regularly).

    """

    batch_size = self.train_batch_size_tokens

    try:
        new_samples = self.get_buffer(
            self.half_buffer_size, raise_on_epoch_end=True
        )
    except StopIteration:
        warnings.warn(
            "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
        )
        self._storage_buffer = (
            None  # dump the current buffer so samples do not leak between epochs
        )
        try:
            new_samples = self.get_buffer(self.half_buffer_size)
        except StopIteration:
            raise ValueError(
                "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
            )

    # 1. # create new buffer by mixing stored and new buffer
    mixing_buffer = torch.cat(
        [new_samples, self.storage_buffer],
        dim=0,
    )

    mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

    # 2.  put 50 % in storage
    self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

    # 3. put other 50 % in a dataloader
    dataloader = iter(
        DataLoader(
            # TODO: seems like a typing bug?
            cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
            batch_size=batch_size,
            shuffle=True,
        )
    )

    return dataloader

load_cached_activation_dataset()

Load the cached activation dataset from disk.

  • If cached_activations_path is set, returns Huggingface Dataset else None
  • Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
Source code in sae_lens/training/activations_store.py
def load_cached_activation_dataset(self) -> Dataset | None:
    """
    Load the cached activation dataset from disk.

    - If cached_activations_path is set, returns Huggingface Dataset else None
    - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
    """
    if self.cached_activations_path is None:
        return None

    assert self.cached_activations_path is not None  # keep pyright happy
    # Sanity check: does the cache directory exist?
    assert os.path.exists(
        self.cached_activations_path
    ), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

    # ---
    # Actual code
    activations_dataset = datasets.load_from_disk(self.cached_activations_path)
    activations_dataset.set_format(
        type="torch", columns=[self.hook_name], device=self.device, dtype=self.dtype
    )
    self.current_row_idx = 0  # idx to load next batch from
    # ---

    assert isinstance(activations_dataset, Dataset)

    # multiple in hooks future
    if not set([self.hook_name]).issubset(activations_dataset.column_names):
        raise ValueError(
            f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
        )

    if activations_dataset.features[self.hook_name].shape != (
        self.context_size,
        self.d_in,
    ):
        raise ValueError(
            f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
        )

    return activations_dataset

next_batch()

Get the next batch from the current DataLoader. If the DataLoader is exhausted, refill the buffer and create a new DataLoader.

Source code in sae_lens/training/activations_store.py
def next_batch(self):
    """
    Get the next batch from the current DataLoader.
    If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
    """
    try:
        # Try to get the next batch
        return next(self.dataloader)
    except StopIteration:
        # If the DataLoader is exhausted, create a new one
        self._dataloader = self.get_data_loader()
        return next(self.dataloader)

reset_input_dataset()

Resets the input dataset iterator to the beginning.

Source code in sae_lens/training/activations_store.py
def reset_input_dataset(self):
    """
    Resets the input dataset iterator to the beginning.
    """
    self.iterable_dataset = iter(self.dataset)

shuffle_input_dataset(seed, buffer_size=1)

This applies a shuffle to the huggingface dataset that is the input to the activations store. This also shuffles the shards of the dataset, which is especially useful for evaluating on different sections of very large streaming datasets. Buffer size is only relevant for streaming datasets. The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will additionally shuffle individual elements within the shard.

Source code in sae_lens/training/activations_store.py
def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
    """
    This applies a shuffle to the huggingface dataset that is the input to the activations store. This
    also shuffles the shards of the dataset, which is especially useful for evaluating on different
    sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
    The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
    additionally shuffle individual elements within the shard.
    """
    if type(self.dataset) == IterableDataset:
        self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
    else:
        self.dataset = self.dataset.shuffle(seed=seed)
    self.iterable_dataset = iter(self.dataset)

CacheActivationsRunner

Source code in sae_lens/cache_activations_runner.py
class CacheActivationsRunner:
    def __init__(self, cfg: CacheActivationsRunnerConfig):
        self.cfg = cfg
        self.model = load_model(
            model_class_name=cfg.model_class_name,
            model_name=cfg.model_name,
            device=cfg.device,
            model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
        )
        self.activations_store = ActivationsStore.from_config(
            self.model,
            cfg,
        )
        ctx_size = _get_sliced_context_size(self.cfg)
        self.features = Features(
            {
                f"{self.cfg.hook_name}": Array2D(
                    shape=(ctx_size, self.cfg.d_in), dtype=self.cfg.dtype
                )
            }
        )
        self.tokens_in_buffer = (
            self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts * ctx_size
        )
        self.n_buffers = math.ceil(self.cfg.training_tokens / self.tokens_in_buffer)

    def __str__(self):
        """
        Print the number of tokens to be cached.
        Print the number of buffers, and the number of tokens per buffer.
        Print the disk space required to store the activations.

        """

        bytes_per_token = (
            self.cfg.d_in * self.cfg.dtype.itemsize
            if isinstance(self.cfg.dtype, torch.dtype)
            else DTYPE_MAP[self.cfg.dtype].itemsize
        )
        total_training_tokens = self.cfg.training_tokens
        total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

        return (
            f"Activation Cache Runner:\n"
            f"Total training tokens: {total_training_tokens}\n"
            f"Number of buffers: {self.n_buffers}\n"
            f"Tokens per buffer: {self.tokens_in_buffer}\n"
            f"Disk space required: {total_disk_space_gb:.2f} GB\n"
            f"Configuration:\n"
            f"{self.cfg}"
        )

    @staticmethod
    def _consolidate_shards(
        source_dir: Path, output_dir: Path, copy_files: bool = True
    ) -> Dataset:
        """Consolidate sharded datasets into a single directory without rewriting data.

        Each of the shards must be of the same format, aka the full dataset must be able to
        be recreated like so:

        ```
        ds = concatenate_datasets(
            [Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
        )

        ```

        Sharded dataset format:
        ```
        source_dir/
            shard_00000/
                dataset_info.json
                state.json
                data-00000-of-00002.arrow
                data-00001-of-00002.arrow
            shard_00001/
                dataset_info.json
                state.json
                data-00000-of-00001.arrow
        ```

        And flattens them into the format:

        ```
        output_dir/
            dataset_info.json
            state.json
            data-00000-of-00003.arrow
            data-00001-of-00003.arrow
            data-00002-of-00003.arrow
        ```

        allowing the dataset to be loaded like so:

        ```
        ds = datasets.load_from_disk(output_dir)
        ```

        Args:
            source_dir: Directory containing the sharded datasets
            output_dir: Directory to consolidate the shards into
            copy_files: If True, copy files; if False, move them and delete source_dir
        """
        first_shard_dir_name = "shard_00000"  # shard_{i:05d}

        assert source_dir.exists() and source_dir.is_dir()
        assert (
            output_dir.exists()
            and output_dir.is_dir()
            and not any(p for p in output_dir.iterdir() if not p.name == ".tmp_shards")
        )
        if not (source_dir / first_shard_dir_name).exists():
            raise Exception(f"No shards in {source_dir} exist!")

        transfer_fn = shutil.copy2 if copy_files else shutil.move

        # Move dataset_info.json from any shard (all the same)
        transfer_fn(
            source_dir / first_shard_dir_name / "dataset_info.json",
            output_dir / "dataset_info.json",
        )

        arrow_files = []
        file_count = 0

        for shard_dir in sorted(source_dir.iterdir()):
            if not shard_dir.name.startswith("shard_"):
                continue

            # state.json contains arrow filenames
            state = json.loads((shard_dir / "state.json").read_text())

            for data_file in state["_data_files"]:
                src = shard_dir / data_file["filename"]
                new_name = f"data-{file_count:05d}-of-{len(list(source_dir.iterdir())):05d}.arrow"
                dst = output_dir / new_name
                transfer_fn(src, dst)
                arrow_files.append({"filename": new_name})
                file_count += 1

        new_state = {
            "_data_files": arrow_files,
            "_fingerprint": None,  # temporary
            "_format_columns": None,
            "_format_kwargs": {},
            "_format_type": None,
            "_output_all_columns": False,
            "_split": None,
        }

        # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
        with open(output_dir / "state.json", "w") as f:
            json.dump(new_state, f, indent=2)

        ds = Dataset.load_from_disk(str(output_dir))
        fingerprint = generate_fingerprint(ds)
        del ds

        with open(output_dir / "state.json", "r+") as f:
            state = json.loads(f.read())
            state["_fingerprint"] = fingerprint
            f.seek(0)
            json.dump(state, f, indent=2)
            f.truncate()

        if not copy_files:  # cleanup source dir
            shutil.rmtree(source_dir)

        return Dataset.load_from_disk(output_dir)

    @torch.no_grad()
    def _create_shard(
        self,
        buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"],
    ) -> Dataset:
        hook_names = [self.cfg.hook_name]  # allow multiple hooks in future

        buffer = einops.rearrange(
            buffer,
            "(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
            bs=self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts,
            context_size=_get_sliced_context_size(self.cfg),
            d_in=self.cfg.d_in,
            num_layers=len(hook_names),
        )
        shard = Dataset.from_dict(
            {hook_name: act for hook_name, act in zip(hook_names, buffer)},
            features=self.features,
        )
        return shard

    @torch.no_grad()
    def run(self) -> Dataset:
        ### Paths setup
        assert self.cfg.new_cached_activations_path is not None
        final_cached_activation_path = Path(self.cfg.new_cached_activations_path)
        final_cached_activation_path.mkdir(exist_ok=True, parents=True)
        if any(final_cached_activation_path.iterdir()):
            raise Exception(
                f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
            )

        tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
        tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)

        ### Create temporary sharded datasets

        print(f"Started caching {self.cfg.training_tokens} activations")

        for i in tqdm(range(self.n_buffers), desc="Caching activations"):
            try:
                # num activations in a single shard: n_batches_in_buffer * store_batch_size_prompts
                buffer = self.activations_store.get_buffer(
                    self.cfg.n_batches_in_buffer, shuffle=self.cfg.shuffle
                )
                shard = self._create_shard(buffer)
                shard.save_to_disk(
                    f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
                )
                del buffer, shard

            except StopIteration:
                print(
                    f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.n_buffers} batches. No more caching will occur."
                )
                break

        ### Concat sharded datasets together, shuffle and push to hub

        dataset = self._consolidate_shards(
            tmp_cached_activation_path, final_cached_activation_path, copy_files=False
        )

        if self.cfg.shuffle:
            print("Shuffling...")
            dataset = dataset.shuffle(seed=self.cfg.seed)

        if self.cfg.hf_repo_id:
            print("Pushing to hub...")
            dataset.push_to_hub(
                repo_id=self.cfg.hf_repo_id,
                num_shards=self.cfg.hf_num_shards or self.n_buffers,
                private=self.cfg.hf_is_private_repo,
                revision=self.cfg.hf_revision,
            )

            meta_io = io.BytesIO()
            meta_contents = json.dumps(
                asdict(self.cfg), indent=2, ensure_ascii=False
            ).encode("utf-8")
            meta_io.write(meta_contents)
            meta_io.seek(0)

            api = HfApi()
            api.upload_file(
                path_or_fileobj=meta_io,
                path_in_repo="cache_activations_runner_cfg.json",
                repo_id=self.cfg.hf_repo_id,
                repo_type="dataset",
                commit_message="Add cache_activations_runner metadata",
            )

        return dataset

__str__()

Print the number of tokens to be cached. Print the number of buffers, and the number of tokens per buffer. Print the disk space required to store the activations.

Source code in sae_lens/cache_activations_runner.py
def __str__(self):
    """
    Print the number of tokens to be cached.
    Print the number of buffers, and the number of tokens per buffer.
    Print the disk space required to store the activations.

    """

    bytes_per_token = (
        self.cfg.d_in * self.cfg.dtype.itemsize
        if isinstance(self.cfg.dtype, torch.dtype)
        else DTYPE_MAP[self.cfg.dtype].itemsize
    )
    total_training_tokens = self.cfg.training_tokens
    total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

    return (
        f"Activation Cache Runner:\n"
        f"Total training tokens: {total_training_tokens}\n"
        f"Number of buffers: {self.n_buffers}\n"
        f"Tokens per buffer: {self.tokens_in_buffer}\n"
        f"Disk space required: {total_disk_space_gb:.2f} GB\n"
        f"Configuration:\n"
        f"{self.cfg}"
    )

CacheActivationsRunnerConfig dataclass

Configuration for caching activations of an LLM.

Source code in sae_lens/config.py
@dataclass
class CacheActivationsRunnerConfig:
    """
    Configuration for caching activations of an LLM.
    """

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.{layer}.hook_mlp_out"
    hook_layer: int = 0
    hook_head_index: Optional[int] = None
    dataset_path: str = ""
    dataset_trust_remote_code: bool | None = None
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    new_cached_activations_path: Optional[str] = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )

    # if saving to huggingface, set hf_repo_id
    hf_repo_id: Optional[str] = None
    hf_num_shards: int | None = None
    hf_revision: str = "main"
    hf_is_private_repo: bool = False

    # dont' specify this since you don't want to load from disk with the cache runner.
    cached_activations_path: Optional[str] = None
    # SAE Parameters
    d_in: int = 512

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    store_batch_size_prompts: int = 32
    train_batch_size_tokens: int = 4096
    normalize_activations: str = "none"  # should always be none for activation caching
    seqpos_slice: tuple[int | None, ...] = (None,)

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"
    prepend_bos: bool = True
    autocast_lm: bool = False  # autocast lm during activation fetching

    # Shuffle activations
    shuffle: bool = True

    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        # Autofill cached_activations_path unless the user overrode it
        if self.new_cached_activations_path is None:
            self.new_cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

        if self.context_size < 0:
            raise ValueError(
                f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
            )

        _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

HookedSAETransformer

Bases: HookedTransformer

Source code in sae_lens/analysis/hooked_sae_transformer.py
class HookedSAETransformer(HookedTransformer):

    def __init__(
        self,
        *model_args: Any,
        **model_kwargs: Any,
    ):
        """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

        Note that if you want to load the model from pretrained weights, you should use
        :meth:`from_pretrained` instead.

        Args:
            *model_args: Positional arguments for HookedTransformer initialization
            **model_kwargs: Keyword arguments for HookedTransformer initialization
        """
        super().__init__(*model_args, **model_kwargs)
        self.acts_to_saes: Dict[str, SAE] = {}

    def add_sae(self, sae: SAE, use_error_term: Optional[bool] = None):
        """Attaches an SAE to the model

        WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

        Args:
            sae: SparseAutoencoderBase. The SAE to attach to the model
            use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
        """
        act_name = sae.cfg.hook_name
        if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
            logging.warning(
                f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
            )
            return

        if use_error_term is not None:
            if not hasattr(sae, "_original_use_error_term"):
                sae._original_use_error_term = sae.use_error_term  # type: ignore
            sae.use_error_term = use_error_term
        self.acts_to_saes[act_name] = sae
        set_deep_attr(self, act_name, sae)
        self.setup()

    def _reset_sae(self, act_name: str, prev_sae: Optional[SAE] = None):
        """Resets an SAE that was attached to the model

        By default will remove the SAE from that hook_point.
        If prev_sae is provided, will replace the current SAE with the provided one.
        This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)

        Args:
            act_name: str. The hook_name of the SAE to reset
            prev_sae: Optional[HookedSAE]. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
        """
        if act_name not in self.acts_to_saes:
            logging.warning(
                f"No SAE is attached to {act_name}. There's nothing to reset."
            )
            return

        current_sae = self.acts_to_saes[act_name]
        if hasattr(current_sae, "_original_use_error_term"):
            current_sae.use_error_term = current_sae._original_use_error_term
            delattr(current_sae, "_original_use_error_term")

        if prev_sae:
            set_deep_attr(self, act_name, prev_sae)
            self.acts_to_saes[act_name] = prev_sae
        else:
            set_deep_attr(self, act_name, HookPoint())
            del self.acts_to_saes[act_name]

    def reset_saes(
        self,
        act_names: Optional[Union[str, List[str]]] = None,
        prev_saes: Optional[List[Union[SAE, None]]] = None,
    ):
        """Reset the SAEs attached to the model

        If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
        Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

        Args:
            act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
            prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
        """
        if isinstance(act_names, str):
            act_names = [act_names]
        elif act_names is None:
            act_names = list(self.acts_to_saes.keys())

        if prev_saes:
            assert len(act_names) == len(
                prev_saes
            ), "act_names and prev_saes must have the same length"
        else:
            prev_saes = [None] * len(act_names)  # type: ignore

        for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
            self._reset_sae(act_name, prev_sae)

        self.setup()

    def run_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        use_error_term: Optional[bool] = None,
        **model_kwargs: Any,
    ) -> Union[
        None,
        Float[torch.Tensor, "batch pos d_vocab"],
        Loss,
        Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
    ]:
        """Wrapper around HookedTransformer forward pass.

        Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self(*model_args, **model_kwargs)

    def run_with_cache_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        use_error_term: Optional[bool] = None,
        return_cache_object: bool = True,
        remove_batch_dim: bool = False,
        **kwargs: Any,
    ) -> Tuple[
        Union[
            None,
            Float[torch.Tensor, "batch pos d_vocab"],
            Loss,
            Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
        ],
        Union[ActivationCache, Dict[str, torch.Tensor]],
    ]:
        """Wrapper around 'run_with_cache' in HookedTransformer.

        Attaches given SAEs before running the model with cache and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
            return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
                useful HookedTransformer specific methods, otherwise it will return a dictionary of
                activations as in HookedRootModule.
            remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
            **kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self.run_with_cache(  # type: ignore
                *model_args,
                return_cache_object=return_cache_object,  # type: ignore
                remove_batch_dim=remove_batch_dim,
                **kwargs,
            )

    def run_with_hooks_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
        bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
        reset_hooks_end: bool = True,
        clear_contexts: bool = False,
        **model_kwargs: Any,
    ):
        """Wrapper around 'run_with_hooks' in HookedTransformer.

        Attaches the given SAEs to the model before running the model with hooks and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
            fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply
            bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply
            reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
            clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self.run_with_hooks(
                *model_args,
                fwd_hooks=fwd_hooks,
                bwd_hooks=bwd_hooks,
                reset_hooks_end=reset_hooks_end,
                clear_contexts=clear_contexts,
                **model_kwargs,
            )

    @contextmanager
    def saes(
        self,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        use_error_term: Optional[bool] = None,
    ):
        """
        A context manager for adding temporary SAEs to the model.
        See HookedTransformer.hooks for a similar context manager for hooks.
        By default will keep track of previously attached SAEs, and restore them when the context manager exits.

        Example:

        .. code-block:: python

            from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

            model = HookedSAETransformer.from_pretrained('gpt2-small')
            sae_cfg = HookedSAEConfig(...)
            sae = HookedSAE(sae_cfg)
            with model.saes(saes=[sae]):
                spliced_logits = model(text)


        Args:
            saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
            reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
            use_error_term (Optional[bool]): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
        """
        act_names_to_reset = []
        prev_saes = []
        if isinstance(saes, SAE):
            saes = [saes]
        try:
            for sae in saes:
                act_names_to_reset.append(sae.cfg.hook_name)
                prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
                prev_saes.append(prev_sae)
                self.add_sae(sae, use_error_term=use_error_term)
            yield self
        finally:
            if reset_saes_end:
                self.reset_saes(act_names_to_reset, prev_saes)

__init__(*model_args, **model_kwargs)

Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

Note that if you want to load the model from pretrained weights, you should use :meth:from_pretrained instead.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for HookedTransformer initialization

()
**model_kwargs Any

Keyword arguments for HookedTransformer initialization

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def __init__(
    self,
    *model_args: Any,
    **model_kwargs: Any,
):
    """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

    Note that if you want to load the model from pretrained weights, you should use
    :meth:`from_pretrained` instead.

    Args:
        *model_args: Positional arguments for HookedTransformer initialization
        **model_kwargs: Keyword arguments for HookedTransformer initialization
    """
    super().__init__(*model_args, **model_kwargs)
    self.acts_to_saes: Dict[str, SAE] = {}

add_sae(sae, use_error_term=None)

Attaches an SAE to the model

WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

Parameters:

Name Type Description Default
sae SAE

SparseAutoencoderBase. The SAE to attach to the model

required
use_error_term Optional[bool]

(Optional[bool]) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def add_sae(self, sae: SAE, use_error_term: Optional[bool] = None):
    """Attaches an SAE to the model

    WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

    Args:
        sae: SparseAutoencoderBase. The SAE to attach to the model
        use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
    """
    act_name = sae.cfg.hook_name
    if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
        logging.warning(
            f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
        )
        return

    if use_error_term is not None:
        if not hasattr(sae, "_original_use_error_term"):
            sae._original_use_error_term = sae.use_error_term  # type: ignore
        sae.use_error_term = use_error_term
    self.acts_to_saes[act_name] = sae
    set_deep_attr(self, act_name, sae)
    self.setup()

reset_saes(act_names=None, prev_saes=None)

Reset the SAEs attached to the model

If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model. Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

Parameters:

Name Type Description Default
act_names Optional[Union[str, List[str]]

The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.

None
prev_saes Optional[List[Union[HookedSAE, None]]]

List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def reset_saes(
    self,
    act_names: Optional[Union[str, List[str]]] = None,
    prev_saes: Optional[List[Union[SAE, None]]] = None,
):
    """Reset the SAEs attached to the model

    If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
    Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

    Args:
        act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
        prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
    """
    if isinstance(act_names, str):
        act_names = [act_names]
    elif act_names is None:
        act_names = list(self.acts_to_saes.keys())

    if prev_saes:
        assert len(act_names) == len(
            prev_saes
        ), "act_names and prev_saes must have the same length"
    else:
        prev_saes = [None] * len(act_names)  # type: ignore

    for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
        self._reset_sae(act_name, prev_sae)

    self.setup()

run_with_cache_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, return_cache_object=True, remove_batch_dim=False, **kwargs)

Wrapper around 'run_with_cache' in HookedTransformer.

Attaches given SAEs before running the model with cache and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes Union[SAE, List[SAE]]

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term Optional[bool]

(Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.

None
return_cache_object bool

(bool) if True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule.

True
remove_batch_dim bool

(bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

False
**kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_cache_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    use_error_term: Optional[bool] = None,
    return_cache_object: bool = True,
    remove_batch_dim: bool = False,
    **kwargs: Any,
) -> Tuple[
    Union[
        None,
        Float[torch.Tensor, "batch pos d_vocab"],
        Loss,
        Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
    ],
    Union[ActivationCache, Dict[str, torch.Tensor]],
]:
    """Wrapper around 'run_with_cache' in HookedTransformer.

    Attaches given SAEs before running the model with cache and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
        return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
            useful HookedTransformer specific methods, otherwise it will return a dictionary of
            activations as in HookedRootModule.
        remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
        **kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self.run_with_cache(  # type: ignore
            *model_args,
            return_cache_object=return_cache_object,  # type: ignore
            remove_batch_dim=remove_batch_dim,
            **kwargs,
        )

run_with_hooks_with_saes(*model_args, saes=[], reset_saes_end=True, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False, **model_kwargs)

Wrapper around 'run_with_hooks' in HookedTransformer.

Attaches the given SAEs to the model before running the model with hooks and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
act_names

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

required
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)

True
fwd_hooks List[Tuple[Union[str, Callable], Callable]]

(List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply

[]
bwd_hooks List[Tuple[Union[str, Callable], Callable]]

(List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply

[]
reset_hooks_end bool

(bool) Whether to reset the hooks at the end of the forward pass (default: True)

True
clear_contexts bool

(bool) Whether to clear the contexts at the end of the forward pass (default: False)

False
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_hooks_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
    bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
    reset_hooks_end: bool = True,
    clear_contexts: bool = False,
    **model_kwargs: Any,
):
    """Wrapper around 'run_with_hooks' in HookedTransformer.

    Attaches the given SAEs to the model before running the model with hooks and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
        fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply
        bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply
        reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
        clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self.run_with_hooks(
            *model_args,
            fwd_hooks=fwd_hooks,
            bwd_hooks=bwd_hooks,
            reset_hooks_end=reset_hooks_end,
            clear_contexts=clear_contexts,
            **model_kwargs,
        )

run_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, **model_kwargs)

Wrapper around HookedTransformer forward pass.

Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes Union[SAE, List[SAE]]

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term Optional[bool]

(Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.

None
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    use_error_term: Optional[bool] = None,
    **model_kwargs: Any,
) -> Union[
    None,
    Float[torch.Tensor, "batch pos d_vocab"],
    Loss,
    Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
]:
    """Wrapper around HookedTransformer forward pass.

    Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self(*model_args, **model_kwargs)

saes(saes=[], reset_saes_end=True, use_error_term=None)

A context manager for adding temporary SAEs to the model. See HookedTransformer.hooks for a similar context manager for hooks. By default will keep track of previously attached SAEs, and restore them when the context manager exits.

Example:

.. code-block:: python

from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

model = HookedSAETransformer.from_pretrained('gpt2-small')
sae_cfg = HookedSAEConfig(...)
sae = HookedSAE(sae_cfg)
with model.saes(saes=[sae]):
    spliced_logits = model(text)

Parameters:

Name Type Description Default
saes Union[HookedSAE, List[HookedSAE]]

SAEs to be attached.

[]
reset_saes_end bool

If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.

True
use_error_term Optional[bool]

If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
@contextmanager
def saes(
    self,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    use_error_term: Optional[bool] = None,
):
    """
    A context manager for adding temporary SAEs to the model.
    See HookedTransformer.hooks for a similar context manager for hooks.
    By default will keep track of previously attached SAEs, and restore them when the context manager exits.

    Example:

    .. code-block:: python

        from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

        model = HookedSAETransformer.from_pretrained('gpt2-small')
        sae_cfg = HookedSAEConfig(...)
        sae = HookedSAE(sae_cfg)
        with model.saes(saes=[sae]):
            spliced_logits = model(text)


    Args:
        saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
        reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
        use_error_term (Optional[bool]): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
    """
    act_names_to_reset = []
    prev_saes = []
    if isinstance(saes, SAE):
        saes = [saes]
    try:
        for sae in saes:
            act_names_to_reset.append(sae.cfg.hook_name)
            prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
            prev_saes.append(prev_sae)
            self.add_sae(sae, use_error_term=use_error_term)
        yield self
    finally:
        if reset_saes_end:
            self.reset_saes(act_names_to_reset, prev_saes)

LanguageModelSAERunnerConfig dataclass

Configuration for training a sparse autoencoder on a language model.

Parameters:

Name Type Description Default
architecture str

The architecture to use, either "standard", "gated", or "jumprelu".

'standard'
model_name str

The name of the model to use. This should be the name of the model in the Hugging Face model hub.

'gelu-2l'
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
hook_name str

The name of the hook to use. This should be a valid TransformerLens hook.

'blocks.0.hook_mlp_out'
hook_eval str

NOT CURRENTLY IN USE. The name of the hook to use for evaluation.

'NOT_IN_USE'
hook_layer int

The index of the layer to hook. Used to stop forward passes early and speed up processing.

0
hook_head_index int

When the hook if for an activatio with a head index, we can specify a specific head to use here.

None
dataset_path str

A Hugging Face dataset path.

''
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

True
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
is_dataset_tokenized bool

NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.

True
context_size int

The context size to use when generating activations on which to train the SAE.

128
use_cached_activations bool

Whether to use cached activations. This is useful when doing sweeps over the same activations.

False
cached_activations_path str

The path to the cached activations.

None
d_in int

The input dimension of the SAE.

512
d_sae int

The output dimension of the SAE. If None, defaults to d_in * expansion_factor.

None
b_dec_init_method str

The method to use to initialize the decoder bias. Zeros is likely fine.

'geometric_median'
expansion_factor int

The expansion factor. Larger is better but more computationally expensive. Default is 4.

None
activation_fn str

The activation function to use. Relu is standard.

'relu'
normalize_sae_decoder bool

Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.

True
noise_scale float

Using noise to induce sparsity is supported but not recommended.

0.0
from_pretrained_path str

The path to a pretrained SAE. We can finetune an existing SAE if needed.

None
apply_b_dec_to_input bool

Whether to apply the decoder bias to the input. Not currently advised.

True
decoder_orthogonal_init bool

Whether to use orthogonal initialization for the decoder. Not currently advised.

False
decoder_heuristic_init bool

Whether to use heuristic initialization for the decoder. See Anthropic April Update.

False
init_encoder_as_decoder_transpose bool

Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.

False
n_batches_in_buffer int

The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.

20
training_tokens int

The number of training tokens.

2000000
finetuning_tokens int

The number of finetuning tokens. See here

0
store_batch_size_prompts int

The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.

32
train_batch_size_tokens int

The batch size for training. This controls the batch size of the SAE Training loop.

4096
normalize_activations str

Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).

'none'
seqpos_slice tuple

Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.

(None)
device str

The device to use. Usually cuda.

'cpu'
act_store_device str

The device to use for the activation store. CPU is advised in order to save vram.

'with_model'
seed int

The seed to use.

42
dtype str

The data type to use.

'float32'
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
jumprelu_init_threshold float

The threshold to initialize for training JumpReLU SAEs.

0.001
jumprelu_bandwidth float

Bandwidth for training JumpReLU SAEs.

0.001
autocast bool

Whether to use autocast during training. Saves vram.

False
autocast_lm bool

Whether to use autocast during activation fetching.

False
compile_llm bool

Whether to compile the LLM.

False
llm_compilation_mode str

The compilation mode to use for the LLM.

None
compile_sae bool

Whether to compile the SAE.

False
sae_compilation_mode str

The compilation mode to use for the SAE.

None
adam_beta1 float

The beta1 parameter for Adam.

0
adam_beta2 float

The beta2 parameter for Adam.

0.999
mse_loss_normalization str

The normalization to use for the MSE loss.

None
l1_coefficient float

The L1 coefficient.

0.001
lp_norm float

The Lp norm.

1
scale_sparsity_penalty_by_decoder_norm bool

Whether to scale the sparsity penalty by the decoder norm.

False
l1_warm_up_steps int

The number of warm-up steps for the L1 loss.

0
lr float

The learning rate.

0.0003
lr_scheduler_name str

The name of the learning rate scheduler to use.

'constant'
lr_warm_up_steps int

The number of warm-up steps for the learning rate.

0
lr_end float

The end learning rate if lr_decay_steps is set. Default is lr / 10.

None
lr_decay_steps int

The number of decay steps for the learning rate.

0
n_restart_cycles int

The number of restart cycles for the cosine annealing warm restarts scheduler.

1
finetuning_method str

The method to use for finetuning.

None
use_ghost_grads bool

Whether to use ghost gradients.

False
feature_sampling_window int

The feature sampling window.

2000
dead_feature_window int

The dead feature window.

1000
dead_feature_threshold float

The dead feature threshold.

1e-08
n_eval_batches int

The number of evaluation batches.

10
eval_batch_size_prompts int

The batch size for evaluation.

None
log_to_wandb bool

Whether to log to Weights & Biases.

True
log_activations_store_to_wandb bool

NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.

False
log_optimizer_state_to_wandb bool

NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.

False
wandb_project str

The Weights & Biases project to log to.

'mats_sae_training_language_model'
wandb_id str

The Weights & Biases ID.

None
run_name str

The name of the run.

None
wandb_entity str

The Weights & Biases entity.

None
wandb_log_frequency int

The frequency to log to Weights & Biases.

10
eval_every_n_wandb_logs int

The frequency to evaluate.

100
resume bool

Whether to resume training.

False
n_checkpoints int

The number of checkpoints.

0
checkpoint_path str

The path to save checkpoints.

'checkpoints'
verbose bool

Whether to print verbose output.

True
model_kwargs dict[str, Any]

Additional keyword arguments for the model.

dict()
model_from_pretrained_kwargs dict[str, Any]

Additional keyword arguments for the model from pretrained.

None
Source code in sae_lens/config.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
@dataclass
class LanguageModelSAERunnerConfig:
    """
    Configuration for training a sparse autoencoder on a language model.

    Args:
        architecture (str): The architecture to use, either "standard", "gated", or "jumprelu".
        model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
        hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
        hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
        hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
        dataset_path (str): A Hugging Face dataset path.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
        context_size (int): The context size to use when generating activations on which to train the SAE.
        use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
        cached_activations_path (str, optional): The path to the cached activations.
        d_in (int): The input dimension of the SAE.
        d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
        b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
        expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
        activation_fn (str): The activation function to use. Relu is standard.
        normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
        noise_scale (float): Using noise to induce sparsity is supported but not recommended.
        from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
        apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
        decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
        decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
        init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
        n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
        training_tokens (int): The number of training tokens.
        finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
        store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
        train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
        normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
        seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
        device (str): The device to use. Usually cuda.
        act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
        seed (int): The seed to use.
        dtype (str): The data type to use.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
        jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
        autocast (bool): Whether to use autocast during training. Saves vram.
        autocast_lm (bool): Whether to use autocast during activation fetching.
        compile_llm (bool): Whether to compile the LLM.
        llm_compilation_mode (str): The compilation mode to use for the LLM.
        compile_sae (bool): Whether to compile the SAE.
        sae_compilation_mode (str): The compilation mode to use for the SAE.
        adam_beta1 (float): The beta1 parameter for Adam.
        adam_beta2 (float): The beta2 parameter for Adam.
        mse_loss_normalization (str): The normalization to use for the MSE loss.
        l1_coefficient (float): The L1 coefficient.
        lp_norm (float): The Lp norm.
        scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
        l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
        lr (float): The learning rate.
        lr_scheduler_name (str): The name of the learning rate scheduler to use.
        lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
        lr_end (float): The end learning rate if lr_decay_steps is set. Default is lr / 10.
        lr_decay_steps (int): The number of decay steps for the learning rate.
        n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
        finetuning_method (str): The method to use for finetuning.
        use_ghost_grads (bool): Whether to use ghost gradients.
        feature_sampling_window (int): The feature sampling window.
        dead_feature_window (int): The dead feature window.
        dead_feature_threshold (float): The dead feature threshold.
        n_eval_batches (int): The number of evaluation batches.
        eval_batch_size_prompts (int): The batch size for evaluation.
        log_to_wandb (bool): Whether to log to Weights & Biases.
        log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
        log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
        wandb_project (str): The Weights & Biases project to log to.
        wandb_id (str): The Weights & Biases ID.
        run_name (str): The name of the run.
        wandb_entity (str): The Weights & Biases entity.
        wandb_log_frequency (int): The frequency to log to Weights & Biases.
        eval_every_n_wandb_logs (int): The frequency to evaluate.
        resume (bool): Whether to resume training.
        n_checkpoints (int): The number of checkpoints.
        checkpoint_path (str): The path to save checkpoints.
        verbose (bool): Whether to print verbose output.
        model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
        model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
    """

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.0.hook_mlp_out"
    hook_eval: str = "NOT_IN_USE"
    hook_layer: int = 0
    hook_head_index: Optional[int] = None
    dataset_path: str = ""
    dataset_trust_remote_code: bool = True
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    use_cached_activations: bool = False
    cached_activations_path: Optional[str] = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )

    # SAE Parameters
    architecture: Literal["standard", "gated", "jumprelu"] = "standard"
    d_in: int = 512
    d_sae: Optional[int] = None
    b_dec_init_method: str = "geometric_median"
    expansion_factor: Optional[int] = (
        None  # defaults to 4 if d_sae and expansion_factor is None
    )
    activation_fn: str = "relu"  # relu, tanh-relu, topk
    activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)  # for topk
    normalize_sae_decoder: bool = True
    noise_scale: float = 0.0
    from_pretrained_path: Optional[str] = None
    apply_b_dec_to_input: bool = True
    decoder_orthogonal_init: bool = False
    decoder_heuristic_init: bool = False
    init_encoder_as_decoder_transpose: bool = False

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    finetuning_tokens: int = 0
    store_batch_size_prompts: int = 32
    train_batch_size_tokens: int = 4096
    normalize_activations: str = (
        "none"  # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
    )
    seqpos_slice: tuple[int | None, ...] = (None,)

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"  # type: ignore #
    prepend_bos: bool = True
    jumprelu_init_threshold: float = 0.001
    jumprelu_bandwidth: float = 0.001

    # Performance - see compilation section of lm_runner.py for info
    autocast: bool = False  # autocast to autocast_dtype during training
    autocast_lm: bool = False  # autocast lm during activation fetching
    compile_llm: bool = False  # use torch.compile on the LLM
    llm_compilation_mode: str | None = None  # which torch.compile mode to use
    compile_sae: bool = False  # use torch.compile on the SAE
    sae_compilation_mode: str | None = None

    # Training Parameters

    ## Batch size
    train_batch_size_tokens: int = 4096

    ## Adam
    adam_beta1: float = 0
    adam_beta2: float = 0.999

    ## Loss Function
    mse_loss_normalization: Optional[str] = None
    l1_coefficient: float = 1e-3
    lp_norm: float = 1
    scale_sparsity_penalty_by_decoder_norm: bool = False
    l1_warm_up_steps: int = 0

    ## Learning Rate Schedule
    lr: float = 3e-4
    lr_scheduler_name: str = (
        "constant"  # constant, cosineannealing, cosineannealingwarmrestarts
    )
    lr_warm_up_steps: int = 0
    lr_end: Optional[float] = None  # only used for cosine annealing, default is lr / 10
    lr_decay_steps: int = 0
    n_restart_cycles: int = 1  # used only for cosineannealingwarmrestarts

    ## FineTuning
    finetuning_method: Optional[str] = None  # scale, decoder or unrotated_decoder

    # Resampling protocol args
    use_ghost_grads: bool = False  # want to change this to true on some timeline.
    feature_sampling_window: int = 2000
    dead_feature_window: int = 1000  # unless this window is larger feature sampling,

    dead_feature_threshold: float = 1e-8

    # Evals
    n_eval_batches: int = 10
    eval_batch_size_prompts: int | None = None  # useful if evals cause OOM

    # WANDB
    log_to_wandb: bool = True
    log_activations_store_to_wandb: bool = False
    log_optimizer_state_to_wandb: bool = False
    wandb_project: str = "mats_sae_training_language_model"
    wandb_id: Optional[str] = None
    run_name: Optional[str] = None
    wandb_entity: Optional[str] = None
    wandb_log_frequency: int = 10
    eval_every_n_wandb_logs: int = 100  # logs every 1000 steps.

    # Misc
    resume: bool = False
    n_checkpoints: int = 0
    checkpoint_path: str = "checkpoints"
    verbose: bool = True
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] | None = None
    sae_lens_version: str = field(default_factory=lambda: __version__)
    sae_lens_training_version: str = field(default_factory=lambda: __version__)

    def __post_init__(self):
        if self.resume:
            raise ValueError(
                "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
                + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
            )

        if self.use_cached_activations and self.cached_activations_path is None:
            self.cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )

        if self.d_sae is not None and self.expansion_factor is not None:
            raise ValueError("You can't set both d_sae and expansion_factor.")

        if self.d_sae is None and self.expansion_factor is None:
            self.expansion_factor = 4

        if self.d_sae is None and self.expansion_factor is not None:
            self.d_sae = self.d_in * self.expansion_factor
        self.tokens_per_buffer = (
            self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
        )

        if self.run_name is None:
            self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"

        if self.model_from_pretrained_kwargs is None:
            if self.model_class_name == "HookedTransformer":
                self.model_from_pretrained_kwargs = {"center_writing_weights": False}
            else:
                self.model_from_pretrained_kwargs = {}

        if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
            raise ValueError(
                f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
            )

        if self.normalize_sae_decoder and self.decoder_heuristic_init:
            raise ValueError(
                "You can't normalize the decoder and use heuristic initialization."
            )

        if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
            raise ValueError(
                "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
            )

        # if we use decoder fine tuning, we can't be applying b_dec to the input
        if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
            raise ValueError(
                "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
            )

        if self.normalize_activations not in [
            "none",
            "expected_average_only_in",
            "constant_norm_rescale",
            "layer_norm",
        ]:
            raise ValueError(
                f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
            )

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

        if self.lr_end is None:
            self.lr_end = self.lr / 10

        unique_id = self.wandb_id
        if unique_id is None:
            unique_id = cast(
                Any, wandb
            ).util.generate_id()  # not sure why this type is erroring
        self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

        if self.verbose:
            print(
                f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
            )
            # Print out some useful info:
            n_tokens_per_buffer = (
                self.store_batch_size_prompts
                * self.context_size
                * self.n_batches_in_buffer
            )
            print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 ** 6}")
            n_contexts_per_buffer = (
                self.store_batch_size_prompts * self.n_batches_in_buffer
            )
            print(
                f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 ** 6}"
            )

            total_training_steps = (
                self.training_tokens + self.finetuning_tokens
            ) // self.train_batch_size_tokens
            print(f"Total training steps: {total_training_steps}")

            total_wandb_updates = total_training_steps // self.wandb_log_frequency
            print(f"Total wandb updates: {total_wandb_updates}")

            # how many times will we sample dead neurons?
            # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
            n_feature_window_samples = (
                total_training_steps // self.feature_sampling_window
            )
            print(
                f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
            )
            print(
                f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
            )
            print(
                f"We will reset the sparsity calculation {n_feature_window_samples} times."
            )
            # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
            print(
                f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
            )

        if self.use_ghost_grads:
            print("Using Ghost Grads.")

        if self.context_size < 0:
            raise ValueError(
                f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
            )

        _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

    @property
    def total_training_tokens(self) -> int:
        return self.training_tokens + self.finetuning_tokens

    @property
    def total_training_steps(self) -> int:
        return self.total_training_tokens // self.train_batch_size_tokens

    def get_base_sae_cfg_dict(self) -> dict[str, Any]:
        return {
            # TEMP
            "architecture": self.architecture,
            "d_in": self.d_in,
            "d_sae": self.d_sae,
            "dtype": self.dtype,
            "device": self.device,
            "model_name": self.model_name,
            "hook_name": self.hook_name,
            "hook_layer": self.hook_layer,
            "hook_head_index": self.hook_head_index,
            "activation_fn_str": self.activation_fn,
            "apply_b_dec_to_input": self.apply_b_dec_to_input,
            "context_size": self.context_size,
            "prepend_bos": self.prepend_bos,
            "dataset_path": self.dataset_path,
            "dataset_trust_remote_code": self.dataset_trust_remote_code,
            "finetuning_scaling_factor": self.finetuning_method is not None,
            "sae_lens_training_version": self.sae_lens_training_version,
            "normalize_activations": self.normalize_activations,
            "activation_fn_kwargs": self.activation_fn_kwargs,
            "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
            "seqpos_slice": self.seqpos_slice,
        }

    def get_training_sae_cfg_dict(self) -> dict[str, Any]:
        return {
            **self.get_base_sae_cfg_dict(),
            "l1_coefficient": self.l1_coefficient,
            "lp_norm": self.lp_norm,
            "use_ghost_grads": self.use_ghost_grads,
            "normalize_sae_decoder": self.normalize_sae_decoder,
            "noise_scale": self.noise_scale,
            "decoder_orthogonal_init": self.decoder_orthogonal_init,
            "mse_loss_normalization": self.mse_loss_normalization,
            "decoder_heuristic_init": self.decoder_heuristic_init,
            "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
            "normalize_activations": self.normalize_activations,
            "jumprelu_init_threshold": self.jumprelu_init_threshold,
            "jumprelu_bandwidth": self.jumprelu_bandwidth,
            "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
        }

    def to_dict(self) -> dict[str, Any]:
        cfg_dict = {
            **self.__dict__,
            # some args may not be serializable by default
            "dtype": str(self.dtype),
            "device": str(self.device),
            "act_store_device": str(self.act_store_device),
        }

        return cfg_dict

    def to_json(self, path: str) -> None:
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))

        with open(path + "cfg.json", "w") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
        with open(path + "cfg.json", "r") as f:
            cfg = json.load(f)

        # ensure that seqpos slices is a tuple
        # Ensure seqpos_slice is a tuple
        if "seqpos_slice" in cfg:
            if isinstance(cfg["seqpos_slice"], list):
                cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
            elif not isinstance(cfg["seqpos_slice"], tuple):
                cfg["seqpos_slice"] = (cfg["seqpos_slice"],)

        return cls(**cfg)

PretokenizeRunner

Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
class PretokenizeRunner:
    """
    Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.
    """

    def __init__(self, cfg: PretokenizeRunnerConfig):
        self.cfg = cfg

    def run(self):
        """
        Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
        """
        dataset = load_dataset(
            self.cfg.dataset_path,
            data_dir=self.cfg.data_dir,
            data_files=self.cfg.data_files,
            split=self.cfg.split,
            streaming=self.cfg.streaming,
        )
        if isinstance(dataset, DatasetDict):
            raise ValueError(
                "Dataset has multiple splits. Must provide a 'split' param."
            )
        tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
        tokenizer.model_max_length = sys.maxsize
        tokenized_dataset = pretokenize_dataset(
            cast(Dataset, dataset), tokenizer, self.cfg
        )

        if self.cfg.save_path is not None:
            tokenized_dataset.save_to_disk(self.cfg.save_path)
            metadata = metadata_from_config(self.cfg)
            metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
            with open(metadata_path, "w") as f:
                json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

        if self.cfg.hf_repo_id is not None:
            push_to_hugging_face_hub(tokenized_dataset, self.cfg)

        return tokenized_dataset

run()

Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
def run(self):
    """
    Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
    """
    dataset = load_dataset(
        self.cfg.dataset_path,
        data_dir=self.cfg.data_dir,
        data_files=self.cfg.data_files,
        split=self.cfg.split,
        streaming=self.cfg.streaming,
    )
    if isinstance(dataset, DatasetDict):
        raise ValueError(
            "Dataset has multiple splits. Must provide a 'split' param."
        )
    tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
    tokenizer.model_max_length = sys.maxsize
    tokenized_dataset = pretokenize_dataset(
        cast(Dataset, dataset), tokenizer, self.cfg
    )

    if self.cfg.save_path is not None:
        tokenized_dataset.save_to_disk(self.cfg.save_path)
        metadata = metadata_from_config(self.cfg)
        metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

    if self.cfg.hf_repo_id is not None:
        push_to_hugging_face_hub(tokenized_dataset, self.cfg)

    return tokenized_dataset

SAE

Bases: HookedRootModule

Core Sparse Autoencoder (SAE) class used for inference. For training, see TrainingSAE.

Source code in sae_lens/sae.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
class SAE(HookedRootModule):
    """
    Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
    """

    cfg: SAEConfig
    dtype: torch.dtype
    device: torch.device

    # analysis
    use_error_term: bool

    def __init__(
        self,
        cfg: SAEConfig,
        use_error_term: bool = False,
    ):
        super().__init__()

        self.cfg = cfg

        if cfg.model_from_pretrained_kwargs:
            warnings.warn(
                "\nThis SAE has non-empty model_from_pretrained_kwargs. "
                "\nFor optimal performance, load the model like so:\n"
                "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
                category=UserWarning,
                stacklevel=1,
            )

        self.activation_fn = get_activation_fn(
            cfg.activation_fn_str, **cfg.activation_fn_kwargs or {}
        )
        self.dtype = DTYPE_MAP[cfg.dtype]
        self.device = torch.device(cfg.device)
        self.use_error_term = use_error_term

        if self.cfg.architecture == "standard":
            self.initialize_weights_basic()
            self.encode = self.encode_standard
        elif self.cfg.architecture == "gated":
            self.initialize_weights_gated()
            self.encode = self.encode_gated
        elif self.cfg.architecture == "jumprelu":
            self.initialize_weights_jumprelu()
            self.encode = self.encode_jumprelu
        else:
            raise ValueError(f"Invalid architecture: {self.cfg.architecture}")

        # handle presence / absence of scaling factor.
        if self.cfg.finetuning_scaling_factor:
            self.apply_finetuning_scaling_factor = (
                lambda x: x * self.finetuning_scaling_factor
            )
        else:
            self.apply_finetuning_scaling_factor = lambda x: x

        # set up hooks
        self.hook_sae_input = HookPoint()
        self.hook_sae_acts_pre = HookPoint()
        self.hook_sae_acts_post = HookPoint()
        self.hook_sae_output = HookPoint()
        self.hook_sae_recons = HookPoint()
        self.hook_sae_error = HookPoint()

        # handle hook_z reshaping if needed.
        # this is very cursed and should be refactored. it exists so that we can reshape out
        # the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
        # into a separate encode and decode function.
        # this will cause errors if we call decode before encode.
        if self.cfg.hook_name.endswith("_z"):
            self.turn_on_forward_pass_hook_z_reshaping()
        else:
            # need to default the reshape fns
            self.turn_off_forward_pass_hook_z_reshaping()

        # handle run time activation normalization if needed:
        if self.cfg.normalize_activations == "constant_norm_rescale":

            #  we need to scale the norm of the input and store the scaling factor
            def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
                self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
                x = x * self.x_norm_coeff
                return x

            def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:  #
                x = x / self.x_norm_coeff
                del self.x_norm_coeff  # prevents reusing
                return x

            self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
            self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

        elif self.cfg.normalize_activations == "layer_norm":

            #  we need to scale the norm of the input and store the scaling factor
            def run_time_activation_ln_in(
                x: torch.Tensor, eps: float = 1e-5
            ) -> torch.Tensor:
                mu = x.mean(dim=-1, keepdim=True)
                x = x - mu
                std = x.std(dim=-1, keepdim=True)
                x = x / (std + eps)
                self.ln_mu = mu
                self.ln_std = std
                return x

            def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5):
                return x * self.ln_std + self.ln_mu

            self.run_time_activation_norm_fn_in = run_time_activation_ln_in
            self.run_time_activation_norm_fn_out = run_time_activation_ln_out
        else:
            self.run_time_activation_norm_fn_in = lambda x: x
            self.run_time_activation_norm_fn_out = lambda x: x

        self.setup()  # Required for `HookedRootModule`s

    def initialize_weights_basic(self):

        # no config changes encoder bias init for now.
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        # Start with the default init strategy:
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
        )

        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
                )
            )
        )

        # methdods which change b_dec as a function of the dataset are implemented after init.
        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

        # scaling factor for fine-tuning (not to be used in initial training)
        # TODO: Make this optional and not included with all SAEs by default (but maintain backwards compatibility)
        if self.cfg.finetuning_scaling_factor:
            self.finetuning_scaling_factor = nn.Parameter(
                torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
            )

    def initialize_weights_gated(self):
        # Initialize the weights and biases for the gated encoder
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
                )
            )
        )

        self.b_gate = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.r_mag = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.b_mag = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
        )

        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

    def initialize_weights_jumprelu(self):
        # The params are identical to the standard SAE
        # except we use a threshold parameter too
        self.threshold = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )
        self.initialize_weights_basic()

    @overload
    def to(
        self: T,
        device: Optional[Union[torch.device, str]] = ...,
        dtype: Optional[torch.dtype] = ...,
        non_blocking: bool = ...,
    ) -> T: ...

    @overload
    def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...

    @overload
    def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...

    def to(self, *args: Any, **kwargs: Any) -> "SAE":  # type: ignore
        device_arg = None
        dtype_arg = None

        # Check args
        for arg in args:
            if isinstance(arg, (torch.device, str)):
                device_arg = arg
            elif isinstance(arg, torch.dtype):
                dtype_arg = arg
            elif isinstance(arg, torch.Tensor):
                device_arg = arg.device
                dtype_arg = arg.dtype

        # Check kwargs
        device_arg = kwargs.get("device", device_arg)
        dtype_arg = kwargs.get("dtype", dtype_arg)

        if device_arg is not None:
            # Convert device to torch.device if it's a string
            device = (
                torch.device(device_arg) if isinstance(device_arg, str) else device_arg
            )

            # Update the cfg.device
            self.cfg.device = str(device)

            # Update the .device property
            self.device = device

        if dtype_arg is not None:
            # Update the cfg.dtype
            self.cfg.dtype = str(dtype_arg)

            # Update the .dtype property
            self.dtype = dtype_arg

        # Call the parent class's to() method to handle all cases (device, dtype, tensor)
        return super().to(*args, **kwargs)

    # Basic Forward Pass Functionality.
    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        # TEMP
        if self.use_error_term:
            with torch.no_grad():
                # Recompute everything without hooks to get true error term
                # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
                # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
                # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
                with _disable_hooks(self):
                    feature_acts_clean = self.encode(x)
                    x_reconstruct_clean = self.decode(feature_acts_clean)
                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            sae_out = sae_out + sae_error
        return self.hook_sae_output(sae_out)

    def encode_gated(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        sae_in = self.process_sae_in(x)

        # Gating path
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path with weight sharing
        magnitude_pre_activation = self.hook_sae_acts_pre(
            sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        )
        feature_magnitudes = self.activation_fn(magnitude_pre_activation)

        feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

        return feature_acts

    def encode_jumprelu(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Calculate SAE features from inputs
        """
        sae_in = self.process_sae_in(x)

        # "... d_in, d_in d_sae -> ... d_sae",
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        feature_acts = self.hook_sae_acts_post(
            self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
        )

        return feature_acts

    def encode_standard(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Calculate SAE features from inputs
        """
        sae_in = self.process_sae_in(x)

        # "... d_in, d_in d_sae -> ... d_sae",
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

        return feature_acts

    def process_sae_in(
        self, sae_in: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        sae_in = sae_in.to(self.dtype)
        sae_in = self.reshape_fn_in(sae_in)
        sae_in = self.hook_sae_input(sae_in)
        sae_in = self.run_time_activation_norm_fn_in(sae_in)
        sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
        return sae_in

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
        # "... d_sae, d_sae d_in -> ... d_in",
        sae_out = self.hook_sae_recons(
            self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
        )

        # handle run time activation normalization if needed
        # will fail if you call this twice without calling encode in between.
        sae_out = self.run_time_activation_norm_fn_out(sae_out)

        # handle hook z reshaping if needed.
        sae_out = self.reshape_fn_out(sae_out, self.d_head)  # type: ignore

        return sae_out

    @torch.no_grad()
    def fold_W_dec_norm(self):
        W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T
        if self.cfg.architecture == "gated":
            self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
            self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
            self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
        else:
            self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

    @torch.no_grad()
    def fold_activation_norm_scaling_factor(
        self, activation_norm_scaling_factor: float
    ):
        self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
        # previously weren't doing this.
        self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
        self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor

        # once we normalize, we shouldn't need to scale activations.
        self.cfg.normalize_activations = "none"

    def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
        if not os.path.exists(path):
            os.mkdir(path)

        # generate the weights
        state_dict = self.state_dict()
        self.process_state_dict_for_saving(state_dict)
        save_file(state_dict, f"{path}/{SAE_WEIGHTS_PATH}")

        # save the config
        config = self.cfg.to_dict()

        with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
            json.dump(config, f)

        if sparsity is not None:
            sparsity_in_dict = {"sparsity": sparsity}
            save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}")  # type: ignore

    # overwrite this in subclasses to modify the state_dict in-place before saving
    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        pass

    # overwrite this in subclasses to modify the state_dict in-place after loading
    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        pass

    @classmethod
    def load_from_pretrained(
        cls, path: str, device: str = "cpu", dtype: str | None = None
    ) -> "SAE":

        # get the config
        config_path = os.path.join(path, SAE_CFG_PATH)
        with open(config_path, "r") as f:
            cfg_dict = json.load(f)
        cfg_dict = handle_config_defaulting(cfg_dict)
        cfg_dict["device"] = device
        if dtype is not None:
            cfg_dict["dtype"] = dtype

        weight_path = os.path.join(path, SAE_WEIGHTS_PATH)
        cfg_dict, state_dict = read_sae_from_disk(
            cfg_dict=cfg_dict,
            weight_path=weight_path,
            device=device,
        )

        sae_cfg = SAEConfig.from_dict(cfg_dict)

        sae = cls(sae_cfg)
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict)

        return sae

    @classmethod
    def from_pretrained(
        cls,
        release: str,
        sae_id: str,
        device: str = "cpu",
    ) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
        """

        Load a pretrained SAE from the Hugging Face model hub.

        Args:
            release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
            id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
            device: The device to load the SAE on.
            return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
        """

        # get sae directory
        sae_directory = get_pretrained_saes_directory()

        # get the repo id and path to the SAE
        if release not in sae_directory:
            if "/" not in release:
                raise ValueError(
                    f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
                )
        elif sae_id not in sae_directory[release].saes_map:
            # If using Gemma Scope and not the canonical release, give a hint to use it
            if (
                "gemma-scope" in release
                and "canonical" not in release
                and f"{release}-canonical" in sae_directory
            ):
                canonical_ids = list(
                    sae_directory[release + "-canonical"].saes_map.keys()
                )
                # Shorten the lengthy string of valid IDs
                if len(canonical_ids) > 5:
                    str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
                else:
                    str_canonical_ids = str(canonical_ids)
                value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
            else:
                value_suffix = ""

            valid_ids = list(sae_directory[release].saes_map.keys())
            # Shorten the lengthy string of valid IDs
            if len(valid_ids) > 5:
                str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
            else:
                str_valid_ids = str(valid_ids)

            raise ValueError(
                f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
                + value_suffix
            )
        sae_info = sae_directory.get(release, None)
        config_overrides = sae_info.config_overrides if sae_info is not None else None

        conversion_loader_name = get_conversion_loader_name(sae_info)
        conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]

        cfg_dict, state_dict, log_sparsities = conversion_loader(
            release,
            sae_id=sae_id,
            device=device,
            force_download=False,
            cfg_overrides=config_overrides,
        )

        sae = cls(SAEConfig.from_dict(cfg_dict))
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict)

        # Check if normalization is 'expected_average_only_in'
        if cfg_dict.get("normalize_activations") == "expected_average_only_in":
            norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
            if norm_scaling_factor is not None:
                sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
                cfg_dict["normalize_activations"] = "none"
            else:
                warnings.warn(
                    f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
                )

        return sae, cfg_dict, log_sparsities

    def get_name(self):
        sae_name = f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
        return sae_name

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
        return cls(SAEConfig.from_dict(config_dict))

    def turn_on_forward_pass_hook_z_reshaping(self):

        assert self.cfg.hook_name.endswith(
            "_z"
        ), "This method should only be called for hook_z SAEs."

        def reshape_fn_in(x: torch.Tensor):
            self.d_head = x.shape[-1]  # type: ignore
            self.reshape_fn_in = lambda x: einops.rearrange(
                x, "... n_heads d_head -> ... (n_heads d_head)"
            )
            return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

        self.reshape_fn_in = reshape_fn_in

        self.reshape_fn_out = lambda x, d_head: einops.rearrange(
            x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
        )
        self.hook_z_reshaping_mode = True

    def turn_off_forward_pass_hook_z_reshaping(self):
        self.reshape_fn_in = lambda x: x
        self.reshape_fn_out = lambda x, d_head: x
        self.d_head = None
        self.hook_z_reshaping_mode = False

decode(feature_acts)

Decodes SAE feature activation tensor into a reconstructed input activation tensor.

Source code in sae_lens/sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
    # "... d_sae, d_sae d_in -> ... d_in",
    sae_out = self.hook_sae_recons(
        self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
    )

    # handle run time activation normalization if needed
    # will fail if you call this twice without calling encode in between.
    sae_out = self.run_time_activation_norm_fn_out(sae_out)

    # handle hook z reshaping if needed.
    sae_out = self.reshape_fn_out(sae_out, self.d_head)  # type: ignore

    return sae_out

encode_jumprelu(x)

Calculate SAE features from inputs

Source code in sae_lens/sae.py
def encode_jumprelu(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Calculate SAE features from inputs
    """
    sae_in = self.process_sae_in(x)

    # "... d_in, d_in d_sae -> ... d_sae",
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    feature_acts = self.hook_sae_acts_post(
        self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
    )

    return feature_acts

encode_standard(x)

Calculate SAE features from inputs

Source code in sae_lens/sae.py
def encode_standard(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Calculate SAE features from inputs
    """
    sae_in = self.process_sae_in(x)

    # "... d_in, d_in d_sae -> ... d_sae",
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    return feature_acts

from_pretrained(release, sae_id, device='cpu') classmethod

Load a pretrained SAE from the Hugging Face model hub.

Parameters:

Name Type Description Default
release str

The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.

required
id

The id of the SAE to load. This will be mapped to a path in the huggingface repo.

required
device str

The device to load the SAE on.

'cpu'
return_sparsity_if_present

If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.

required
Source code in sae_lens/sae.py
@classmethod
def from_pretrained(
    cls,
    release: str,
    sae_id: str,
    device: str = "cpu",
) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
    """

    Load a pretrained SAE from the Hugging Face model hub.

    Args:
        release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
        id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
        device: The device to load the SAE on.
        return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
    """

    # get sae directory
    sae_directory = get_pretrained_saes_directory()

    # get the repo id and path to the SAE
    if release not in sae_directory:
        if "/" not in release:
            raise ValueError(
                f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
            )
    elif sae_id not in sae_directory[release].saes_map:
        # If using Gemma Scope and not the canonical release, give a hint to use it
        if (
            "gemma-scope" in release
            and "canonical" not in release
            and f"{release}-canonical" in sae_directory
        ):
            canonical_ids = list(
                sae_directory[release + "-canonical"].saes_map.keys()
            )
            # Shorten the lengthy string of valid IDs
            if len(canonical_ids) > 5:
                str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
            else:
                str_canonical_ids = str(canonical_ids)
            value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
        else:
            value_suffix = ""

        valid_ids = list(sae_directory[release].saes_map.keys())
        # Shorten the lengthy string of valid IDs
        if len(valid_ids) > 5:
            str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
        else:
            str_valid_ids = str(valid_ids)

        raise ValueError(
            f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
            + value_suffix
        )
    sae_info = sae_directory.get(release, None)
    config_overrides = sae_info.config_overrides if sae_info is not None else None

    conversion_loader_name = get_conversion_loader_name(sae_info)
    conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]

    cfg_dict, state_dict, log_sparsities = conversion_loader(
        release,
        sae_id=sae_id,
        device=device,
        force_download=False,
        cfg_overrides=config_overrides,
    )

    sae = cls(SAEConfig.from_dict(cfg_dict))
    sae.process_state_dict_for_loading(state_dict)
    sae.load_state_dict(state_dict)

    # Check if normalization is 'expected_average_only_in'
    if cfg_dict.get("normalize_activations") == "expected_average_only_in":
        norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
        if norm_scaling_factor is not None:
            sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
            cfg_dict["normalize_activations"] = "none"
        else:
            warnings.warn(
                f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
            )

    return sae, cfg_dict, log_sparsities

SAETrainingRunner

Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.

Source code in sae_lens/sae_training_runner.py
class SAETrainingRunner:
    """
    Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
    """

    cfg: LanguageModelSAERunnerConfig
    model: HookedRootModule
    sae: TrainingSAE
    activations_store: ActivationsStore

    def __init__(
        self,
        cfg: LanguageModelSAERunnerConfig,
        override_dataset: HfDataset | None = None,
        override_model: HookedRootModule | None = None,
        override_sae: TrainingSAE | None = None,
    ):
        if override_dataset is not None:
            logging.warning(
                f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
            )
        if override_model is not None:
            logging.warning(
                f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
            )

        self.cfg = cfg

        if override_model is None:
            self.model = load_model(
                self.cfg.model_class_name,
                self.cfg.model_name,
                device=self.cfg.device,
                model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
            )
        else:
            self.model = override_model

        self.activations_store = ActivationsStore.from_config(
            self.model,
            self.cfg,
            override_dataset=override_dataset,
        )

        if override_sae is None:
            if self.cfg.from_pretrained_path is not None:
                self.sae = TrainingSAE.load_from_pretrained(
                    self.cfg.from_pretrained_path, self.cfg.device
                )
            else:
                self.sae = TrainingSAE(
                    TrainingSAEConfig.from_dict(
                        self.cfg.get_training_sae_cfg_dict(),
                    )
                )
                self._init_sae_group_b_decs()
        else:
            self.sae = override_sae

    def run(self):
        """
        Run the training of the SAE.
        """

        if self.cfg.log_to_wandb:
            wandb.init(
                project=self.cfg.wandb_project,
                entity=self.cfg.wandb_entity,
                config=cast(Any, self.cfg),
                name=self.cfg.run_name,
                id=self.cfg.wandb_id,
            )

        trainer = SAETrainer(
            model=self.model,
            sae=self.sae,
            activation_store=self.activations_store,
            save_checkpoint_fn=self.save_checkpoint,
            cfg=self.cfg,
        )

        self._compile_if_needed()
        sae = self.run_trainer_with_interruption_handling(trainer)

        if self.cfg.log_to_wandb:
            # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
            wandb.finish()  # type: ignore

        return sae

    def _compile_if_needed(self):

        # Compile model and SAE
        #  torch.compile can provide significant speedups (10-20% in testing)
        # using max-autotune gives the best speedups but:
        # (a) increases VRAM usage,
        # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
        # (c) takes some time to compile
        # optimal settings seem to be:
        # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
        # (also pylance seems to really hate this)
        if self.cfg.compile_llm:
            self.model = torch.compile(
                self.model,
                mode=self.cfg.llm_compilation_mode,
            )  # type: ignore

        if self.cfg.compile_sae:
            if self.cfg.device == "mps":
                backend = "aot_eager"
            else:
                backend = "inductor"

            self.sae.training_forward_pass = torch.compile(  # type: ignore
                self.sae.training_forward_pass,
                mode=self.cfg.sae_compilation_mode,
                backend=backend,
            )  # type: ignore

    def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
        try:
            # signal handlers (if preempted)
            signal.signal(signal.SIGINT, interrupt_callback)
            signal.signal(signal.SIGTERM, interrupt_callback)

            # train SAE
            sae = trainer.fit()

        except (KeyboardInterrupt, InterruptedException):
            print("interrupted, saving progress")
            checkpoint_name = trainer.n_training_tokens
            self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
            print("done saving")
            raise

        return sae

    # TODO: move this into the SAE trainer or Training SAE class
    def _init_sae_group_b_decs(
        self,
    ) -> None:
        """
        extract all activations at a certain layer and use for sae b_dec initialization
        """

        if self.cfg.b_dec_init_method == "geometric_median":
            layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
            # get geometric median of the activations if we're using those.
            median = compute_geometric_median(
                layer_acts,
                maxiter=100,
            ).median
            self.sae.initialize_b_dec_with_precalculated(median)  # type: ignore
        elif self.cfg.b_dec_init_method == "mean":
            layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
            self.sae.initialize_b_dec_with_mean(layer_acts)  # type: ignore

    def save_checkpoint(
        self,
        trainer: SAETrainer,
        checkpoint_name: int | str,
        wandb_aliases: list[str] | None = None,
    ) -> str:

        checkpoint_path = f"{trainer.cfg.checkpoint_path}/{checkpoint_name}"

        os.makedirs(checkpoint_path, exist_ok=True)

        path = f"{checkpoint_path}"
        os.makedirs(path, exist_ok=True)

        if self.sae.cfg.normalize_sae_decoder:
            self.sae.set_decoder_norm_to_unit_norm()
        self.sae.save_model(path)

        # let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
        # and should not cause issues but give us more info about SAEs we trained in SAE Lens.
        config = trainer.cfg.to_dict()
        with open(f"{path}/cfg.json", "w") as f:
            json.dump(config, f)

        log_feature_sparsities = {"sparsity": trainer.log_feature_sparsity}

        log_feature_sparsity_path = f"{path}/{SPARSITY_PATH}"
        save_file(log_feature_sparsities, log_feature_sparsity_path)

        if trainer.cfg.log_to_wandb and os.path.exists(log_feature_sparsity_path):
            # Avoid wandb saving errors such as:
            #   ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
            sae_name = self.sae.get_name().replace("/", "__")

            model_artifact = wandb.Artifact(
                sae_name,
                type="model",
                metadata=dict(trainer.cfg.__dict__),
            )

            model_artifact.add_file(f"{path}/{SAE_WEIGHTS_PATH}")
            model_artifact.add_file(f"{path}/{SAE_CFG_PATH}")

            # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
            wandb.log_artifact(model_artifact, aliases=wandb_aliases)  # type: ignore

            sparsity_artifact = wandb.Artifact(
                f"{sae_name}_log_feature_sparsity",
                type="log_feature_sparsity",
                metadata=dict(trainer.cfg.__dict__),
            )
            sparsity_artifact.add_file(log_feature_sparsity_path)
            # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
            wandb.log_artifact(sparsity_artifact)  # type: ignore

        return checkpoint_path

run()

Run the training of the SAE.

Source code in sae_lens/sae_training_runner.py
def run(self):
    """
    Run the training of the SAE.
    """

    if self.cfg.log_to_wandb:
        wandb.init(
            project=self.cfg.wandb_project,
            entity=self.cfg.wandb_entity,
            config=cast(Any, self.cfg),
            name=self.cfg.run_name,
            id=self.cfg.wandb_id,
        )

    trainer = SAETrainer(
        model=self.model,
        sae=self.sae,
        activation_store=self.activations_store,
        save_checkpoint_fn=self.save_checkpoint,
        cfg=self.cfg,
    )

    self._compile_if_needed()
    sae = self.run_trainer_with_interruption_handling(trainer)

    if self.cfg.log_to_wandb:
        # remove this type ignore comment after https://github.com/wandb/wandb/issues/8248 is resolved
        wandb.finish()  # type: ignore

    return sae

TrainingSAE

Bases: SAE

A SAE used for training. This class provides a training_forward_pass method which calculates losses used for training.

Source code in sae_lens/training/training_sae.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
class TrainingSAE(SAE):
    """
    A SAE used for training. This class provides a `training_forward_pass` method which calculates
    losses used for training.
    """

    cfg: TrainingSAEConfig
    use_error_term: bool
    dtype: torch.dtype
    device: torch.device

    def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
        base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
        super().__init__(base_sae_cfg)
        self.cfg = cfg  # type: ignore

        if cfg.architecture == "standard":
            self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
        elif cfg.architecture == "gated":
            self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
        elif cfg.architecture == "jumprelu":
            self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
            self.bandwidth = cfg.jumprelu_bandwidth
            self.log_threshold.data = torch.ones(
                self.cfg.d_sae, dtype=self.dtype, device=self.device
            ) * np.log(cfg.jumprelu_init_threshold)

        else:
            raise ValueError(f"Unknown architecture: {cfg.architecture}")

        self.check_cfg_compatibility()

        self.use_error_term = use_error_term

        self.initialize_weights_complex()

        # The training SAE will assume that the activation store handles
        # reshaping.
        self.turn_off_forward_pass_hook_z_reshaping()

        self.mse_loss_fn = self._get_mse_loss_fn()

    def initialize_weights_jumprelu(self):
        # same as the superclass, except we use a log_threshold parameter instead of threshold
        self.log_threshold = nn.Parameter(
            torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )
        self.initialize_weights_basic()

    @property
    def threshold(self) -> torch.Tensor:
        if self.cfg.architecture != "jumprelu":
            raise ValueError("Threshold is only defined for Jumprelu SAEs")
        return torch.exp(self.log_threshold)

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE":
        return cls(TrainingSAEConfig.from_dict(config_dict))

    def check_cfg_compatibility(self):
        if self.cfg.architecture == "gated":
            assert (
                self.cfg.use_ghost_grads is False
            ), "Gated SAEs do not support ghost grads"
            assert self.use_error_term is False, "Gated SAEs do not support error terms"

    def encode_standard(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Calcuate SAE features from inputs
        """
        feature_acts, _ = self.encode_with_hidden_pre_fn(x)
        return feature_acts

    def encode_with_hidden_pre_jumprelu(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        sae_in = self.process_sae_in(x)

        hidden_pre = sae_in @ self.W_enc + self.b_enc

        if self.training:
            hidden_pre = (
                hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
            )

        threshold = torch.exp(self.log_threshold)

        feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)

        return feature_acts, hidden_pre  # type: ignore

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        sae_in = self.process_sae_in(x)

        # "... d_in, d_in d_sae -> ... d_sae",
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        hidden_pre_noised = hidden_pre + (
            torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training
        )
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))

        return feature_acts, hidden_pre_noised

    def encode_with_hidden_pre_gated(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        sae_in = self.process_sae_in(x)

        # apply b_dec_to_input if using that method.
        sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)

        # Gating path with Heaviside step function
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path with weight sharing
        magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        # magnitude_pre_activation_noised = magnitude_pre_activation + (
        #     torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training
        # )
        feature_magnitudes = self.activation_fn(
            magnitude_pre_activation
        )  # magnitude_pre_activation_noised)

        # Return both the gated feature activations and the magnitude pre-activations
        return (
            active_features * feature_magnitudes,
            magnitude_pre_activation,
        )  # magnitude_pre_activation_noised

    def forward(
        self,
        x: Float[torch.Tensor, "... d_in"],
    ) -> Float[torch.Tensor, "... d_in"]:

        feature_acts, _ = self.encode_with_hidden_pre_fn(x)
        sae_out = self.decode(feature_acts)

        return sae_out

    def training_forward_pass(
        self,
        sae_in: torch.Tensor,
        current_l1_coefficient: float,
        dead_neuron_mask: Optional[torch.Tensor] = None,
    ) -> TrainStepOutput:

        # do a forward pass to get SAE out, but we also need the
        # hidden pre.
        feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
        sae_out = self.decode(feature_acts)

        # MSE LOSS
        per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
        mse_loss = per_item_mse_loss.sum(dim=-1).mean()

        losses: dict[str, float | torch.Tensor] = {}

        if self.cfg.architecture == "gated":
            # Gated SAE Loss Calculation

            # Shared variables
            sae_in_centered = (
                self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input
            )
            pi_gate = sae_in_centered @ self.W_enc + self.b_gate
            pi_gate_act = torch.relu(pi_gate)

            # SFN sparsity loss - summed over the feature dimension and averaged over the batch
            l1_loss = (
                current_l1_coefficient
                * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
            )

            # Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch
            via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
            aux_reconstruction_loss = torch.sum(
                (via_gate_reconstruction - sae_in) ** 2, dim=-1
            ).mean()
            loss = mse_loss + l1_loss + aux_reconstruction_loss
            losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
            losses["l1_loss"] = l1_loss
        elif self.cfg.architecture == "jumprelu":
            threshold = torch.exp(self.log_threshold)
            l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1)  # type: ignore
            l0_loss = (current_l1_coefficient * l0).mean()
            loss = mse_loss + l0_loss
            losses["l0_loss"] = l0_loss
        else:
            # default SAE sparsity loss
            weighted_feature_acts = feature_acts
            if self.cfg.scale_sparsity_penalty_by_decoder_norm:
                weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
            sparsity = weighted_feature_acts.norm(
                p=self.cfg.lp_norm, dim=-1
            )  # sum over the feature dimension

            l1_loss = (current_l1_coefficient * sparsity).mean()
            loss = mse_loss + l1_loss
            if (
                self.cfg.use_ghost_grads
                and self.training
                and dead_neuron_mask is not None
            ):
                ghost_grad_loss = self.calculate_ghost_grad_loss(
                    x=sae_in,
                    sae_out=sae_out,
                    per_item_mse_loss=per_item_mse_loss,
                    hidden_pre=hidden_pre,
                    dead_neuron_mask=dead_neuron_mask,
                )
                losses["ghost_grad_loss"] = ghost_grad_loss
                loss = loss + ghost_grad_loss
            losses["l1_loss"] = l1_loss

        losses["mse_loss"] = mse_loss

        return TrainStepOutput(
            sae_in=sae_in,
            sae_out=sae_out,
            feature_acts=feature_acts,
            loss=loss,
            losses=losses,
        )

    def calculate_ghost_grad_loss(
        self,
        x: torch.Tensor,
        sae_out: torch.Tensor,
        per_item_mse_loss: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor,
    ) -> torch.Tensor:

        # 1.
        residual = x - sae_out
        l2_norm_residual = torch.norm(residual, dim=-1)

        # 2.
        # ghost grads use an exponentional activation function, ignoring whatever
        # the activation function is in the SAE. The forward pass uses the dead neurons only.
        feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
        ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
        l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
        norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
        ghost_out = ghost_out * norm_scaling_factor[:, None].detach()

        # 3. There is some fairly complex rescaling here to make sure that the loss
        # is comparable to the original loss. This is because the ghost grads are
        # only calculated for the dead neurons, so we need to rescale the loss to
        # make sure that the loss is comparable to the original loss.
        # There have been methodological improvements that are not implemented here yet
        # see here: https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Improving_ghost_grads
        per_item_mse_loss_ghost_resid = self.mse_loss_fn(ghost_out, residual.detach())
        mse_rescaling_factor = (
            per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
        ).detach()
        per_item_mse_loss_ghost_resid = (
            mse_rescaling_factor * per_item_mse_loss_ghost_resid
        )

        return per_item_mse_loss_ghost_resid.mean()

    @torch.no_grad()
    def _get_mse_loss_fn(self) -> Any:

        def standard_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            return torch.nn.functional.mse_loss(preds, target, reduction="none")

        def batch_norm_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            target_centered = target - target.mean(dim=0, keepdim=True)
            normalization = target_centered.norm(dim=-1, keepdim=True)
            return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
                normalization + 1e-6
            )

        if self.cfg.mse_loss_normalization == "dense_batch":
            return batch_norm_mse_loss_fn
        else:
            return standard_mse_loss_fn

    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict:
            threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
            del state_dict["log_threshold"]
            state_dict["threshold"] = threshold

    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        if self.cfg.architecture == "jumprelu" and "threshold" in state_dict:
            threshold = state_dict["threshold"]
            del state_dict["threshold"]
            state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()

    @classmethod
    def load_from_pretrained(
        cls,
        path: str,
        device: str = "cpu",
        dtype: str | None = None,
    ) -> "TrainingSAE":

        # get the config
        config_path = os.path.join(path, SAE_CFG_PATH)
        with open(config_path, "r") as f:
            cfg_dict = json.load(f)
        cfg_dict = handle_config_defaulting(cfg_dict)
        cfg_dict["device"] = device
        if dtype is not None:
            cfg_dict["dtype"] = dtype

        weight_path = os.path.join(path, SAE_WEIGHTS_PATH)
        cfg_dict, state_dict = read_sae_from_disk(
            cfg_dict=cfg_dict,
            weight_path=weight_path,
            device=device,
        )
        sae_cfg = TrainingSAEConfig.from_dict(cfg_dict)

        sae = cls(sae_cfg)
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict)

        return sae

    def initialize_weights_complex(self):
        """ """

        if self.cfg.decoder_orthogonal_init:
            self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

        elif self.cfg.decoder_heuristic_init:
            self.W_dec = nn.Parameter(
                torch.rand(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
            self.initialize_decoder_norm_constant_norm()

        # Then we initialize the encoder weights (either as the transpose of decoder or not)
        if self.cfg.init_encoder_as_decoder_transpose:
            self.W_enc.data = self.W_dec.data.T.clone().contiguous()
        else:
            self.W_enc = nn.Parameter(
                torch.nn.init.kaiming_uniform_(
                    torch.empty(
                        self.cfg.d_in,
                        self.cfg.d_sae,
                        dtype=self.dtype,
                        device=self.device,
                    )
                )
            )

        if self.cfg.normalize_sae_decoder:
            with torch.no_grad():
                # Anthropic normalize this to have unit columns
                self.set_decoder_norm_to_unit_norm()

    ## Initialization Methods
    @torch.no_grad()
    def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
        out = torch.tensor(origin, dtype=self.dtype, device=self.device)
        self.b_dec.data = out

    @torch.no_grad()
    def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
        previous_b_dec = self.b_dec.clone().cpu()
        out = all_activations.mean(dim=0)

        previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
        distances = torch.norm(all_activations - out, dim=-1)

        print("Reinitializing b_dec with mean of activations")
        print(
            f"Previous distances: {previous_distances.median(0).values.mean().item()}"
        )
        print(f"New distances: {distances.median(0).values.mean().item()}")

        self.b_dec.data = out.to(self.dtype).to(self.device)

    ## Training Utils
    @torch.no_grad()
    def set_decoder_norm_to_unit_norm(self):
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

    @torch.no_grad()
    def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
        """
        A heuristic proceedure inspired by:
        https://transformer-circuits.pub/2024/april-update/index.html#training-saes
        """
        # TODO: Parameterise this as a function of m and n

        # ensure W_dec norms at unit norm
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
        self.W_dec.data *= norm  # will break tests but do this for now.

    @torch.no_grad()
    def remove_gradient_parallel_to_decoder_directions(self):
        """
        Update grads so that they remove the parallel component
            (d_sae, d_in) shape
        """
        assert self.W_dec.grad is not None  # keep pyright happy

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_in, d_sae d_in -> d_sae",
        )
        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_in -> d_sae d_in",
        )

encode_standard(x)

Calcuate SAE features from inputs

Source code in sae_lens/training/training_sae.py
def encode_standard(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Calcuate SAE features from inputs
    """
    feature_acts, _ = self.encode_with_hidden_pre_fn(x)
    return feature_acts

initialize_decoder_norm_constant_norm(norm=0.1)

A heuristic proceedure inspired by: https://transformer-circuits.pub/2024/april-update/index.html#training-saes

Source code in sae_lens/training/training_sae.py
@torch.no_grad()
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
    """
    A heuristic proceedure inspired by:
    https://transformer-circuits.pub/2024/april-update/index.html#training-saes
    """
    # TODO: Parameterise this as a function of m and n

    # ensure W_dec norms at unit norm
    self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
    self.W_dec.data *= norm  # will break tests but do this for now.

initialize_weights_complex()

Source code in sae_lens/training/training_sae.py
def initialize_weights_complex(self):
    """ """

    if self.cfg.decoder_orthogonal_init:
        self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

    elif self.cfg.decoder_heuristic_init:
        self.W_dec = nn.Parameter(
            torch.rand(
                self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
            )
        )
        self.initialize_decoder_norm_constant_norm()

    # Then we initialize the encoder weights (either as the transpose of decoder or not)
    if self.cfg.init_encoder_as_decoder_transpose:
        self.W_enc.data = self.W_dec.data.T.clone().contiguous()
    else:
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in,
                    self.cfg.d_sae,
                    dtype=self.dtype,
                    device=self.device,
                )
            )
        )

    if self.cfg.normalize_sae_decoder:
        with torch.no_grad():
            # Anthropic normalize this to have unit columns
            self.set_decoder_norm_to_unit_norm()

remove_gradient_parallel_to_decoder_directions()

Update grads so that they remove the parallel component (d_sae, d_in) shape

Source code in sae_lens/training/training_sae.py
@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
    """
    Update grads so that they remove the parallel component
        (d_sae, d_in) shape
    """
    assert self.W_dec.grad is not None  # keep pyright happy

    parallel_component = einops.einsum(
        self.W_dec.grad,
        self.W_dec.data,
        "d_sae d_in, d_sae d_in -> d_sae",
    )
    self.W_dec.grad -= einops.einsum(
        parallel_component,
        self.W_dec.data,
        "d_sae, d_sae d_in -> d_sae d_in",
    )