Skip to content

API

ActivationsStore

Class for streaming tokens and generating and storing activations while training SAEs.

Source code in sae_lens/training/activations_store.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
class ActivationsStore:
    """
    Class for streaming tokens and generating and storing activations
    while training SAEs.
    """

    model: HookedRootModule
    dataset: HfDataset
    cached_activations_path: str | None
    cached_activation_dataset: Dataset | None = None
    tokens_column: Literal["tokens", "input_ids", "text", "problem"]
    hook_name: str
    hook_layer: int
    hook_head_index: int | None
    _dataloader: Iterator[Any] | None = None
    _storage_buffer: torch.Tensor | None = None
    exclude_special_tokens: torch.Tensor | None = None
    device: torch.device

    @classmethod
    def from_cache_activations(
        cls,
        model: HookedRootModule,
        cfg: CacheActivationsRunnerConfig,
    ) -> ActivationsStore:
        """
        Public api to create an ActivationsStore from a cached activations dataset.
        """
        return cls(
            cached_activations_path=cfg.new_cached_activations_path,
            dtype=cfg.dtype,
            hook_name=cfg.hook_name,
            hook_layer=cfg.hook_layer,
            context_size=cfg.context_size,
            d_in=cfg.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.model_batch_size,  # get_buffer
            train_batch_size_tokens=cfg.model_batch_size,  # dataloader
            seqpos_slice=(None,),
            device=torch.device(cfg.device),  # since we're sending these to SAE
            # NOOP
            prepend_bos=False,
            hook_head_index=None,
            dataset=cfg.dataset_path,
            streaming=False,
            model=model,
            normalize_activations="none",
            model_kwargs=None,
            autocast_lm=False,
            dataset_trust_remote_code=None,
            exclude_special_tokens=None,
        )

    @classmethod
    def from_config(
        cls,
        model: HookedRootModule,
        cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
        | CacheActivationsRunnerConfig,
        override_dataset: HfDataset | None = None,
    ) -> ActivationsStore:
        if isinstance(cfg, CacheActivationsRunnerConfig):
            return cls.from_cache_activations(model, cfg)

        cached_activations_path = cfg.cached_activations_path
        # set cached_activations_path to None if we're not using cached activations
        if (
            isinstance(cfg, LanguageModelSAERunnerConfig)
            and not cfg.use_cached_activations
        ):
            cached_activations_path = None

        if override_dataset is None and cfg.dataset_path == "":
            raise ValueError(
                "You must either pass in a dataset or specify a dataset_path in your configutation."
            )

        device = torch.device(cfg.act_store_device)
        exclude_special_tokens = cfg.exclude_special_tokens
        if exclude_special_tokens is False:
            exclude_special_tokens = None
        if exclude_special_tokens is True:
            exclude_special_tokens = _get_special_token_ids(model.tokenizer)  # type: ignore
        if exclude_special_tokens is not None:
            exclude_special_tokens = torch.tensor(
                exclude_special_tokens, dtype=torch.long, device=device
            )
        return cls(
            model=model,
            dataset=override_dataset or cfg.dataset_path,
            streaming=cfg.streaming,
            hook_name=cfg.hook_name,
            hook_layer=cfg.hook_layer,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            d_in=cfg.d_in
            if isinstance(cfg, CacheActivationsRunnerConfig)
            else cfg.sae.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.store_batch_size_prompts,
            train_batch_size_tokens=cfg.train_batch_size_tokens,
            prepend_bos=cfg.prepend_bos,
            normalize_activations=cfg.sae.normalize_activations,
            device=device,
            dtype=cfg.dtype,
            cached_activations_path=cached_activations_path,
            model_kwargs=cfg.model_kwargs,
            autocast_lm=cfg.autocast_lm,
            dataset_trust_remote_code=cfg.dataset_trust_remote_code,
            seqpos_slice=cfg.seqpos_slice,
            exclude_special_tokens=exclude_special_tokens,
        )

    @classmethod
    def from_sae(
        cls,
        model: HookedRootModule,
        sae: SAE[T_SAE_CONFIG],
        dataset: HfDataset | str,
        dataset_trust_remote_code: bool = False,
        context_size: int | None = None,
        streaming: bool = True,
        store_batch_size_prompts: int = 8,
        n_batches_in_buffer: int = 8,
        train_batch_size_tokens: int = 4096,
        total_tokens: int = 10**9,
        device: str = "cpu",
    ) -> ActivationsStore:
        if sae.cfg.metadata.hook_name is None:
            raise ValueError("hook_name is required")
        if sae.cfg.metadata.hook_layer is None:
            raise ValueError("hook_layer is required")
        if sae.cfg.metadata.hook_head_index is None:
            raise ValueError("hook_head_index is required")
        if sae.cfg.metadata.context_size is None:
            raise ValueError("context_size is required")
        if sae.cfg.metadata.prepend_bos is None:
            raise ValueError("prepend_bos is required")
        return cls(
            model=model,
            dataset=dataset,
            d_in=sae.cfg.d_in,
            hook_name=sae.cfg.metadata.hook_name,
            hook_layer=sae.cfg.metadata.hook_layer,
            hook_head_index=sae.cfg.metadata.hook_head_index,
            context_size=sae.cfg.metadata.context_size
            if context_size is None
            else context_size,
            prepend_bos=sae.cfg.metadata.prepend_bos,
            streaming=streaming,
            store_batch_size_prompts=store_batch_size_prompts,
            train_batch_size_tokens=train_batch_size_tokens,
            n_batches_in_buffer=n_batches_in_buffer,
            total_training_tokens=total_tokens,
            normalize_activations=sae.cfg.normalize_activations,
            dataset_trust_remote_code=dataset_trust_remote_code,
            dtype=sae.cfg.dtype,
            device=torch.device(device),
            seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
        )

    def __init__(
        self,
        model: HookedRootModule,
        dataset: HfDataset | str,
        streaming: bool,
        hook_name: str,
        hook_layer: int,
        hook_head_index: int | None,
        context_size: int,
        d_in: int,
        n_batches_in_buffer: int,
        total_training_tokens: int,
        store_batch_size_prompts: int,
        train_batch_size_tokens: int,
        prepend_bos: bool,
        normalize_activations: str,
        device: torch.device,
        dtype: str,
        cached_activations_path: str | None = None,
        model_kwargs: dict[str, Any] | None = None,
        autocast_lm: bool = False,
        dataset_trust_remote_code: bool | None = None,
        seqpos_slice: tuple[int | None, ...] = (None,),
        exclude_special_tokens: torch.Tensor | None = None,
    ):
        self.model = model
        if model_kwargs is None:
            model_kwargs = {}
        self.model_kwargs = model_kwargs
        self.dataset = (
            load_dataset(
                dataset,
                split="train",
                streaming=streaming,
                trust_remote_code=dataset_trust_remote_code,  # type: ignore
            )
            if isinstance(dataset, str)
            else dataset
        )

        if isinstance(dataset, (Dataset, DatasetDict)):
            self.dataset = cast(Dataset | DatasetDict, self.dataset)
            n_samples = len(self.dataset)

            if n_samples < total_training_tokens:
                warnings.warn(
                    f"The training dataset contains fewer samples ({n_samples}) than the number of samples required by your training configuration ({total_training_tokens}). This will result in multiple training epochs and some samples being used more than once."
                )

        self.hook_name = hook_name
        self.hook_layer = hook_layer
        self.hook_head_index = hook_head_index
        self.context_size = context_size
        self.d_in = d_in
        self.n_batches_in_buffer = n_batches_in_buffer
        self.half_buffer_size = n_batches_in_buffer // 2
        self.total_training_tokens = total_training_tokens
        self.store_batch_size_prompts = store_batch_size_prompts
        self.train_batch_size_tokens = train_batch_size_tokens
        self.prepend_bos = prepend_bos
        self.normalize_activations = normalize_activations
        self.device = torch.device(device)
        self.dtype = DTYPE_MAP[dtype]
        self.cached_activations_path = cached_activations_path
        self.autocast_lm = autocast_lm
        self.seqpos_slice = seqpos_slice
        self.exclude_special_tokens = exclude_special_tokens

        self.n_dataset_processed = 0

        self.estimated_norm_scaling_factor = None

        # Check if dataset is tokenized
        dataset_sample = next(iter(self.dataset))

        # check if it's tokenized
        if "tokens" in dataset_sample:
            self.is_dataset_tokenized = True
            self.tokens_column = "tokens"
        elif "input_ids" in dataset_sample:
            self.is_dataset_tokenized = True
            self.tokens_column = "input_ids"
        elif "text" in dataset_sample:
            self.is_dataset_tokenized = False
            self.tokens_column = "text"
        elif "problem" in dataset_sample:
            self.is_dataset_tokenized = False
            self.tokens_column = "problem"
        else:
            raise ValueError(
                "Dataset must have a 'tokens', 'input_ids', 'text', or 'problem' column."
            )
        if self.is_dataset_tokenized:
            ds_context_size = len(dataset_sample[self.tokens_column])
            if ds_context_size < self.context_size:
                raise ValueError(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}.
                    The context_size {ds_context_size} is expected to be larger than or equal to the provided context size {self.context_size}."""
                )
            if self.context_size != ds_context_size:
                warnings.warn(
                    f"""pretokenized dataset has context_size {ds_context_size}, but the provided context_size is {self.context_size}. Some data will be discarded in this case.""",
                    RuntimeWarning,
                )
            # TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement
            if hasattr(self.dataset, "set_format"):
                self.dataset.set_format(type="torch", columns=[self.tokens_column])  # type: ignore

            if (
                isinstance(dataset, str)
                and hasattr(model, "tokenizer")
                and model.tokenizer is not None
            ):
                validate_pretokenized_dataset_tokenizer(
                    dataset_path=dataset,
                    model_tokenizer=model.tokenizer,  # type: ignore
                )
        else:
            warnings.warn(
                "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
            )

        self.iterable_sequences = self._iterate_tokenized_sequences()

        self.cached_activation_dataset = self.load_cached_activation_dataset()

        # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

    def _iterate_raw_dataset(
        self,
    ) -> Generator[torch.Tensor | list[int] | str, None, None]:
        """
        Helper to iterate over the dataset while incrementing n_dataset_processed
        """
        for row in self.dataset:
            # typing datasets is difficult
            yield row[self.tokens_column]  # type: ignore
            self.n_dataset_processed += 1

    def _iterate_raw_dataset_tokens(self) -> Generator[torch.Tensor, None, None]:
        """
        Helper to create an iterator which tokenizes raw text from the dataset on the fly
        """
        for row in self._iterate_raw_dataset():
            tokens = (
                self.model.to_tokens(
                    row,
                    truncate=False,
                    move_to_device=False,  # we move to device below
                    prepend_bos=False,
                )  # type: ignore
                .squeeze(0)
                .to(self.device)
            )
            if len(tokens.shape) != 1:
                raise ValueError(f"tokens.shape should be 1D but was {tokens.shape}")
            yield tokens

    def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]:
        """
        Generator which iterates over full sequence of context_size tokens
        """
        # If the datset is pretokenized, we will slice the dataset to the length of the context window if needed. Otherwise, no further processing is needed.
        # We assume that all necessary BOS/EOS/SEP tokens have been added during pretokenization.
        if self.is_dataset_tokenized:
            for row in self._iterate_raw_dataset():
                yield torch.tensor(
                    row[
                        : self.context_size
                    ],  # If self.context_size = None, this line simply returns the whole row
                    dtype=torch.long,
                    device=self.device,
                    requires_grad=False,
                )
        # If the dataset isn't tokenized, we'll tokenize, concat, and batch on the fly
        else:
            tokenizer = getattr(self.model, "tokenizer", None)
            bos_token_id = None if tokenizer is None else tokenizer.bos_token_id
            yield from concat_and_batch_sequences(
                tokens_iterator=self._iterate_raw_dataset_tokens(),
                context_size=self.context_size,
                begin_batch_token_id=(bos_token_id if self.prepend_bos else None),
                begin_sequence_token_id=None,
                sequence_separator_token_id=(
                    bos_token_id if self.prepend_bos else None
                ),
            )

    def load_cached_activation_dataset(self) -> Dataset | None:
        """
        Load the cached activation dataset from disk.

        - If cached_activations_path is set, returns Huggingface Dataset else None
        - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
        """
        if self.cached_activations_path is None:
            return None

        assert self.cached_activations_path is not None  # keep pyright happy
        # Sanity check: does the cache directory exist?
        if not os.path.exists(self.cached_activations_path):
            raise FileNotFoundError(
                f"Cache directory {self.cached_activations_path} does not exist. "
                "Consider double-checking your dataset, model, and hook names."
            )

        # ---
        # Actual code
        activations_dataset = datasets.load_from_disk(self.cached_activations_path)
        columns = [self.hook_name]
        if "token_ids" in activations_dataset.column_names:
            columns.append("token_ids")
        activations_dataset.set_format(
            type="torch", columns=columns, device=self.device, dtype=self.dtype
        )
        self.current_row_idx = 0  # idx to load next batch from
        # ---

        assert isinstance(activations_dataset, Dataset)

        # multiple in hooks future
        if not set([self.hook_name]).issubset(activations_dataset.column_names):
            raise ValueError(
                f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
            )

        if activations_dataset.features[self.hook_name].shape != (
            self.context_size,
            self.d_in,
        ):
            raise ValueError(
                f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
            )

        return activations_dataset

    def set_norm_scaling_factor_if_needed(self):
        if (
            self.normalize_activations == "expected_average_only_in"
            and self.estimated_norm_scaling_factor is None
        ):
            self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()

    def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        if self.estimated_norm_scaling_factor is None:
            raise ValueError(
                "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
            )
        return activations * self.estimated_norm_scaling_factor

    def unscale(self, activations: torch.Tensor) -> torch.Tensor:
        if self.estimated_norm_scaling_factor is None:
            raise ValueError(
                "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
            )
        return activations / self.estimated_norm_scaling_factor

    def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        return (self.d_in**0.5) / activations.norm(dim=-1).mean()

    @torch.no_grad()
    def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
        norms_per_batch = []
        for _ in tqdm(
            range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
        ):
            # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
            self.estimated_norm_scaling_factor = 1.0
            acts = self.next_batch()[0]
            self.estimated_norm_scaling_factor = None
            norms_per_batch.append(acts.norm(dim=-1).mean().item())
        mean_norm = np.mean(norms_per_batch)
        return np.sqrt(self.d_in) / mean_norm

    def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
        """
        This applies a shuffle to the huggingface dataset that is the input to the activations store. This
        also shuffles the shards of the dataset, which is especially useful for evaluating on different
        sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
        The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
        additionally shuffle individual elements within the shard.
        """
        if isinstance(self.dataset, IterableDataset):
            self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
        else:
            self.dataset = self.dataset.shuffle(seed=seed)
        self.iterable_dataset = iter(self.dataset)

    def reset_input_dataset(self):
        """
        Resets the input dataset iterator to the beginning.
        """
        self.iterable_dataset = iter(self.dataset)

    @property
    def storage_buffer(self) -> torch.Tensor:
        if self._storage_buffer is None:
            self._storage_buffer = _filter_buffer_acts(
                self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
            )

        return self._storage_buffer

    @property
    def dataloader(self) -> Iterator[Any]:
        if self._dataloader is None:
            self._dataloader = self.get_data_loader()
        return self._dataloader

    def get_batch_tokens(
        self, batch_size: int | None = None, raise_at_epoch_end: bool = False
    ):
        """
        Streams a batch of tokens from a dataset.

        If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
        """
        if not batch_size:
            batch_size = self.store_batch_size_prompts
        sequences = []
        # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
        for _ in range(batch_size):
            try:
                sequences.append(next(self.iterable_sequences))
            except StopIteration:
                self.iterable_sequences = self._iterate_tokenized_sequences()
                if raise_at_epoch_end:
                    raise StopIteration(
                        f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                    )
                sequences.append(next(self.iterable_sequences))

        return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

    @torch.no_grad()
    def get_activations(self, batch_tokens: torch.Tensor):
        """
        Returns activations of shape (batches, context, num_layers, d_in)

        d_in may result from a concatenated head dimension.
        """

        # Setup autocast if using
        if self.autocast_lm:
            autocast_if_enabled = torch.autocast(
                device_type="cuda",
                dtype=torch.bfloat16,
                enabled=self.autocast_lm,
            )
        else:
            autocast_if_enabled = contextlib.nullcontext()

        with autocast_if_enabled:
            layerwise_activations_cache = self.model.run_with_cache(
                batch_tokens,
                names_filter=[self.hook_name],
                stop_at_layer=self.hook_layer + 1,
                prepend_bos=False,
                **self.model_kwargs,
            )[1]

        layerwise_activations = layerwise_activations_cache[self.hook_name][
            :, slice(*self.seqpos_slice)
        ]

        n_batches, n_context = layerwise_activations.shape[:2]

        stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

        if self.hook_head_index is not None:
            stacked_activations[:, :, 0] = layerwise_activations[
                :, :, self.hook_head_index
            ]
        elif layerwise_activations.ndim > 3:  # if we have a head dimension
            try:
                stacked_activations[:, :, 0] = layerwise_activations.view(
                    n_batches, n_context, -1
                )
            except RuntimeError as e:
                logger.error(f"Error during view operation: {e}")
                logger.info("Attempting to use reshape instead...")
                stacked_activations[:, :, 0] = layerwise_activations.reshape(
                    n_batches, n_context, -1
                )
        else:
            stacked_activations[:, :, 0] = layerwise_activations

        return stacked_activations

    def _load_buffer_from_cached(
        self,
        total_size: int,
        context_size: int,
        num_layers: int,
        d_in: int,
        raise_on_epoch_end: bool,
    ) -> tuple[
        Float[torch.Tensor, "(total_size context_size) num_layers d_in"],
        Int[torch.Tensor, "(total_size context_size)"] | None,
    ]:
        """
        Loads `total_size` activations from `cached_activation_dataset`

        The dataset has columns for each hook_name,
        each containing activations of shape (context_size, d_in).

        raises StopIteration
        """
        assert self.cached_activation_dataset is not None
        # In future, could be a list of multiple hook names
        hook_names = [self.hook_name]
        if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
            raise ValueError(
                f"Missing columns in dataset. Expected {hook_names}, "
                f"got {self.cached_activation_dataset.column_names}."
            )

        if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
            self.current_row_idx = 0
            if raise_on_epoch_end:
                raise StopIteration

        new_buffer = []
        ds_slice = self.cached_activation_dataset[
            self.current_row_idx : self.current_row_idx + total_size
        ]
        for hook_name in hook_names:
            # Load activations for each hook.
            # Usually faster to first slice dataset then pick column
            _hook_buffer = ds_slice[hook_name]
            if _hook_buffer.shape != (total_size, context_size, d_in):
                raise ValueError(
                    f"_hook_buffer has shape {_hook_buffer.shape}, "
                    f"but expected ({total_size}, {context_size}, {d_in})."
                )
            new_buffer.append(_hook_buffer)

        # Stack across num_layers dimension
        # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
        new_buffer = torch.stack(new_buffer, dim=2)
        if new_buffer.shape != (total_size, context_size, num_layers, d_in):
            raise ValueError(
                f"new_buffer has shape {new_buffer.shape}, "
                f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
            )

        self.current_row_idx += total_size
        acts_buffer = new_buffer.reshape(total_size * context_size, num_layers, d_in)

        if "token_ids" not in self.cached_activation_dataset.column_names:
            return acts_buffer, None

        token_ids_buffer = ds_slice["token_ids"]
        if token_ids_buffer.shape != (total_size, context_size):
            raise ValueError(
                f"token_ids_buffer has shape {token_ids_buffer.shape}, "
                f"but expected ({total_size}, {context_size})."
            )
        token_ids_buffer = token_ids_buffer.reshape(total_size * context_size)
        return acts_buffer, token_ids_buffer

    @torch.no_grad()
    def get_buffer(
        self,
        n_batches_in_buffer: int,
        raise_on_epoch_end: bool = False,
        shuffle: bool = True,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """
        Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.

        The primary purpose here is maintaining a shuffling buffer.

        If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
        """
        context_size = self.context_size
        training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
        batch_size = self.store_batch_size_prompts
        d_in = self.d_in
        total_size = batch_size * n_batches_in_buffer
        num_layers = 1

        if self.cached_activation_dataset is not None:
            return self._load_buffer_from_cached(
                total_size, context_size, num_layers, d_in, raise_on_epoch_end
            )

        refill_iterator = range(0, total_size, batch_size)
        # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
        new_buffer_activations = torch.zeros(
            (total_size, training_context_size, num_layers, d_in),
            dtype=self.dtype,  # type: ignore
            device=self.device,
        )
        new_buffer_token_ids = torch.zeros(
            (total_size, training_context_size),
            dtype=torch.long,
            device=self.device,
        )

        for refill_batch_idx_start in tqdm(
            refill_iterator, leave=False, desc="Refilling buffer"
        ):
            # move batch toks to gpu for model
            refill_batch_tokens = self.get_batch_tokens(
                raise_at_epoch_end=raise_on_epoch_end
            ).to(_get_model_device(self.model))
            refill_activations = self.get_activations(refill_batch_tokens)
            # move acts back to cpu
            refill_activations.to(self.device)
            new_buffer_activations[
                refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
            ] = refill_activations

            # handle seqpos_slice, this is done for activations in get_activations
            refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
            new_buffer_token_ids[
                refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
            ] = refill_batch_tokens

        new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
        new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
        if shuffle:
            new_buffer_activations, new_buffer_token_ids = permute_together(
                [new_buffer_activations, new_buffer_token_ids]
            )

        # every buffer should be normalized:
        if self.normalize_activations == "expected_average_only_in":
            new_buffer_activations = self.apply_norm_scaling_factor(
                new_buffer_activations
            )

        return (
            new_buffer_activations,
            new_buffer_token_ids,
        )

    def get_data_loader(
        self,
    ) -> Iterator[Any]:
        """
        Return a torch.utils.dataloader which you can get batches from.

        Should automatically refill the buffer when it gets to n % full.
        (better mixing if you refill and shuffle regularly).

        """

        batch_size = self.train_batch_size_tokens

        try:
            new_samples = _filter_buffer_acts(
                self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
                self.exclude_special_tokens,
            )
        except StopIteration:
            warnings.warn(
                "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
            )
            self._storage_buffer = (
                None  # dump the current buffer so samples do not leak between epochs
            )
            try:
                new_samples = _filter_buffer_acts(
                    self.get_buffer(self.half_buffer_size),
                    self.exclude_special_tokens,
                )
            except StopIteration:
                raise ValueError(
                    "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
                )

        # 1. # create new buffer by mixing stored and new buffer
        mixing_buffer = torch.cat(
            [new_samples, self.storage_buffer],
            dim=0,
        )

        mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

        # 2.  put 50 % in storage
        self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

        # 3. put other 50 % in a dataloader
        return iter(
            DataLoader(
                # TODO: seems like a typing bug?
                cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
                batch_size=batch_size,
                shuffle=True,
            )
        )

    def next_batch(self) -> torch.Tensor:
        """
        Get the next batch from the current DataLoader.
        If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
        """
        try:
            # Try to get the next batch
            return next(self.dataloader)
        except StopIteration:
            # If the DataLoader is exhausted, create a new one
            self._dataloader = self.get_data_loader()
            return next(self.dataloader)

    def state_dict(self) -> dict[str, torch.Tensor]:
        result = {
            "n_dataset_processed": torch.tensor(self.n_dataset_processed),
        }
        if self._storage_buffer is not None:  # first time might be None
            result["storage_buffer_activations"] = self._storage_buffer[0]
            if self._storage_buffer[1] is not None:
                result["storage_buffer_tokens"] = self._storage_buffer[1]
        if self.estimated_norm_scaling_factor is not None:
            result["estimated_norm_scaling_factor"] = torch.tensor(
                self.estimated_norm_scaling_factor
            )
        return result

    def save(self, file_path: str):
        """save the state dict to a file in safetensors format"""
        save_file(self.state_dict(), file_path)

from_cache_activations(model, cfg) classmethod

Public api to create an ActivationsStore from a cached activations dataset.

Source code in sae_lens/training/activations_store.py
@classmethod
def from_cache_activations(
    cls,
    model: HookedRootModule,
    cfg: CacheActivationsRunnerConfig,
) -> ActivationsStore:
    """
    Public api to create an ActivationsStore from a cached activations dataset.
    """
    return cls(
        cached_activations_path=cfg.new_cached_activations_path,
        dtype=cfg.dtype,
        hook_name=cfg.hook_name,
        hook_layer=cfg.hook_layer,
        context_size=cfg.context_size,
        d_in=cfg.d_in,
        n_batches_in_buffer=cfg.n_batches_in_buffer,
        total_training_tokens=cfg.training_tokens,
        store_batch_size_prompts=cfg.model_batch_size,  # get_buffer
        train_batch_size_tokens=cfg.model_batch_size,  # dataloader
        seqpos_slice=(None,),
        device=torch.device(cfg.device),  # since we're sending these to SAE
        # NOOP
        prepend_bos=False,
        hook_head_index=None,
        dataset=cfg.dataset_path,
        streaming=False,
        model=model,
        normalize_activations="none",
        model_kwargs=None,
        autocast_lm=False,
        dataset_trust_remote_code=None,
        exclude_special_tokens=None,
    )

get_activations(batch_tokens)

Returns activations of shape (batches, context, num_layers, d_in)

d_in may result from a concatenated head dimension.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
    """
    Returns activations of shape (batches, context, num_layers, d_in)

    d_in may result from a concatenated head dimension.
    """

    # Setup autocast if using
    if self.autocast_lm:
        autocast_if_enabled = torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=self.autocast_lm,
        )
    else:
        autocast_if_enabled = contextlib.nullcontext()

    with autocast_if_enabled:
        layerwise_activations_cache = self.model.run_with_cache(
            batch_tokens,
            names_filter=[self.hook_name],
            stop_at_layer=self.hook_layer + 1,
            prepend_bos=False,
            **self.model_kwargs,
        )[1]

    layerwise_activations = layerwise_activations_cache[self.hook_name][
        :, slice(*self.seqpos_slice)
    ]

    n_batches, n_context = layerwise_activations.shape[:2]

    stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

    if self.hook_head_index is not None:
        stacked_activations[:, :, 0] = layerwise_activations[
            :, :, self.hook_head_index
        ]
    elif layerwise_activations.ndim > 3:  # if we have a head dimension
        try:
            stacked_activations[:, :, 0] = layerwise_activations.view(
                n_batches, n_context, -1
            )
        except RuntimeError as e:
            logger.error(f"Error during view operation: {e}")
            logger.info("Attempting to use reshape instead...")
            stacked_activations[:, :, 0] = layerwise_activations.reshape(
                n_batches, n_context, -1
            )
    else:
        stacked_activations[:, :, 0] = layerwise_activations

    return stacked_activations

get_batch_tokens(batch_size=None, raise_at_epoch_end=False)

Streams a batch of tokens from a dataset.

If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.

Source code in sae_lens/training/activations_store.py
def get_batch_tokens(
    self, batch_size: int | None = None, raise_at_epoch_end: bool = False
):
    """
    Streams a batch of tokens from a dataset.

    If raise_at_epoch_end is true we will reset the dataset at the end of each epoch and raise a StopIteration. Otherwise we will reset silently.
    """
    if not batch_size:
        batch_size = self.store_batch_size_prompts
    sequences = []
    # the sequences iterator yields fully formed tokens of size context_size, so we just need to cat these into a batch
    for _ in range(batch_size):
        try:
            sequences.append(next(self.iterable_sequences))
        except StopIteration:
            self.iterable_sequences = self._iterate_tokenized_sequences()
            if raise_at_epoch_end:
                raise StopIteration(
                    f"Ran out of tokens in dataset after {self.n_dataset_processed} samples, beginning the next epoch."
                )
            sequences.append(next(self.iterable_sequences))

    return torch.stack(sequences, dim=0).to(_get_model_device(self.model))

get_buffer(n_batches_in_buffer, raise_on_epoch_end=False, shuffle=True)

Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.

The primary purpose here is maintaining a shuffling buffer.

If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_buffer(
    self,
    n_batches_in_buffer: int,
    raise_on_epoch_end: bool = False,
    shuffle: bool = True,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """
    Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.

    The primary purpose here is maintaining a shuffling buffer.

    If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
    """
    context_size = self.context_size
    training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
    batch_size = self.store_batch_size_prompts
    d_in = self.d_in
    total_size = batch_size * n_batches_in_buffer
    num_layers = 1

    if self.cached_activation_dataset is not None:
        return self._load_buffer_from_cached(
            total_size, context_size, num_layers, d_in, raise_on_epoch_end
        )

    refill_iterator = range(0, total_size, batch_size)
    # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
    new_buffer_activations = torch.zeros(
        (total_size, training_context_size, num_layers, d_in),
        dtype=self.dtype,  # type: ignore
        device=self.device,
    )
    new_buffer_token_ids = torch.zeros(
        (total_size, training_context_size),
        dtype=torch.long,
        device=self.device,
    )

    for refill_batch_idx_start in tqdm(
        refill_iterator, leave=False, desc="Refilling buffer"
    ):
        # move batch toks to gpu for model
        refill_batch_tokens = self.get_batch_tokens(
            raise_at_epoch_end=raise_on_epoch_end
        ).to(_get_model_device(self.model))
        refill_activations = self.get_activations(refill_batch_tokens)
        # move acts back to cpu
        refill_activations.to(self.device)
        new_buffer_activations[
            refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
        ] = refill_activations

        # handle seqpos_slice, this is done for activations in get_activations
        refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
        new_buffer_token_ids[
            refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
        ] = refill_batch_tokens

    new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
    new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
    if shuffle:
        new_buffer_activations, new_buffer_token_ids = permute_together(
            [new_buffer_activations, new_buffer_token_ids]
        )

    # every buffer should be normalized:
    if self.normalize_activations == "expected_average_only_in":
        new_buffer_activations = self.apply_norm_scaling_factor(
            new_buffer_activations
        )

    return (
        new_buffer_activations,
        new_buffer_token_ids,
    )

get_data_loader()

Return a torch.utils.dataloader which you can get batches from.

Should automatically refill the buffer when it gets to n % full. (better mixing if you refill and shuffle regularly).

Source code in sae_lens/training/activations_store.py
def get_data_loader(
    self,
) -> Iterator[Any]:
    """
    Return a torch.utils.dataloader which you can get batches from.

    Should automatically refill the buffer when it gets to n % full.
    (better mixing if you refill and shuffle regularly).

    """

    batch_size = self.train_batch_size_tokens

    try:
        new_samples = _filter_buffer_acts(
            self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
            self.exclude_special_tokens,
        )
    except StopIteration:
        warnings.warn(
            "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
        )
        self._storage_buffer = (
            None  # dump the current buffer so samples do not leak between epochs
        )
        try:
            new_samples = _filter_buffer_acts(
                self.get_buffer(self.half_buffer_size),
                self.exclude_special_tokens,
            )
        except StopIteration:
            raise ValueError(
                "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
            )

    # 1. # create new buffer by mixing stored and new buffer
    mixing_buffer = torch.cat(
        [new_samples, self.storage_buffer],
        dim=0,
    )

    mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

    # 2.  put 50 % in storage
    self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

    # 3. put other 50 % in a dataloader
    return iter(
        DataLoader(
            # TODO: seems like a typing bug?
            cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
            batch_size=batch_size,
            shuffle=True,
        )
    )

load_cached_activation_dataset()

Load the cached activation dataset from disk.

  • If cached_activations_path is set, returns Huggingface Dataset else None
  • Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
Source code in sae_lens/training/activations_store.py
def load_cached_activation_dataset(self) -> Dataset | None:
    """
    Load the cached activation dataset from disk.

    - If cached_activations_path is set, returns Huggingface Dataset else None
    - Checks that the loaded dataset has current has activations for hooks in config and that shapes match.
    """
    if self.cached_activations_path is None:
        return None

    assert self.cached_activations_path is not None  # keep pyright happy
    # Sanity check: does the cache directory exist?
    if not os.path.exists(self.cached_activations_path):
        raise FileNotFoundError(
            f"Cache directory {self.cached_activations_path} does not exist. "
            "Consider double-checking your dataset, model, and hook names."
        )

    # ---
    # Actual code
    activations_dataset = datasets.load_from_disk(self.cached_activations_path)
    columns = [self.hook_name]
    if "token_ids" in activations_dataset.column_names:
        columns.append("token_ids")
    activations_dataset.set_format(
        type="torch", columns=columns, device=self.device, dtype=self.dtype
    )
    self.current_row_idx = 0  # idx to load next batch from
    # ---

    assert isinstance(activations_dataset, Dataset)

    # multiple in hooks future
    if not set([self.hook_name]).issubset(activations_dataset.column_names):
        raise ValueError(
            f"loaded dataset does not include hook activations, got {activations_dataset.column_names}"
        )

    if activations_dataset.features[self.hook_name].shape != (
        self.context_size,
        self.d_in,
    ):
        raise ValueError(
            f"Given dataset of shape {activations_dataset.features[self.hook_name].shape} does not match context_size ({self.context_size}) and d_in ({self.d_in})"
        )

    return activations_dataset

next_batch()

Get the next batch from the current DataLoader. If the DataLoader is exhausted, refill the buffer and create a new DataLoader.

Source code in sae_lens/training/activations_store.py
def next_batch(self) -> torch.Tensor:
    """
    Get the next batch from the current DataLoader.
    If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
    """
    try:
        # Try to get the next batch
        return next(self.dataloader)
    except StopIteration:
        # If the DataLoader is exhausted, create a new one
        self._dataloader = self.get_data_loader()
        return next(self.dataloader)

reset_input_dataset()

Resets the input dataset iterator to the beginning.

Source code in sae_lens/training/activations_store.py
def reset_input_dataset(self):
    """
    Resets the input dataset iterator to the beginning.
    """
    self.iterable_dataset = iter(self.dataset)

save(file_path)

save the state dict to a file in safetensors format

Source code in sae_lens/training/activations_store.py
def save(self, file_path: str):
    """save the state dict to a file in safetensors format"""
    save_file(self.state_dict(), file_path)

shuffle_input_dataset(seed, buffer_size=1)

This applies a shuffle to the huggingface dataset that is the input to the activations store. This also shuffles the shards of the dataset, which is especially useful for evaluating on different sections of very large streaming datasets. Buffer size is only relevant for streaming datasets. The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will additionally shuffle individual elements within the shard.

Source code in sae_lens/training/activations_store.py
def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
    """
    This applies a shuffle to the huggingface dataset that is the input to the activations store. This
    also shuffles the shards of the dataset, which is especially useful for evaluating on different
    sections of very large streaming datasets. Buffer size is only relevant for streaming datasets.
    The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will
    additionally shuffle individual elements within the shard.
    """
    if isinstance(self.dataset, IterableDataset):
        self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size)
    else:
        self.dataset = self.dataset.shuffle(seed=seed)
    self.iterable_dataset = iter(self.dataset)

CacheActivationsRunner

Source code in sae_lens/cache_activations_runner.py
class CacheActivationsRunner:
    def __init__(
        self,
        cfg: CacheActivationsRunnerConfig,
        override_dataset: Dataset | None = None,
    ):
        self.cfg = cfg
        self.model: HookedRootModule = load_model(
            model_class_name=self.cfg.model_class_name,
            model_name=self.cfg.model_name,
            device=self.cfg.device,
            model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
        )
        if self.cfg.compile_llm:
            self.model = torch.compile(self.model, mode=self.cfg.llm_compilation_mode)  # type: ignore
        self.activations_store = _mk_activations_store(
            self.model,
            self.cfg,
            override_dataset=override_dataset,
        )
        self.context_size = self._get_sliced_context_size(
            self.cfg.context_size, self.cfg.seqpos_slice
        )
        features_dict: dict[str, Array2D | Sequence] = {
            hook_name: Array2D(
                shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype
            )
            for hook_name in [self.cfg.hook_name]
        }
        features_dict["token_ids"] = Sequence(
            Value(dtype="int32"), length=self.context_size
        )
        self.features = Features(features_dict)

    def __str__(self):
        """
        Print the number of tokens to be cached.
        Print the number of buffers, and the number of tokens per buffer.
        Print the disk space required to store the activations.

        """

        bytes_per_token = (
            self.cfg.d_in * self.cfg.dtype.itemsize
            if isinstance(self.cfg.dtype, torch.dtype)
            else DTYPE_MAP[self.cfg.dtype].itemsize
        )
        total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
        total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

        return (
            f"Activation Cache Runner:\n"
            f"Total training tokens: {total_training_tokens}\n"
            f"Number of buffers: {self.cfg.n_buffers}\n"
            f"Tokens per buffer: {self.cfg.n_tokens_in_buffer}\n"
            f"Disk space required: {total_disk_space_gb:.2f} GB\n"
            f"Configuration:\n"
            f"{self.cfg}"
        )

    @staticmethod
    def _consolidate_shards(
        source_dir: Path, output_dir: Path, copy_files: bool = True
    ) -> Dataset:
        """Consolidate sharded datasets into a single directory without rewriting data.

        Each of the shards must be of the same format, aka the full dataset must be able to
        be recreated like so:

        ```
        ds = concatenate_datasets(
            [Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
        )

        ```

        Sharded dataset format:
        ```
        source_dir/
            shard_00000/
                dataset_info.json
                state.json
                data-00000-of-00002.arrow
                data-00001-of-00002.arrow
            shard_00001/
                dataset_info.json
                state.json
                data-00000-of-00001.arrow
        ```

        And flattens them into the format:

        ```
        output_dir/
            dataset_info.json
            state.json
            data-00000-of-00003.arrow
            data-00001-of-00003.arrow
            data-00002-of-00003.arrow
        ```

        allowing the dataset to be loaded like so:

        ```
        ds = datasets.load_from_disk(output_dir)
        ```

        Args:
            source_dir: Directory containing the sharded datasets
            output_dir: Directory to consolidate the shards into
            copy_files: If True, copy files; if False, move them and delete source_dir
        """
        first_shard_dir_name = "shard_00000"  # shard_{i:05d}

        if not source_dir.exists() or not source_dir.is_dir():
            raise NotADirectoryError(
                f"source_dir is not an existing directory: {source_dir}"
            )

        if not output_dir.exists() or not output_dir.is_dir():
            raise NotADirectoryError(
                f"output_dir is not an existing directory: {output_dir}"
            )

        other_items = [p for p in output_dir.iterdir() if p.name != ".tmp_shards"]
        if other_items:
            raise FileExistsError(
                f"output_dir must be empty (besides .tmp_shards). Found: {other_items}"
            )

        if not (source_dir / first_shard_dir_name).exists():
            raise Exception(f"No shards in {source_dir} exist!")

        transfer_fn = shutil.copy2 if copy_files else shutil.move

        # Move dataset_info.json from any shard (all the same)
        transfer_fn(
            source_dir / first_shard_dir_name / "dataset_info.json",
            output_dir / "dataset_info.json",
        )

        arrow_files = []
        file_count = 0

        for shard_dir in sorted(source_dir.iterdir()):
            if not shard_dir.name.startswith("shard_"):
                continue

            # state.json contains arrow filenames
            state = json.loads((shard_dir / "state.json").read_text())

            for data_file in state["_data_files"]:
                src = shard_dir / data_file["filename"]
                new_name = f"data-{file_count:05d}-of-{len(list(source_dir.iterdir())):05d}.arrow"
                dst = output_dir / new_name
                transfer_fn(src, dst)
                arrow_files.append({"filename": new_name})
                file_count += 1

        new_state = {
            "_data_files": arrow_files,
            "_fingerprint": None,  # temporary
            "_format_columns": None,
            "_format_kwargs": {},
            "_format_type": None,
            "_output_all_columns": False,
            "_split": None,
        }

        # fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
        with open(output_dir / "state.json", "w") as f:
            json.dump(new_state, f, indent=2)

        ds = Dataset.load_from_disk(str(output_dir))
        fingerprint = generate_fingerprint(ds)
        del ds

        with open(output_dir / "state.json", "r+") as f:
            state = json.loads(f.read())
            state["_fingerprint"] = fingerprint
            f.seek(0)
            json.dump(state, f, indent=2)
            f.truncate()

        if not copy_files:  # cleanup source dir
            shutil.rmtree(source_dir)

        return Dataset.load_from_disk(output_dir)

    @torch.no_grad()
    def run(self) -> Dataset:
        activation_save_path = self.cfg.new_cached_activations_path
        assert activation_save_path is not None

        ### Paths setup
        final_cached_activation_path = Path(activation_save_path)
        final_cached_activation_path.mkdir(exist_ok=True, parents=True)
        if any(final_cached_activation_path.iterdir()):
            raise Exception(
                f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
            )

        tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
        tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)

        ### Create temporary sharded datasets

        logger.info(f"Started caching activations for {self.cfg.dataset_path}")

        for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
            try:
                buffer = self.activations_store.get_buffer(
                    self.cfg.n_batches_in_buffer, shuffle=False
                )
                shard = self._create_shard(buffer)
                shard.save_to_disk(
                    f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
                )
                del buffer, shard
            except StopIteration:
                logger.warning(
                    f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
                )
                break

        ### Concatenate shards and push to Huggingface Hub

        dataset = self._consolidate_shards(
            tmp_cached_activation_path, final_cached_activation_path, copy_files=False
        )

        if self.cfg.shuffle:
            logger.info("Shuffling...")
            dataset = dataset.shuffle(seed=self.cfg.seed)

        if self.cfg.hf_repo_id:
            logger.info("Pushing to Huggingface Hub...")
            dataset.push_to_hub(
                repo_id=self.cfg.hf_repo_id,
                num_shards=self.cfg.hf_num_shards,
                private=self.cfg.hf_is_private_repo,
                revision=self.cfg.hf_revision,
            )

            meta_io = io.BytesIO()
            meta_contents = json.dumps(
                asdict(self.cfg), indent=2, ensure_ascii=False
            ).encode("utf-8")
            meta_io.write(meta_contents)
            meta_io.seek(0)

            api = HfApi()
            api.upload_file(
                path_or_fileobj=meta_io,
                path_in_repo="cache_activations_runner_cfg.json",
                repo_id=self.cfg.hf_repo_id,
                repo_type="dataset",
                commit_message="Add cache_activations_runner metadata",
            )

        return dataset

    def _create_shard(
        self,
        buffer: tuple[
            Float[torch.Tensor, "(bs context_size) num_layers d_in"],
            Int[torch.Tensor, "(bs context_size)"] | None,
        ],
    ) -> Dataset:
        hook_names = [self.cfg.hook_name]
        acts, token_ids = buffer
        acts = einops.rearrange(
            acts,
            "(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
            bs=self.cfg.n_seq_in_buffer,
            context_size=self.context_size,
            d_in=self.cfg.d_in,
            num_layers=len(hook_names),
        )
        shard_dict = {hook_name: act for hook_name, act in zip(hook_names, acts)}

        if token_ids is not None:
            token_ids = einops.rearrange(
                token_ids,
                "(bs context_size) -> bs context_size",
                bs=self.cfg.n_seq_in_buffer,
                context_size=self.context_size,
            )
            shard_dict["token_ids"] = token_ids.to(torch.int32)
        return Dataset.from_dict(
            shard_dict,
            features=self.features,
        )

    @staticmethod
    def _get_sliced_context_size(
        context_size: int, seqpos_slice: tuple[int | None, ...] | None
    ) -> int:
        if seqpos_slice is not None:
            context_size = len(range(context_size)[slice(*seqpos_slice)])
        return context_size

__str__()

Print the number of tokens to be cached. Print the number of buffers, and the number of tokens per buffer. Print the disk space required to store the activations.

Source code in sae_lens/cache_activations_runner.py
def __str__(self):
    """
    Print the number of tokens to be cached.
    Print the number of buffers, and the number of tokens per buffer.
    Print the disk space required to store the activations.

    """

    bytes_per_token = (
        self.cfg.d_in * self.cfg.dtype.itemsize
        if isinstance(self.cfg.dtype, torch.dtype)
        else DTYPE_MAP[self.cfg.dtype].itemsize
    )
    total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
    total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

    return (
        f"Activation Cache Runner:\n"
        f"Total training tokens: {total_training_tokens}\n"
        f"Number of buffers: {self.cfg.n_buffers}\n"
        f"Tokens per buffer: {self.cfg.n_tokens_in_buffer}\n"
        f"Disk space required: {total_disk_space_gb:.2f} GB\n"
        f"Configuration:\n"
        f"{self.cfg}"
    )

CacheActivationsRunnerConfig dataclass

Configuration for creating and caching activations of an LLM.

Parameters:

Name Type Description Default
dataset_path str

The path to the Hugging Face dataset. This may be tokenized or not.

required
model_name str

The name of the model to use.

required
model_batch_size int

How many prompts are in the batch of the language model when generating activations.

required
hook_name str

The name of the hook to use.

required
hook_layer int

The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.

required
d_in int

Dimension of the model.

required
total_training_tokens int

Total number of tokens to process.

required
context_size int

Context size to process. Can be left as -1 if the dataset is tokenized.

-1
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
new_cached_activations_path str

The path to save the activations.

None
shuffle bool

Whether to shuffle the dataset.

True
seed int

The seed to use for shuffling.

42
dtype str

Datatype of activations to be stored.

'float32'
device str

The device for the model.

'cuda' if is_available() else 'cpu'
buffer_size_gb float

The buffer size in GB. This should be < 2GB.

2.0
hf_repo_id str

The Hugging Face repository id to save the activations to.

None
hf_num_shards int

The number of shards to save the activations to.

None
hf_revision str

The revision to save the activations to.

'main'
hf_is_private_repo bool

Whether the Hugging Face repository is private.

False
model_kwargs dict

Keyword arguments for model.run_with_cache.

dict()
model_from_pretrained_kwargs dict

Keyword arguments for the from_pretrained method of the model.

dict()
compile_llm bool

Whether to compile the LLM.

False
llm_compilation_mode str

The torch.compile mode to use.

None
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
seqpos_slice tuple

Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.

(None,)
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
autocast_lm bool

Whether to use autocast during activation fetching.

False
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

None
Source code in sae_lens/config.py
@dataclass
class CacheActivationsRunnerConfig:
    """
    Configuration for creating and caching activations of an LLM.

    Args:
        dataset_path (str): The path to the Hugging Face dataset. This may be tokenized or not.
        model_name (str): The name of the model to use.
        model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
        hook_name (str): The name of the hook to use.
        hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
        d_in (int): Dimension of the model.
        total_training_tokens (int): Total number of tokens to process.
        context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        new_cached_activations_path (str, optional): The path to save the activations.
        shuffle (bool): Whether to shuffle the dataset.
        seed (int): The seed to use for shuffling.
        dtype (str): Datatype of activations to be stored.
        device (str): The device for the model.
        buffer_size_gb (float): The buffer size in GB. This should be < 2GB.
        hf_repo_id (str, optional): The Hugging Face repository id to save the activations to.
        hf_num_shards (int, optional): The number of shards to save the activations to.
        hf_revision (str): The revision to save the activations to.
        hf_is_private_repo (bool): Whether the Hugging Face repository is private.
        model_kwargs (dict): Keyword arguments for `model.run_with_cache`.
        model_from_pretrained_kwargs (dict): Keyword arguments for the `from_pretrained` method of the model.
        compile_llm (bool): Whether to compile the LLM.
        llm_compilation_mode (str): The torch.compile mode to use.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        autocast_lm (bool): Whether to use autocast during activation fetching.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
    """

    dataset_path: str
    model_name: str
    model_batch_size: int
    hook_name: str
    hook_layer: int
    d_in: int
    training_tokens: int

    context_size: int = -1  # Required if dataset is not tokenized
    model_class_name: str = "HookedTransformer"
    # defaults to "activations/{dataset}/{model}/{hook_name}
    new_cached_activations_path: str | None = None
    shuffle: bool = True
    seed: int = 42
    dtype: str = "float32"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    buffer_size_gb: float = 2.0  # HF datasets writer have problems with shards > 2GB

    # Huggingface Integration
    hf_repo_id: str | None = None
    hf_num_shards: int | None = None
    hf_revision: str = "main"
    hf_is_private_repo: bool = False

    # Model
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
    compile_llm: bool = False
    llm_compilation_mode: str | None = None  # which torch.compile mode to use

    # Activation Store
    prepend_bos: bool = True
    seqpos_slice: tuple[int | None, ...] = (None,)
    streaming: bool = True
    autocast_lm: bool = False
    dataset_trust_remote_code: bool | None = None

    def __post_init__(self):
        # Automatically determine context_size if dataset is tokenized
        if self.context_size == -1:
            ds = load_dataset(self.dataset_path, split="train", streaming=True)
            assert isinstance(ds, IterableDataset)
            first_sample = next(iter(ds))
            toks = first_sample.get("tokens") or first_sample.get("input_ids") or None
            if toks is None:
                raise ValueError(
                    "Dataset is not tokenized. Please specify context_size."
                )
            token_length = len(toks)
            self.context_size = token_length

        if self.context_size == -1:
            raise ValueError("context_size is still -1 after dataset inspection.")

        if self.seqpos_slice is not None:
            _validate_seqpos(
                seqpos=self.seqpos_slice,
                context_size=self.context_size,
            )

        if self.new_cached_activations_path is None:
            self.new_cached_activations_path = _default_cached_activations_path(  # type: ignore
                self.dataset_path, self.model_name, self.hook_name, None
            )

    @property
    def sliced_context_size(self) -> int:
        if self.seqpos_slice is not None:
            return len(range(self.context_size)[slice(*self.seqpos_slice)])
        return self.context_size

    @property
    def bytes_per_token(self) -> int:
        return self.d_in * DTYPE_MAP[self.dtype].itemsize

    @property
    def n_tokens_in_buffer(self) -> int:
        # Calculate raw tokens per buffer based on memory constraints
        _tokens_per_buffer = int(self.buffer_size_gb * 1e9) // self.bytes_per_token
        # Round down to nearest multiple of batch_token_size
        return _tokens_per_buffer - (_tokens_per_buffer % self.n_tokens_in_batch)

    @property
    def n_tokens_in_batch(self) -> int:
        return self.model_batch_size * self.sliced_context_size

    @property
    def n_batches_in_buffer(self) -> int:
        return self.n_tokens_in_buffer // self.n_tokens_in_batch

    @property
    def n_seq_in_dataset(self) -> int:
        return self.training_tokens // self.sliced_context_size

    @property
    def n_seq_in_buffer(self) -> int:
        return self.n_tokens_in_buffer // self.sliced_context_size

    @property
    def n_buffers(self) -> int:
        return math.ceil(self.training_tokens / self.n_tokens_in_buffer)

GatedSAE

Bases: SAE[GatedSAEConfig]

GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a gated linear encoder and a standard linear decoder.

Source code in sae_lens/saes/gated_sae.py
class GatedSAE(SAE[GatedSAEConfig]):
    """
    GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a gated linear encoder and a standard linear decoder.
    """

    b_gate: nn.Parameter
    b_mag: nn.Parameter
    r_mag: nn.Parameter

    def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)
        # Ensure b_enc does not exist for the gated architecture
        self.b_enc = None

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_gated(self)

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Encode the input tensor into the feature space using a gated encoder.
        This must match the original encode_gated implementation from SAE class.
        """
        # Preprocess the SAE input (casting type, applying hooks, normalization)
        sae_in = self.process_sae_in(x)

        # Gating path exactly as in original SAE.encode_gated
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path (weight sharing with gated encoder)
        magnitude_pre_activation = self.hook_sae_acts_pre(
            sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        )
        feature_magnitudes = self.activation_fn(magnitude_pre_activation)

        # Combine gating and magnitudes
        return self.hook_sae_acts_post(active_features * feature_magnitudes)

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Decode the feature activations back into the input space:
          1) Apply optional finetuning scaling.
          2) Linear transform plus bias.
          3) Run any reconstruction hooks and out-normalization if configured.
          4) If the SAE was reshaping hook_z activations, reshape back.
        """
        # 1) optional finetuning scaling
        # 2) linear transform
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        # 3) hooking and normalization
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        # 4) reshape if needed (hook_z)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """Override to handle gated-specific parameters."""
        W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T

        # Gated-specific parameters need special handling
        self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
        self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
        self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

    @torch.no_grad()
    def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
        """Initialize decoder with constant norm."""
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
        self.W_dec.data *= norm

decode(feature_acts)

Decode the feature activations back into the input space

1) Apply optional finetuning scaling. 2) Linear transform plus bias. 3) Run any reconstruction hooks and out-normalization if configured. 4) If the SAE was reshaping hook_z activations, reshape back.

Source code in sae_lens/saes/gated_sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """
    Decode the feature activations back into the input space:
      1) Apply optional finetuning scaling.
      2) Linear transform plus bias.
      3) Run any reconstruction hooks and out-normalization if configured.
      4) If the SAE was reshaping hook_z activations, reshape back.
    """
    # 1) optional finetuning scaling
    # 2) linear transform
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    # 3) hooking and normalization
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    # 4) reshape if needed (hook_z)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space using a gated encoder. This must match the original encode_gated implementation from SAE class.

Source code in sae_lens/saes/gated_sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Encode the input tensor into the feature space using a gated encoder.
    This must match the original encode_gated implementation from SAE class.
    """
    # Preprocess the SAE input (casting type, applying hooks, normalization)
    sae_in = self.process_sae_in(x)

    # Gating path exactly as in original SAE.encode_gated
    gating_pre_activation = sae_in @ self.W_enc + self.b_gate
    active_features = (gating_pre_activation > 0).to(self.dtype)

    # Magnitude path (weight sharing with gated encoder)
    magnitude_pre_activation = self.hook_sae_acts_pre(
        sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
    )
    feature_magnitudes = self.activation_fn(magnitude_pre_activation)

    # Combine gating and magnitudes
    return self.hook_sae_acts_post(active_features * feature_magnitudes)

fold_W_dec_norm()

Override to handle gated-specific parameters.

Source code in sae_lens/saes/gated_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """Override to handle gated-specific parameters."""
    W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
    self.W_dec.data = self.W_dec.data / W_dec_norms
    self.W_enc.data = self.W_enc.data * W_dec_norms.T

    # Gated-specific parameters need special handling
    self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
    self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
    self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()

initialize_decoder_norm_constant_norm(norm=0.1)

Initialize decoder with constant norm.

Source code in sae_lens/saes/gated_sae.py
@torch.no_grad()
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
    """Initialize decoder with constant norm."""
    self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
    self.W_dec.data *= norm

GatedSAEConfig dataclass

Bases: SAEConfig

Configuration class for a GatedSAE.

Source code in sae_lens/saes/gated_sae.py
@dataclass
class GatedSAEConfig(SAEConfig):
    """
    Configuration class for a GatedSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "gated"

GatedTrainingSAE

Bases: TrainingSAE[GatedTrainingSAEConfig]

GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture. It implements: - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init. - encode: calls encode_with_hidden_pre (standard training approach). - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE. - encode_with_hidden_pre: gating logic + optional noise injection for training. - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty. - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.

Source code in sae_lens/saes/gated_sae.py
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
    """
    GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
    It implements:
      - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
      - encode: calls encode_with_hidden_pre (standard training approach).
      - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
      - encode_with_hidden_pre: gating logic + optional noise injection for training.
      - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
      - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
    """

    b_gate: nn.Parameter  # type: ignore
    b_mag: nn.Parameter  # type: ignore
    r_mag: nn.Parameter  # type: ignore

    def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
        if use_error_term:
            raise ValueError(
                "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
            )
        super().__init__(cfg, use_error_term)

    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_gated(self)

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        """
        Gated forward pass with pre-activation (for training).
        We also inject noise if self.training is True.
        """
        sae_in = self.process_sae_in(x)

        # Gating path
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).to(self.dtype)

        # Magnitude path
        magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        if self.training and self.cfg.noise_scale > 0:
            magnitude_pre_activation += (
                torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
            )
        magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)

        feature_magnitudes = self.activation_fn(magnitude_pre_activation)

        # Combine gating path and magnitude path
        feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

        # Return both the final feature activations and the pre-activation (for logging or penalty)
        return feature_acts, magnitude_pre_activation

    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # Re-center the input if apply_b_dec_to_input is set
        sae_in_centered = step_input.sae_in - (
            self.b_dec * self.cfg.apply_b_dec_to_input
        )

        # The gating pre-activation (pi_gate) for the auxiliary path
        pi_gate = sae_in_centered @ self.W_enc + self.b_gate
        pi_gate_act = torch.relu(pi_gate)

        # L1-like penalty scaled by W_dec norms
        l1_loss = (
            step_input.coefficients["l1"]
            * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
        )

        # Aux reconstruction: reconstruct x purely from gating path
        via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
        aux_recon_loss = (
            (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
        )

        # Return both losses separately
        return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}

    def log_histograms(self) -> dict[str, NDArray[Any]]:
        """Log histograms of the weights and biases."""
        b_gate_dist = self.b_gate.detach().float().cpu().numpy()
        b_mag_dist = self.b_mag.detach().float().cpu().numpy()
        return {
            **super().log_histograms(),
            "weights/b_gate": b_gate_dist,
            "weights/b_mag": b_mag_dist,
        }

    @torch.no_grad()
    def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
        """Initialize decoder with constant norm"""
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
        self.W_dec.data *= norm

    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l1": TrainCoefficientConfig(
                value=self.cfg.l1_coefficient,
                warm_up_steps=self.cfg.l1_warm_up_steps,
            ),
        }

    def to_inference_config_dict(self) -> dict[str, Any]:
        return filter_valid_dataclass_fields(
            self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
        )

encode_with_hidden_pre(x)

Gated forward pass with pre-activation (for training). We also inject noise if self.training is True.

Source code in sae_lens/saes/gated_sae.py
def encode_with_hidden_pre(
    self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
    """
    Gated forward pass with pre-activation (for training).
    We also inject noise if self.training is True.
    """
    sae_in = self.process_sae_in(x)

    # Gating path
    gating_pre_activation = sae_in @ self.W_enc + self.b_gate
    active_features = (gating_pre_activation > 0).to(self.dtype)

    # Magnitude path
    magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
    if self.training and self.cfg.noise_scale > 0:
        magnitude_pre_activation += (
            torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
        )
    magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)

    feature_magnitudes = self.activation_fn(magnitude_pre_activation)

    # Combine gating path and magnitude path
    feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

    # Return both the final feature activations and the pre-activation (for logging or penalty)
    return feature_acts, magnitude_pre_activation

initialize_decoder_norm_constant_norm(norm=0.1)

Initialize decoder with constant norm

Source code in sae_lens/saes/gated_sae.py
@torch.no_grad()
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
    """Initialize decoder with constant norm"""
    self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
    self.W_dec.data *= norm

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/gated_sae.py
def log_histograms(self) -> dict[str, NDArray[Any]]:
    """Log histograms of the weights and biases."""
    b_gate_dist = self.b_gate.detach().float().cpu().numpy()
    b_mag_dist = self.b_mag.detach().float().cpu().numpy()
    return {
        **super().log_histograms(),
        "weights/b_gate": b_gate_dist,
        "weights/b_mag": b_mag_dist,
    }

GatedTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a GatedTrainingSAE.

Source code in sae_lens/saes/gated_sae.py
@dataclass
class GatedTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a GatedTrainingSAE.
    """

    l1_coefficient: float = 1.0
    l1_warm_up_steps: int = 0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "gated"

HookedSAETransformer

Bases: HookedTransformer

Source code in sae_lens/analysis/hooked_sae_transformer.py
class HookedSAETransformer(HookedTransformer):
    def __init__(
        self,
        *model_args: Any,
        **model_kwargs: Any,
    ):
        """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

        Note that if you want to load the model from pretrained weights, you should use
        :meth:`from_pretrained` instead.

        Args:
            *model_args: Positional arguments for HookedTransformer initialization
            **model_kwargs: Keyword arguments for HookedTransformer initialization
        """
        super().__init__(*model_args, **model_kwargs)
        self.acts_to_saes: dict[str, SAE] = {}  # type: ignore

    def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
        """Attaches an SAE to the model

        WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

        Args:
            sae: SparseAutoencoderBase. The SAE to attach to the model
            use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
        """
        act_name = sae.cfg.metadata.hook_name
        if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
            logging.warning(
                f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
            )
            return

        if use_error_term is not None:
            if not hasattr(sae, "_original_use_error_term"):
                sae._original_use_error_term = sae.use_error_term  # type: ignore
            sae.use_error_term = use_error_term
        self.acts_to_saes[act_name] = sae
        set_deep_attr(self, act_name, sae)
        self.setup()

    def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
        """Resets an SAE that was attached to the model

        By default will remove the SAE from that hook_point.
        If prev_sae is provided, will replace the current SAE with the provided one.
        This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)

        Args:
            act_name: str. The hook_name of the SAE to reset
            prev_sae: SAE | None. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
        """
        if act_name not in self.acts_to_saes:
            logging.warning(
                f"No SAE is attached to {act_name}. There's nothing to reset."
            )
            return

        current_sae = self.acts_to_saes[act_name]
        if hasattr(current_sae, "_original_use_error_term"):
            current_sae.use_error_term = current_sae._original_use_error_term  # type: ignore
            delattr(current_sae, "_original_use_error_term")

        if prev_sae:
            set_deep_attr(self, act_name, prev_sae)
            self.acts_to_saes[act_name] = prev_sae
        else:
            set_deep_attr(self, act_name, HookPoint())
            del self.acts_to_saes[act_name]

    def reset_saes(
        self,
        act_names: str | list[str] | None = None,
        prev_saes: list[SAE[Any] | None] | None = None,
    ):
        """Reset the SAEs attached to the model

        If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
        Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

        Args:
            act_names (str | list[str] | None): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
            prev_saes (list[SAE | None] | None): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
        """
        if isinstance(act_names, str):
            act_names = [act_names]
        elif act_names is None:
            act_names = list(self.acts_to_saes.keys())

        if prev_saes:
            if len(act_names) != len(prev_saes):
                raise ValueError("act_names and prev_saes must have the same length")
        else:
            prev_saes = [None] * len(act_names)  # type: ignore

        for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
            self._reset_sae(act_name, prev_sae)

        self.setup()

    def run_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
        **model_kwargs: Any,
    ) -> (
        None
        | Float[torch.Tensor, "batch pos d_vocab"]
        | Loss
        | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
    ):
        """Wrapper around HookedTransformer forward pass.

        Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self(*model_args, **model_kwargs)

    def run_with_cache_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
        return_cache_object: bool = True,
        remove_batch_dim: bool = False,
        **kwargs: Any,
    ) -> tuple[
        None
        | Float[torch.Tensor, "batch pos d_vocab"]
        | Loss
        | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
        ActivationCache | dict[str, torch.Tensor],
    ]:
        """Wrapper around 'run_with_cache' in HookedTransformer.

        Attaches given SAEs before running the model with cache and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
            return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
                useful HookedTransformer specific methods, otherwise it will return a dictionary of
                activations as in HookedRootModule.
            remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
            **kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(
            saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
        ):
            return self.run_with_cache(  # type: ignore
                *model_args,
                return_cache_object=return_cache_object,  # type: ignore
                remove_batch_dim=remove_batch_dim,
                **kwargs,
            )

    def run_with_hooks_with_saes(
        self,
        *model_args: Any,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        fwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
        bwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
        reset_hooks_end: bool = True,
        clear_contexts: bool = False,
        **model_kwargs: Any,
    ):
        """Wrapper around 'run_with_hooks' in HookedTransformer.

        Attaches the given SAEs to the model before running the model with hooks and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
            fwd_hooks: (list[tuple[str | Callable, Callable]]) List of forward hooks to apply
            bwd_hooks: (list[tuple[str | Callable, Callable]]) List of backward hooks to apply
            reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
            clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self.run_with_hooks(
                *model_args,
                fwd_hooks=fwd_hooks,
                bwd_hooks=bwd_hooks,
                reset_hooks_end=reset_hooks_end,
                clear_contexts=clear_contexts,
                **model_kwargs,
            )

    @contextmanager
    def saes(
        self,
        saes: SAE[Any] | list[SAE[Any]] = [],
        reset_saes_end: bool = True,
        use_error_term: bool | None = None,
    ):
        """
        A context manager for adding temporary SAEs to the model.
        See HookedTransformer.hooks for a similar context manager for hooks.
        By default will keep track of previously attached SAEs, and restore them when the context manager exits.

        Example:

        .. code-block:: python

            from transformer_lens import HookedSAETransformer
            from sae_lens.saes.sae import SAE

            model = HookedSAETransformer.from_pretrained('gpt2-small')
            sae_cfg = SAEConfig(...)
            sae = SAE(sae_cfg)
            with model.saes(saes=[sae]):
                spliced_logits = model(text)


        Args:
            saes (SAE | list[SAE]): SAEs to be attached.
            reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
            use_error_term (bool | None): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
        """
        act_names_to_reset = []
        prev_saes = []
        if isinstance(saes, SAE):
            saes = [saes]
        try:
            for sae in saes:
                act_names_to_reset.append(sae.cfg.metadata.hook_name)
                prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
                prev_saes.append(prev_sae)
                self.add_sae(sae, use_error_term=use_error_term)
            yield self
        finally:
            if reset_saes_end:
                self.reset_saes(act_names_to_reset, prev_saes)

__init__(*model_args, **model_kwargs)

Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

Note that if you want to load the model from pretrained weights, you should use :meth:from_pretrained instead.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for HookedTransformer initialization

()
**model_kwargs Any

Keyword arguments for HookedTransformer initialization

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def __init__(
    self,
    *model_args: Any,
    **model_kwargs: Any,
):
    """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

    Note that if you want to load the model from pretrained weights, you should use
    :meth:`from_pretrained` instead.

    Args:
        *model_args: Positional arguments for HookedTransformer initialization
        **model_kwargs: Keyword arguments for HookedTransformer initialization
    """
    super().__init__(*model_args, **model_kwargs)
    self.acts_to_saes: dict[str, SAE] = {}  # type: ignore

add_sae(sae, use_error_term=None)

Attaches an SAE to the model

WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

Parameters:

Name Type Description Default
sae SAE[Any]

SparseAutoencoderBase. The SAE to attach to the model

required
use_error_term bool | None

(bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
    """Attaches an SAE to the model

    WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

    Args:
        sae: SparseAutoencoderBase. The SAE to attach to the model
        use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
    """
    act_name = sae.cfg.metadata.hook_name
    if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
        logging.warning(
            f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
        )
        return

    if use_error_term is not None:
        if not hasattr(sae, "_original_use_error_term"):
            sae._original_use_error_term = sae.use_error_term  # type: ignore
        sae.use_error_term = use_error_term
    self.acts_to_saes[act_name] = sae
    set_deep_attr(self, act_name, sae)
    self.setup()

reset_saes(act_names=None, prev_saes=None)

Reset the SAEs attached to the model

If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model. Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

Parameters:

Name Type Description Default
act_names str | list[str] | None

The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.

None
prev_saes list[SAE | None] | None

List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def reset_saes(
    self,
    act_names: str | list[str] | None = None,
    prev_saes: list[SAE[Any] | None] | None = None,
):
    """Reset the SAEs attached to the model

    If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
    Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

    Args:
        act_names (str | list[str] | None): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
        prev_saes (list[SAE | None] | None): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
    """
    if isinstance(act_names, str):
        act_names = [act_names]
    elif act_names is None:
        act_names = list(self.acts_to_saes.keys())

    if prev_saes:
        if len(act_names) != len(prev_saes):
            raise ValueError("act_names and prev_saes must have the same length")
    else:
        prev_saes = [None] * len(act_names)  # type: ignore

    for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
        self._reset_sae(act_name, prev_sae)

    self.setup()

run_with_cache_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, return_cache_object=True, remove_batch_dim=False, **kwargs)

Wrapper around 'run_with_cache' in HookedTransformer.

Attaches given SAEs before running the model with cache and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term bool | None

(bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.

None
return_cache_object bool

(bool) if True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule.

True
remove_batch_dim bool

(bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

False
**kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_cache_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
    return_cache_object: bool = True,
    remove_batch_dim: bool = False,
    **kwargs: Any,
) -> tuple[
    None
    | Float[torch.Tensor, "batch pos d_vocab"]
    | Loss
    | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
    ActivationCache | dict[str, torch.Tensor],
]:
    """Wrapper around 'run_with_cache' in HookedTransformer.

    Attaches given SAEs before running the model with cache and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
        return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
            useful HookedTransformer specific methods, otherwise it will return a dictionary of
            activations as in HookedRootModule.
        remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
        **kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self.run_with_cache(  # type: ignore
            *model_args,
            return_cache_object=return_cache_object,  # type: ignore
            remove_batch_dim=remove_batch_dim,
            **kwargs,
        )

run_with_hooks_with_saes(*model_args, saes=[], reset_saes_end=True, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False, **model_kwargs)

Wrapper around 'run_with_hooks' in HookedTransformer.

Attaches the given SAEs to the model before running the model with hooks and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)

True
fwd_hooks list[tuple[str | Callable, Callable]]

(list[tuple[str | Callable, Callable]]) List of forward hooks to apply

[]
bwd_hooks list[tuple[str | Callable, Callable]]

(list[tuple[str | Callable, Callable]]) List of backward hooks to apply

[]
reset_hooks_end bool

(bool) Whether to reset the hooks at the end of the forward pass (default: True)

True
clear_contexts bool

(bool) Whether to clear the contexts at the end of the forward pass (default: False)

False
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_hooks_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    fwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
    bwd_hooks: list[tuple[str | Callable, Callable]] = [],  # type: ignore
    reset_hooks_end: bool = True,
    clear_contexts: bool = False,
    **model_kwargs: Any,
):
    """Wrapper around 'run_with_hooks' in HookedTransformer.

    Attaches the given SAEs to the model before running the model with hooks and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
        fwd_hooks: (list[tuple[str | Callable, Callable]]) List of forward hooks to apply
        bwd_hooks: (list[tuple[str | Callable, Callable]]) List of backward hooks to apply
        reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
        clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self.run_with_hooks(
            *model_args,
            fwd_hooks=fwd_hooks,
            bwd_hooks=bwd_hooks,
            reset_hooks_end=reset_hooks_end,
            clear_contexts=clear_contexts,
            **model_kwargs,
        )

run_with_saes(*model_args, saes=[], reset_saes_end=True, use_error_term=None, **model_kwargs)

Wrapper around HookedTransformer forward pass.

Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes SAE[Any] | list[SAE[Any]]

(SAE | list[SAE]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
use_error_term bool | None

(bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.

None
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_saes(
    self,
    *model_args: Any,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
    **model_kwargs: Any,
) -> (
    None
    | Float[torch.Tensor, "batch pos d_vocab"]
    | Loss
    | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
):
    """Wrapper around HookedTransformer forward pass.

    Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (SAE | list[SAE]) The SAEs to be attached for this forward pass
        reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        use_error_term: (bool | None) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(
        saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
    ):
        return self(*model_args, **model_kwargs)

saes(saes=[], reset_saes_end=True, use_error_term=None)

A context manager for adding temporary SAEs to the model. See HookedTransformer.hooks for a similar context manager for hooks. By default will keep track of previously attached SAEs, and restore them when the context manager exits.

Example:

.. code-block:: python

from transformer_lens import HookedSAETransformer
from sae_lens.saes.sae import SAE

model = HookedSAETransformer.from_pretrained('gpt2-small')
sae_cfg = SAEConfig(...)
sae = SAE(sae_cfg)
with model.saes(saes=[sae]):
    spliced_logits = model(text)

Parameters:

Name Type Description Default
saes SAE | list[SAE]

SAEs to be attached.

[]
reset_saes_end bool

If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.

True
use_error_term bool | None

If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
@contextmanager
def saes(
    self,
    saes: SAE[Any] | list[SAE[Any]] = [],
    reset_saes_end: bool = True,
    use_error_term: bool | None = None,
):
    """
    A context manager for adding temporary SAEs to the model.
    See HookedTransformer.hooks for a similar context manager for hooks.
    By default will keep track of previously attached SAEs, and restore them when the context manager exits.

    Example:

    .. code-block:: python

        from transformer_lens import HookedSAETransformer
        from sae_lens.saes.sae import SAE

        model = HookedSAETransformer.from_pretrained('gpt2-small')
        sae_cfg = SAEConfig(...)
        sae = SAE(sae_cfg)
        with model.saes(saes=[sae]):
            spliced_logits = model(text)


    Args:
        saes (SAE | list[SAE]): SAEs to be attached.
        reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
        use_error_term (bool | None): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
    """
    act_names_to_reset = []
    prev_saes = []
    if isinstance(saes, SAE):
        saes = [saes]
    try:
        for sae in saes:
            act_names_to_reset.append(sae.cfg.metadata.hook_name)
            prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
            prev_saes.append(prev_sae)
            self.add_sae(sae, use_error_term=use_error_term)
        yield self
    finally:
        if reset_saes_end:
            self.reset_saes(act_names_to_reset, prev_saes)

JumpReLUSAE

Bases: SAE[JumpReLUSAEConfig]

JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a JumpReLU activation. For each unit, if its pre-activation is <= threshold, that unit is zeroed out; otherwise, it follows a user-specified activation function (e.g., ReLU, tanh-relu, etc.).

It implements
  • initialize_weights: sets up parameters, including a threshold.
  • encode: computes the feature activations using JumpReLU.
  • decode: reconstructs the input from the feature activations.

The BaseSAE.forward() method automatically calls encode and decode, including any error-term processing if configured.

Source code in sae_lens/saes/jumprelu_sae.py
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
    """
    JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a JumpReLU activation. For each unit, if its pre-activation is
    <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
    activation function (e.g., ReLU, tanh-relu, etc.).

    It implements:
      - initialize_weights: sets up parameters, including a threshold.
      - encode: computes the feature activations using JumpReLU.
      - decode: reconstructs the input from the feature activations.

    The BaseSAE.forward() method automatically calls encode and decode,
    including any error-term processing if configured.
    """

    b_enc: nn.Parameter
    threshold: nn.Parameter

    def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        self.threshold = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Encode the input tensor into the feature space using JumpReLU.
        The threshold parameter determines which units remain active.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
        base_acts = self.activation_fn(hidden_pre)

        # 2) Zero out any unit whose (hidden_pre <= threshold).
        #    We cast the boolean mask to the same dtype for safe multiplication.
        jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)

        # 3) Multiply the normally activated units by that mask.
        return self.hook_sae_acts_post(base_acts * jump_relu_mask)

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Decode the feature activations back to the input space.
        Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
        """
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """
        Override to properly handle threshold adjustment with W_dec norms.
        When we scale the encoder weights, we need to scale the threshold
        by the same factor to maintain the same sparsity pattern.
        """
        # Save the current threshold before calling parent method
        current_thresh = self.threshold.clone()

        # Get W_dec norms that will be used for scaling
        W_dec_norms = self.W_dec.norm(dim=-1)

        # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
        super().fold_W_dec_norm()

        # Scale the threshold by the same factor as we scaled b_enc
        # This ensures the same features remain active/inactive after folding
        self.threshold.data = current_thresh * W_dec_norms

decode(feature_acts)

Decode the feature activations back to the input space. Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.

Source code in sae_lens/saes/jumprelu_sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """
    Decode the feature activations back to the input space.
    Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
    """
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space using JumpReLU. The threshold parameter determines which units remain active.

Source code in sae_lens/saes/jumprelu_sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Encode the input tensor into the feature space using JumpReLU.
    The threshold parameter determines which units remain active.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
    base_acts = self.activation_fn(hidden_pre)

    # 2) Zero out any unit whose (hidden_pre <= threshold).
    #    We cast the boolean mask to the same dtype for safe multiplication.
    jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)

    # 3) Multiply the normally activated units by that mask.
    return self.hook_sae_acts_post(base_acts * jump_relu_mask)

fold_W_dec_norm()

Override to properly handle threshold adjustment with W_dec norms. When we scale the encoder weights, we need to scale the threshold by the same factor to maintain the same sparsity pattern.

Source code in sae_lens/saes/jumprelu_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """
    Override to properly handle threshold adjustment with W_dec norms.
    When we scale the encoder weights, we need to scale the threshold
    by the same factor to maintain the same sparsity pattern.
    """
    # Save the current threshold before calling parent method
    current_thresh = self.threshold.clone()

    # Get W_dec norms that will be used for scaling
    W_dec_norms = self.W_dec.norm(dim=-1)

    # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
    super().fold_W_dec_norm()

    # Scale the threshold by the same factor as we scaled b_enc
    # This ensures the same features remain active/inactive after folding
    self.threshold.data = current_thresh * W_dec_norms

JumpReLUSAEConfig dataclass

Bases: SAEConfig

Configuration class for a JumpReLUSAE.

Source code in sae_lens/saes/jumprelu_sae.py
@dataclass
class JumpReLUSAEConfig(SAEConfig):
    """
    Configuration class for a JumpReLUSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "jumprelu"

JumpReLUTrainingSAE

Bases: TrainingSAE[JumpReLUTrainingSAEConfig]

JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.

Similar to the inference-only JumpReLUSAE, but with: - A learnable log-threshold parameter (instead of a raw threshold). - Forward passes that add noise during training, if configured. - A specialized auxiliary loss term for sparsity (L0 or similar).

Methods of interest include: - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold. - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise. - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.

Source code in sae_lens/saes/jumprelu_sae.py
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
    """
    JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.

    Similar to the inference-only JumpReLUSAE, but with:
      - A learnable log-threshold parameter (instead of a raw threshold).
      - Forward passes that add noise during training, if configured.
      - A specialized auxiliary loss term for sparsity (L0 or similar).

    Methods of interest include:
    - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
    - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
    - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
    """

    b_enc: nn.Parameter
    log_threshold: nn.Parameter

    def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

        # We'll store a bandwidth for the training approach, if needed
        self.bandwidth = cfg.jumprelu_bandwidth

        # In typical JumpReLU training code, we may track a log_threshold:
        self.log_threshold = nn.Parameter(
            torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
            * np.log(cfg.jumprelu_init_threshold)
        )

    @override
    def initialize_weights(self) -> None:
        """
        Initialize parameters like the base SAE, but also add log_threshold.
        """
        super().initialize_weights()
        # Encoder Bias
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

    @property
    def threshold(self) -> torch.Tensor:
        """
        Returns the parameterized threshold > 0 for each unit.
        threshold = exp(log_threshold).
        """
        return torch.exp(self.log_threshold)

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        sae_in = self.process_sae_in(x)

        hidden_pre = sae_in @ self.W_enc + self.b_enc

        if self.training and self.cfg.noise_scale > 0:
            hidden_pre = (
                hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
            )

        feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)

        return feature_acts, hidden_pre  # type: ignore

    @override
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        """Calculate architecture-specific auxiliary loss terms."""
        l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1)  # type: ignore
        l0_loss = (step_input.coefficients["l0"] * l0).mean()
        return {"l0_loss": l0_loss}

    @override
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l0": TrainCoefficientConfig(
                value=self.cfg.l0_coefficient,
                warm_up_steps=self.cfg.l0_warm_up_steps,
            ),
        }

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """
        Override to properly handle threshold adjustment with W_dec norms.
        """
        # Save the current threshold before we call the parent method
        current_thresh = self.threshold.clone()

        # Get W_dec norms
        W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)

        # Call parent implementation to handle W_enc and W_dec adjustment
        super().fold_W_dec_norm()

        # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
        self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())

    def _create_train_step_output(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        loss: torch.Tensor,
        losses: dict[str, torch.Tensor],
    ) -> TrainStepOutput:
        """
        Helper to produce a TrainStepOutput from the trainer.
        The old code expects a method named _create_train_step_output().
        """
        return TrainStepOutput(
            sae_in=sae_in,
            sae_out=sae_out,
            feature_acts=feature_acts,
            hidden_pre=hidden_pre,
            loss=loss,
            losses=losses,
        )

    @torch.no_grad()
    def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
        """Initialize decoder with constant norm"""
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
        self.W_dec.data *= norm

    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        """Convert log_threshold to threshold for saving"""
        if "log_threshold" in state_dict:
            threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
            del state_dict["log_threshold"]
            state_dict["threshold"] = threshold

    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        """Convert threshold to log_threshold for loading"""
        if "threshold" in state_dict:
            threshold = state_dict["threshold"]
            del state_dict["threshold"]
            state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()

    def to_inference_config_dict(self) -> dict[str, Any]:
        return filter_valid_dataclass_fields(
            self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
        )

threshold: torch.Tensor property

Returns the parameterized threshold > 0 for each unit. threshold = exp(log_threshold).

calculate_aux_loss(step_input, feature_acts, hidden_pre, sae_out)

Calculate architecture-specific auxiliary loss terms.

Source code in sae_lens/saes/jumprelu_sae.py
@override
def calculate_aux_loss(
    self,
    step_input: TrainStepInput,
    feature_acts: torch.Tensor,
    hidden_pre: torch.Tensor,
    sae_out: torch.Tensor,
) -> dict[str, torch.Tensor]:
    """Calculate architecture-specific auxiliary loss terms."""
    l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1)  # type: ignore
    l0_loss = (step_input.coefficients["l0"] * l0).mean()
    return {"l0_loss": l0_loss}

fold_W_dec_norm()

Override to properly handle threshold adjustment with W_dec norms.

Source code in sae_lens/saes/jumprelu_sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """
    Override to properly handle threshold adjustment with W_dec norms.
    """
    # Save the current threshold before we call the parent method
    current_thresh = self.threshold.clone()

    # Get W_dec norms
    W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)

    # Call parent implementation to handle W_enc and W_dec adjustment
    super().fold_W_dec_norm()

    # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
    self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())

initialize_decoder_norm_constant_norm(norm=0.1)

Initialize decoder with constant norm

Source code in sae_lens/saes/jumprelu_sae.py
@torch.no_grad()
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
    """Initialize decoder with constant norm"""
    self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
    self.W_dec.data *= norm

initialize_weights()

Initialize parameters like the base SAE, but also add log_threshold.

Source code in sae_lens/saes/jumprelu_sae.py
@override
def initialize_weights(self) -> None:
    """
    Initialize parameters like the base SAE, but also add log_threshold.
    """
    super().initialize_weights()
    # Encoder Bias
    self.b_enc = nn.Parameter(
        torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
    )

process_state_dict_for_loading(state_dict)

Convert threshold to log_threshold for loading

Source code in sae_lens/saes/jumprelu_sae.py
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
    """Convert threshold to log_threshold for loading"""
    if "threshold" in state_dict:
        threshold = state_dict["threshold"]
        del state_dict["threshold"]
        state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()

process_state_dict_for_saving(state_dict)

Convert log_threshold to threshold for saving

Source code in sae_lens/saes/jumprelu_sae.py
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
    """Convert log_threshold to threshold for saving"""
    if "log_threshold" in state_dict:
        threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
        del state_dict["log_threshold"]
        state_dict["threshold"] = threshold

JumpReLUTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a JumpReLUTrainingSAE.

Source code in sae_lens/saes/jumprelu_sae.py
@dataclass
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a JumpReLUTrainingSAE.
    """

    jumprelu_init_threshold: float = 0.001
    jumprelu_bandwidth: float = 0.001
    l0_coefficient: float = 1.0
    l0_warm_up_steps: int = 0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "jumprelu"

LanguageModelSAERunnerConfig dataclass

Bases: Generic[T_TRAINING_SAE_CONFIG]

Configuration for training a sparse autoencoder on a language model.

Parameters:

Name Type Description Default
sae T_TRAINING_SAE_CONFIG

The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).

required
model_name str

The name of the model to use. This should be the name of the model in the Hugging Face model hub.

'gelu-2l'
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
hook_name str

The name of the hook to use. This should be a valid TransformerLens hook.

'blocks.0.hook_mlp_out'
hook_eval str

NOT CURRENTLY IN USE. The name of the hook to use for evaluation.

'NOT_IN_USE'
hook_layer int

The index of the layer to hook. Used to stop forward passes early and speed up processing.

0
hook_head_index int

When the hook is for an activation with a head index, we can specify a specific head to use here.

None
dataset_path str

A Hugging Face dataset path.

''
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

True
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
is_dataset_tokenized bool

Whether the dataset is already tokenized.

True
context_size int

The context size to use when generating activations on which to train the SAE.

128
use_cached_activations bool

Whether to use cached activations. This is useful when doing sweeps over the same activations.

False
cached_activations_path str

The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".

None
from_pretrained_path str

The path to a pretrained SAE. We can finetune an existing SAE if needed.

None
n_batches_in_buffer int

The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.

20
training_tokens int

The number of training tokens.

2000000
store_batch_size_prompts int

The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.

32
seqpos_slice tuple[int | None, ...]

Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.

(None,)
device str

The device to use. Usually "cuda".

'cpu'
act_store_device str

The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.

'with_model'
seed int

The seed to use.

42
dtype str

The data type to use for the SAE and activations.

'float32'
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
autocast bool

Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.

False
autocast_lm bool

Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.

False
compile_llm bool

Whether to compile the LLM using torch.compile.

False
llm_compilation_mode str

The compilation mode to use for the LLM if compile_llm is True.

None
compile_sae bool

Whether to compile the SAE using torch.compile.

False
sae_compilation_mode str

The compilation mode to use for the SAE if compile_sae is True.

None
train_batch_size_tokens int

The batch size for training, in tokens. This controls the batch size of the SAE training loop.

4096
adam_beta1 float

The beta1 parameter for the Adam optimizer.

0.0
adam_beta2 float

The beta2 parameter for the Adam optimizer.

0.999
lr float

The learning rate.

0.0003
lr_scheduler_name str

The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").

'constant'
lr_warm_up_steps int

The number of warm-up steps for the learning rate.

0
lr_end float

The end learning rate if using a scheduler like cosine annealing. Defaults to lr / 10.

None
lr_decay_steps int

The number of decay steps for the learning rate if using a scheduler with decay.

0
n_restart_cycles int

The number of restart cycles for the cosine annealing with warm restarts scheduler.

1
dead_feature_window int

The window size (in training steps) for detecting dead features.

1000
feature_sampling_window int

The window size (in training steps) for resampling features (e.g. dead features).

2000
dead_feature_threshold float

The threshold below which a feature's activation frequency is considered dead.

1e-08
n_eval_batches int

The number of batches to use for evaluation.

10
eval_batch_size_prompts int

The batch size for evaluation, in prompts. Useful if evals cause OOM.

None
logger LoggingConfig

Configuration for logging (e.g. W&B).

LoggingConfig()
n_checkpoints int

The number of checkpoints to save during training. 0 means no checkpoints.

0
checkpoint_path str

The path to save checkpoints. A unique ID will be appended to this path.

'checkpoints'
verbose bool

Whether to print verbose output.

True
model_kwargs dict[str, Any]

Keyword arguments for model.run_with_cache

dict_field(default={})
model_from_pretrained_kwargs dict[str, Any]

Additional keyword arguments to pass to the model's from_pretrained method.

dict_field(default=None)
sae_lens_version str

The version of the sae_lens library.

lambda: __version__()
sae_lens_training_version str

The version of the sae_lens training library.

lambda: __version__()
exclude_special_tokens bool | list[int]

Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.

False
Source code in sae_lens/config.py
@dataclass
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
    """
    Configuration for training a sparse autoencoder on a language model.

    Args:
        sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
        model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
        hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
        hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
        hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
        dataset_path (str): A Hugging Face dataset path.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        is_dataset_tokenized (bool): Whether the dataset is already tokenized.
        context_size (int): The context size to use when generating activations on which to train the SAE.
        use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
        cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
        from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
        n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
        training_tokens (int): The number of training tokens.
        store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
        seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
        device (str): The device to use. Usually "cuda".
        act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
        seed (int): The seed to use.
        dtype (str): The data type to use for the SAE and activations.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
        autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
        compile_llm (bool): Whether to compile the LLM using `torch.compile`.
        llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
        compile_sae (bool): Whether to compile the SAE using `torch.compile`.
        sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
        train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
        adam_beta1 (float): The beta1 parameter for the Adam optimizer.
        adam_beta2 (float): The beta2 parameter for the Adam optimizer.
        lr (float): The learning rate.
        lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
        lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
        lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
        lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
        n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
        dead_feature_window (int): The window size (in training steps) for detecting dead features.
        feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
        dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
        n_eval_batches (int): The number of batches to use for evaluation.
        eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
        logger (LoggingConfig): Configuration for logging (e.g. W&B).
        n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
        checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
        verbose (bool): Whether to print verbose output.
        model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
        model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
        sae_lens_version (str): The version of the sae_lens library.
        sae_lens_training_version (str): The version of the sae_lens training library.
        exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
    """

    sae: T_TRAINING_SAE_CONFIG

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.0.hook_mlp_out"
    hook_eval: str = "NOT_IN_USE"
    hook_layer: int = 0
    hook_head_index: int | None = None
    dataset_path: str = ""
    dataset_trust_remote_code: bool = True
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    use_cached_activations: bool = False
    cached_activations_path: str | None = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )

    # SAE Parameters
    from_pretrained_path: str | None = None

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    store_batch_size_prompts: int = 32
    seqpos_slice: tuple[int | None, ...] = (None,)

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"  # type: ignore #
    prepend_bos: bool = True

    # Performance - see compilation section of lm_runner.py for info
    autocast: bool = False  # autocast to autocast_dtype during training
    autocast_lm: bool = False  # autocast lm during activation fetching
    compile_llm: bool = False  # use torch.compile on the LLM
    llm_compilation_mode: str | None = None  # which torch.compile mode to use
    compile_sae: bool = False  # use torch.compile on the SAE
    sae_compilation_mode: str | None = None

    # Training Parameters

    ## Batch size
    train_batch_size_tokens: int = 4096

    ## Adam
    adam_beta1: float = 0.0
    adam_beta2: float = 0.999

    ## Learning Rate Schedule
    lr: float = 3e-4
    lr_scheduler_name: str = (
        "constant"  # constant, cosineannealing, cosineannealingwarmrestarts
    )
    lr_warm_up_steps: int = 0
    lr_end: float | None = None  # only used for cosine annealing, default is lr / 10
    lr_decay_steps: int = 0
    n_restart_cycles: int = 1  # used only for cosineannealingwarmrestarts

    # Resampling protocol args
    dead_feature_window: int = 1000  # unless this window is larger feature sampling,
    feature_sampling_window: int = 2000
    dead_feature_threshold: float = 1e-8

    # Evals
    n_eval_batches: int = 10
    eval_batch_size_prompts: int | None = None  # useful if evals cause OOM

    logger: LoggingConfig = field(default_factory=LoggingConfig)

    # Misc
    n_checkpoints: int = 0
    checkpoint_path: str = "checkpoints"
    verbose: bool = True
    model_kwargs: dict[str, Any] = dict_field(default={})
    model_from_pretrained_kwargs: dict[str, Any] | None = dict_field(default=None)
    sae_lens_version: str = field(default_factory=lambda: __version__)
    sae_lens_training_version: str = field(default_factory=lambda: __version__)
    exclude_special_tokens: bool | list[int] = False

    def __post_init__(self):
        if self.use_cached_activations and self.cached_activations_path is None:
            self.cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )
        self.tokens_per_buffer = (
            self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
        )

        if self.logger.run_name is None:
            self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"

        if self.model_from_pretrained_kwargs is None:
            if self.model_class_name == "HookedTransformer":
                self.model_from_pretrained_kwargs = {"center_writing_weights": False}
            else:
                self.model_from_pretrained_kwargs = {}

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

        if self.lr_end is None:
            self.lr_end = self.lr / 10

        unique_id = self.logger.wandb_id
        if unique_id is None:
            unique_id = cast(
                Any, wandb
            ).util.generate_id()  # not sure why this type is erroring
        self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

        if self.verbose:
            logger.info(
                f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
            )
            # Print out some useful info:
            n_tokens_per_buffer = (
                self.store_batch_size_prompts
                * self.context_size
                * self.n_batches_in_buffer
            )
            logger.info(
                f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10**6}"
            )
            n_contexts_per_buffer = (
                self.store_batch_size_prompts * self.n_batches_in_buffer
            )
            logger.info(
                f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10**6}"
            )

            total_training_steps = (
                self.training_tokens
            ) // self.train_batch_size_tokens
            logger.info(f"Total training steps: {total_training_steps}")

            total_wandb_updates = (
                total_training_steps // self.logger.wandb_log_frequency
            )
            logger.info(f"Total wandb updates: {total_wandb_updates}")

            # how many times will we sample dead neurons?
            # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
            n_feature_window_samples = (
                total_training_steps // self.feature_sampling_window
            )
            logger.info(
                f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10**6}"
            )
            logger.info(
                f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10**6}"
            )
            logger.info(
                f"We will reset the sparsity calculation {n_feature_window_samples} times."
            )
            # logger.info("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
            logger.info(
                f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
            )

        if self.context_size < 0:
            raise ValueError(
                f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
            )

        _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

        if isinstance(self.exclude_special_tokens, list) and not all(
            isinstance(x, int) for x in self.exclude_special_tokens
        ):
            raise ValueError("exclude_special_tokens list must contain only integers")

    @property
    def total_training_tokens(self) -> int:
        return self.training_tokens

    @property
    def total_training_steps(self) -> int:
        return self.total_training_tokens // self.train_batch_size_tokens

    def get_training_sae_cfg_dict(self) -> dict[str, Any]:
        return self.sae.to_dict()

    def to_dict(self) -> dict[str, Any]:
        # Make a shallow copy of config's dictionary
        d = dict(self.__dict__)

        d["logger"] = asdict(self.logger)
        d["sae"] = self.sae.to_dict()
        # Overwrite fields that might not be JSON-serializable
        d["dtype"] = str(self.dtype)
        d["device"] = str(self.device)
        d["act_store_device"] = str(self.act_store_device)
        return d

    def to_json(self, path: str) -> None:
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))

        with open(path + "cfg.json", "w") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
        with open(path + "cfg.json") as f:
            cfg = json.load(f)

        # ensure that seqpos slices is a tuple
        # Ensure seqpos_slice is a tuple
        if "seqpos_slice" in cfg:
            if isinstance(cfg["seqpos_slice"], list):
                cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
            elif not isinstance(cfg["seqpos_slice"], tuple):
                cfg["seqpos_slice"] = (cfg["seqpos_slice"],)

        return cls(**cfg)

PretokenizeRunner

Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
class PretokenizeRunner:
    """
    Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.
    """

    def __init__(self, cfg: PretokenizeRunnerConfig):
        self.cfg = cfg

    def run(self):
        """
        Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
        """
        dataset = load_dataset(
            self.cfg.dataset_path,
            name=self.cfg.dataset_name,
            data_dir=self.cfg.data_dir,
            data_files=self.cfg.data_files,
            split=self.cfg.split,
            streaming=self.cfg.streaming,
        )
        if isinstance(dataset, DatasetDict):
            raise ValueError(
                "Dataset has multiple splits. Must provide a 'split' param."
            )
        tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
        tokenizer.model_max_length = sys.maxsize
        tokenized_dataset = pretokenize_dataset(
            cast(Dataset, dataset), tokenizer, self.cfg
        )

        if self.cfg.save_path is not None:
            tokenized_dataset.save_to_disk(self.cfg.save_path)
            metadata = metadata_from_config(self.cfg)
            metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
            with open(metadata_path, "w") as f:
                json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

        if self.cfg.hf_repo_id is not None:
            push_to_hugging_face_hub(tokenized_dataset, self.cfg)

        return tokenized_dataset

run()

Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
def run(self):
    """
    Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
    """
    dataset = load_dataset(
        self.cfg.dataset_path,
        name=self.cfg.dataset_name,
        data_dir=self.cfg.data_dir,
        data_files=self.cfg.data_files,
        split=self.cfg.split,
        streaming=self.cfg.streaming,
    )
    if isinstance(dataset, DatasetDict):
        raise ValueError(
            "Dataset has multiple splits. Must provide a 'split' param."
        )
    tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
    tokenizer.model_max_length = sys.maxsize
    tokenized_dataset = pretokenize_dataset(
        cast(Dataset, dataset), tokenizer, self.cfg
    )

    if self.cfg.save_path is not None:
        tokenized_dataset.save_to_disk(self.cfg.save_path)
        metadata = metadata_from_config(self.cfg)
        metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

    if self.cfg.hf_repo_id is not None:
        push_to_hugging_face_hub(tokenized_dataset, self.cfg)

    return tokenized_dataset

PretokenizeRunnerConfig dataclass

Configuration class for pretokenizing a dataset.

Source code in sae_lens/config.py
@dataclass
class PretokenizeRunnerConfig:
    """
    Configuration class for pretokenizing a dataset.
    """

    tokenizer_name: str = "gpt2"
    dataset_path: str = ""
    dataset_name: str | None = None
    dataset_trust_remote_code: bool | None = None
    split: str | None = "train"
    data_files: list[str] | None = None
    data_dir: str | None = None
    num_proc: int = 4
    context_size: int = 128
    column_name: str = "text"
    shuffle: bool = True
    seed: int | None = None
    streaming: bool = False
    pretokenize_batch_size: int | None = 1000

    # special tokens
    begin_batch_token: int | Literal["bos", "eos", "sep"] | None = "bos"
    begin_sequence_token: int | Literal["bos", "eos", "sep"] | None = None
    sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos"

    # if saving locally, set save_path
    save_path: str | None = None

    # if saving to huggingface, set hf_repo_id
    hf_repo_id: str | None = None
    hf_num_shards: int = 64
    hf_revision: str = "main"
    hf_is_private_repo: bool = False

SAE

Bases: HookedRootModule, Generic[T_SAE_CONFIG], ABC

Abstract base class for all SAE architectures.

Source code in sae_lens/saes/sae.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
    """Abstract base class for all SAE architectures."""

    cfg: T_SAE_CONFIG
    dtype: torch.dtype
    device: torch.device
    use_error_term: bool

    # For type checking only - don't provide default values
    # These will be initialized by subclasses
    W_enc: nn.Parameter
    W_dec: nn.Parameter
    b_dec: nn.Parameter

    def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
        """Initialize the SAE."""
        super().__init__()

        self.cfg = cfg

        if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
            warnings.warn(
                "\nThis SAE has non-empty model_from_pretrained_kwargs. "
                "\nFor optimal performance, load the model like so:\n"
                "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
                category=UserWarning,
                stacklevel=1,
            )

        self.dtype = DTYPE_MAP[cfg.dtype]
        self.device = torch.device(cfg.device)
        self.use_error_term = use_error_term

        # Set up activation function
        self.activation_fn = self.get_activation_fn()

        # Initialize weights
        self.initialize_weights()

        # Set up hooks
        self.hook_sae_input = HookPoint()
        self.hook_sae_acts_pre = HookPoint()
        self.hook_sae_acts_post = HookPoint()
        self.hook_sae_output = HookPoint()
        self.hook_sae_recons = HookPoint()
        self.hook_sae_error = HookPoint()

        # handle hook_z reshaping if needed.
        if self.cfg.reshape_activations == "hook_z":
            self.turn_on_forward_pass_hook_z_reshaping()
        else:
            self.turn_off_forward_pass_hook_z_reshaping()

        # Set up activation normalization
        self._setup_activation_normalization()

        self.setup()  # Required for HookedRootModule

    @torch.no_grad()
    def fold_activation_norm_scaling_factor(self, scaling_factor: float):
        self.W_enc.data *= scaling_factor  # type: ignore
        self.W_dec.data /= scaling_factor  # type: ignore
        self.b_dec.data /= scaling_factor  # type: ignore
        self.cfg.normalize_activations = "none"

    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        """Get the activation function specified in config."""
        return nn.ReLU()

    def _setup_activation_normalization(self):
        """Set up activation normalization functions based on config."""
        if self.cfg.normalize_activations == "constant_norm_rescale":

            def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
                self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
                return x * self.x_norm_coeff

            def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
                x = x / self.x_norm_coeff  # type: ignore
                del self.x_norm_coeff
                return x

            self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
            self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

        elif self.cfg.normalize_activations == "layer_norm":

            def run_time_activation_ln_in(
                x: torch.Tensor, eps: float = 1e-5
            ) -> torch.Tensor:
                mu = x.mean(dim=-1, keepdim=True)
                x = x - mu
                std = x.std(dim=-1, keepdim=True)
                x = x / (std + eps)
                self.ln_mu = mu
                self.ln_std = std
                return x

            def run_time_activation_ln_out(
                x: torch.Tensor,
                eps: float = 1e-5,  # noqa: ARG001
            ) -> torch.Tensor:
                return x * self.ln_std + self.ln_mu  # type: ignore

            self.run_time_activation_norm_fn_in = run_time_activation_ln_in
            self.run_time_activation_norm_fn_out = run_time_activation_ln_out
        else:
            self.run_time_activation_norm_fn_in = lambda x: x
            self.run_time_activation_norm_fn_out = lambda x: x

    def initialize_weights(self):
        """Initialize model weights."""
        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

        w_dec_data = torch.empty(
            self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
        )
        nn.init.kaiming_uniform_(w_dec_data)
        self.W_dec = nn.Parameter(w_dec_data)

        w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
        self.W_enc = nn.Parameter(w_enc_data)

    @abstractmethod
    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """Encode input tensor to feature space."""
        pass

    @abstractmethod
    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """Decode feature activations back to input space."""
        pass

    def turn_on_forward_pass_hook_z_reshaping(self):
        if (
            self.cfg.metadata.hook_name is not None
            and not self.cfg.metadata.hook_name.endswith("_z")
        ):
            raise ValueError("This method should only be called for hook_z SAEs.")

        # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")

        def reshape_fn_in(x: torch.Tensor):
            # print(f"reshape_fn_in input shape: {x.shape}")
            self.d_head = x.shape[-1]
            # print(f"Setting d_head to: {self.d_head}")
            self.reshape_fn_in = lambda x: einops.rearrange(
                x, "... n_heads d_head -> ... (n_heads d_head)"
            )
            return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

        self.reshape_fn_in = reshape_fn_in
        self.reshape_fn_out = lambda x, d_head: einops.rearrange(
            x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
        )
        self.hook_z_reshaping_mode = True
        # print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")

    def turn_off_forward_pass_hook_z_reshaping(self):
        self.reshape_fn_in = lambda x: x
        self.reshape_fn_out = lambda x, d_head: x  # noqa: ARG005
        self.d_head = None
        self.hook_z_reshaping_mode = False

    @overload
    def to(
        self: T_SAE,
        device: torch.device | str | None = ...,
        dtype: torch.dtype | None = ...,
        non_blocking: bool = ...,
    ) -> T_SAE: ...

    @overload
    def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...

    @overload
    def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...

    def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE:  # type: ignore
        device_arg = None
        dtype_arg = None

        # Check args
        for arg in args:
            if isinstance(arg, (torch.device, str)):
                device_arg = arg
            elif isinstance(arg, torch.dtype):
                dtype_arg = arg
            elif isinstance(arg, torch.Tensor):
                device_arg = arg.device
                dtype_arg = arg.dtype

        # Check kwargs
        device_arg = kwargs.get("device", device_arg)
        dtype_arg = kwargs.get("dtype", dtype_arg)

        # Update device in config if provided
        if device_arg is not None:
            # Convert device to torch.device if it's a string
            device = (
                torch.device(device_arg) if isinstance(device_arg, str) else device_arg
            )

            # Update the cfg.device
            self.cfg.device = str(device)

            # Update the device property
            self.device = device

        # Update dtype in config if provided
        if dtype_arg is not None:
            # Update the cfg.dtype
            self.cfg.dtype = str(dtype_arg)

            # Update the dtype property
            self.dtype = dtype_arg

        return super().to(*args, **kwargs)

    def process_sae_in(
        self, sae_in: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_in"]:
        # print(f"Input shape to process_sae_in: {sae_in.shape}")
        # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
        # print(f"self.b_dec shape: {self.b_dec.shape}")
        # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")

        sae_in = sae_in.to(self.dtype)

        # print(f"Shape before reshape_fn_in: {sae_in.shape}")
        sae_in = self.reshape_fn_in(sae_in)
        # print(f"Shape after reshape_fn_in: {sae_in.shape}")

        sae_in = self.hook_sae_input(sae_in)
        sae_in = self.run_time_activation_norm_fn_in(sae_in)

        # Here's where the error happens
        bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
        # print(f"Bias term shape: {bias_term.shape}")

        return sae_in - bias_term

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the SAE."""
        feature_acts = self.encode(x)
        sae_out = self.decode(feature_acts)

        if self.use_error_term:
            with torch.no_grad():
                # Recompute without hooks for true error term
                with _disable_hooks(self):
                    feature_acts_clean = self.encode(x)
                    x_reconstruct_clean = self.decode(feature_acts_clean)
                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            sae_out = sae_out + sae_error

        return self.hook_sae_output(sae_out)

    # overwrite this in subclasses to modify the state_dict in-place before saving
    def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
        pass

    # overwrite this in subclasses to modify the state_dict in-place after loading
    def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
        pass

    @torch.no_grad()
    def fold_W_dec_norm(self):
        """Fold decoder norms into encoder."""
        W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T

        # Only update b_enc if it exists (standard/jumprelu architectures)
        if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
            self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

    def get_name(self):
        """Generate a name for this SAE."""
        return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"

    def save_model(self, path: str | Path) -> tuple[Path, Path]:
        """Save model weights and config to disk."""
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        # Generate the weights
        state_dict = self.state_dict()  # Use internal SAE state dict
        self.process_state_dict_for_saving(state_dict)
        model_weights_path = path / SAE_WEIGHTS_FILENAME
        save_file(state_dict, model_weights_path)

        # Save the config
        config = self.cfg.to_dict()
        cfg_path = path / SAE_CFG_FILENAME
        with open(cfg_path, "w") as f:
            json.dump(config, f)

        return model_weights_path, cfg_path

    ## Initialization Methods
    @torch.no_grad()
    def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
        out = torch.tensor(origin, dtype=self.dtype, device=self.device)
        self.b_dec.data = out

    @torch.no_grad()
    def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
        previous_b_dec = self.b_dec.clone().cpu()
        out = all_activations.mean(dim=0)

        previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
        distances = torch.norm(all_activations - out, dim=-1)

        logger.info("Reinitializing b_dec with mean of activations")
        logger.debug(
            f"Previous distances: {previous_distances.median(0).values.mean().item()}"
        )
        logger.debug(f"New distances: {distances.median(0).values.mean().item()}")

        self.b_dec.data = out.to(self.dtype).to(self.device)

    # Class methods for loading models
    @classmethod
    @deprecated("Use load_from_disk instead")
    def load_from_pretrained(
        cls: Type[T_SAE],
        path: str | Path,
        device: str = "cpu",
        dtype: str | None = None,
    ) -> T_SAE:
        return cls.load_from_disk(path, device=device, dtype=dtype)

    @classmethod
    def load_from_disk(
        cls: Type[T_SAE],
        path: str | Path,
        device: str = "cpu",
        dtype: str | None = None,
        converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
    ) -> T_SAE:
        overrides = {"dtype": dtype} if dtype is not None else None
        cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
        cfg_dict = handle_config_defaulting(cfg_dict)
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            cfg_dict["architecture"]
        )
        sae_cfg = sae_config_cls.from_dict(cfg_dict)
        sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
        sae = sae_cls(sae_cfg)
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict)
        return sae

    @classmethod
    def from_pretrained(
        cls: Type[T_SAE],
        release: str,
        sae_id: str,
        device: str = "cpu",
        force_download: bool = False,
        converter: PretrainedSaeHuggingfaceLoader | None = None,
    ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
        """
        Load a pretrained SAE from the Hugging Face model hub.

        Args:
            release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
            id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
            device: The device to load the SAE on.
            return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
        """

        # get sae directory
        sae_directory = get_pretrained_saes_directory()

        # Validate release and sae_id
        if release not in sae_directory:
            if "/" not in release:
                raise ValueError(
                    f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
                )
        elif sae_id not in sae_directory[release].saes_map:
            # Handle special cases like Gemma Scope
            if (
                "gemma-scope" in release
                and "canonical" not in release
                and f"{release}-canonical" in sae_directory
            ):
                canonical_ids = list(
                    sae_directory[release + "-canonical"].saes_map.keys()
                )
                # Shorten the lengthy string of valid IDs
                if len(canonical_ids) > 5:
                    str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
                else:
                    str_canonical_ids = str(canonical_ids)
                value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
            else:
                value_suffix = ""

            valid_ids = list(sae_directory[release].saes_map.keys())
            # Shorten the lengthy string of valid IDs
            if len(valid_ids) > 5:
                str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
            else:
                str_valid_ids = str(valid_ids)

            raise ValueError(
                f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
                + value_suffix
            )

        conversion_loader = (
            converter
            or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
        )
        repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
        config_overrides = get_config_overrides(release, sae_id)
        config_overrides["device"] = device

        # Load config and weights
        cfg_dict, state_dict, log_sparsities = conversion_loader(
            repo_id=repo_id,
            folder_name=folder_name,
            device=device,
            force_download=force_download,
            cfg_overrides=config_overrides,
        )
        cfg_dict = handle_config_defaulting(cfg_dict)

        # Create SAE with appropriate architecture
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            cfg_dict["architecture"]
        )
        sae_cfg = sae_config_cls.from_dict(cfg_dict)
        sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
        sae = sae_cls(sae_cfg)
        sae.process_state_dict_for_loading(state_dict)
        sae.load_state_dict(state_dict)

        # Apply normalization if needed
        if cfg_dict.get("normalize_activations") == "expected_average_only_in":
            norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
            if norm_scaling_factor is not None:
                sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
                cfg_dict["normalize_activations"] = "none"
            else:
                warnings.warn(
                    f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
                )

        return sae, cfg_dict, log_sparsities

    @classmethod
    def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
        """Create an SAE from a config dictionary."""
        sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
        sae_config_cls = cls.get_sae_config_class_for_architecture(
            config_dict["architecture"]
        )
        return sae_cls(sae_config_cls.from_dict(config_dict))

    @classmethod
    def get_sae_class_for_architecture(
        cls: Type[T_SAE], architecture: str
    ) -> Type[T_SAE]:
        """Get the SAE class for a given architecture."""
        sae_cls, _ = get_sae_class(architecture)
        if not issubclass(sae_cls, cls):
            raise ValueError(
                f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
            )
        return sae_cls

    # in the future, this can be used to load different config classes for different architectures
    @classmethod
    def get_sae_config_class_for_architecture(
        cls,
        architecture: str,  # noqa: ARG003
    ) -> type[SAEConfig]:
        return SAEConfig

__init__(cfg, use_error_term=False)

Initialize the SAE.

Source code in sae_lens/saes/sae.py
def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
    """Initialize the SAE."""
    super().__init__()

    self.cfg = cfg

    if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
        warnings.warn(
            "\nThis SAE has non-empty model_from_pretrained_kwargs. "
            "\nFor optimal performance, load the model like so:\n"
            "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
            category=UserWarning,
            stacklevel=1,
        )

    self.dtype = DTYPE_MAP[cfg.dtype]
    self.device = torch.device(cfg.device)
    self.use_error_term = use_error_term

    # Set up activation function
    self.activation_fn = self.get_activation_fn()

    # Initialize weights
    self.initialize_weights()

    # Set up hooks
    self.hook_sae_input = HookPoint()
    self.hook_sae_acts_pre = HookPoint()
    self.hook_sae_acts_post = HookPoint()
    self.hook_sae_output = HookPoint()
    self.hook_sae_recons = HookPoint()
    self.hook_sae_error = HookPoint()

    # handle hook_z reshaping if needed.
    if self.cfg.reshape_activations == "hook_z":
        self.turn_on_forward_pass_hook_z_reshaping()
    else:
        self.turn_off_forward_pass_hook_z_reshaping()

    # Set up activation normalization
    self._setup_activation_normalization()

    self.setup()  # Required for HookedRootModule

decode(feature_acts) abstractmethod

Decode feature activations back to input space.

Source code in sae_lens/saes/sae.py
@abstractmethod
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """Decode feature activations back to input space."""
    pass

encode(x) abstractmethod

Encode input tensor to feature space.

Source code in sae_lens/saes/sae.py
@abstractmethod
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """Encode input tensor to feature space."""
    pass

fold_W_dec_norm()

Fold decoder norms into encoder.

Source code in sae_lens/saes/sae.py
@torch.no_grad()
def fold_W_dec_norm(self):
    """Fold decoder norms into encoder."""
    W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
    self.W_dec.data = self.W_dec.data / W_dec_norms
    self.W_enc.data = self.W_enc.data * W_dec_norms.T

    # Only update b_enc if it exists (standard/jumprelu architectures)
    if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
        self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

forward(x)

Forward pass through the SAE.

Source code in sae_lens/saes/sae.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the SAE."""
    feature_acts = self.encode(x)
    sae_out = self.decode(feature_acts)

    if self.use_error_term:
        with torch.no_grad():
            # Recompute without hooks for true error term
            with _disable_hooks(self):
                feature_acts_clean = self.encode(x)
                x_reconstruct_clean = self.decode(feature_acts_clean)
            sae_error = self.hook_sae_error(x - x_reconstruct_clean)
        sae_out = sae_out + sae_error

    return self.hook_sae_output(sae_out)

from_dict(config_dict) classmethod

Create an SAE from a config dictionary.

Source code in sae_lens/saes/sae.py
@classmethod
def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
    """Create an SAE from a config dictionary."""
    sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
    sae_config_cls = cls.get_sae_config_class_for_architecture(
        config_dict["architecture"]
    )
    return sae_cls(sae_config_cls.from_dict(config_dict))

from_pretrained(release, sae_id, device='cpu', force_download=False, converter=None) classmethod

Load a pretrained SAE from the Hugging Face model hub.

Parameters:

Name Type Description Default
release str

The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.

required
id

The id of the SAE to load. This will be mapped to a path in the huggingface repo.

required
device str

The device to load the SAE on.

'cpu'
return_sparsity_if_present

If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.

required
Source code in sae_lens/saes/sae.py
@classmethod
def from_pretrained(
    cls: Type[T_SAE],
    release: str,
    sae_id: str,
    device: str = "cpu",
    force_download: bool = False,
    converter: PretrainedSaeHuggingfaceLoader | None = None,
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
    """
    Load a pretrained SAE from the Hugging Face model hub.

    Args:
        release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
        id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
        device: The device to load the SAE on.
        return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
    """

    # get sae directory
    sae_directory = get_pretrained_saes_directory()

    # Validate release and sae_id
    if release not in sae_directory:
        if "/" not in release:
            raise ValueError(
                f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
            )
    elif sae_id not in sae_directory[release].saes_map:
        # Handle special cases like Gemma Scope
        if (
            "gemma-scope" in release
            and "canonical" not in release
            and f"{release}-canonical" in sae_directory
        ):
            canonical_ids = list(
                sae_directory[release + "-canonical"].saes_map.keys()
            )
            # Shorten the lengthy string of valid IDs
            if len(canonical_ids) > 5:
                str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
            else:
                str_canonical_ids = str(canonical_ids)
            value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
        else:
            value_suffix = ""

        valid_ids = list(sae_directory[release].saes_map.keys())
        # Shorten the lengthy string of valid IDs
        if len(valid_ids) > 5:
            str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
        else:
            str_valid_ids = str(valid_ids)

        raise ValueError(
            f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
            + value_suffix
        )

    conversion_loader = (
        converter
        or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
    )
    repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
    config_overrides = get_config_overrides(release, sae_id)
    config_overrides["device"] = device

    # Load config and weights
    cfg_dict, state_dict, log_sparsities = conversion_loader(
        repo_id=repo_id,
        folder_name=folder_name,
        device=device,
        force_download=force_download,
        cfg_overrides=config_overrides,
    )
    cfg_dict = handle_config_defaulting(cfg_dict)

    # Create SAE with appropriate architecture
    sae_config_cls = cls.get_sae_config_class_for_architecture(
        cfg_dict["architecture"]
    )
    sae_cfg = sae_config_cls.from_dict(cfg_dict)
    sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
    sae = sae_cls(sae_cfg)
    sae.process_state_dict_for_loading(state_dict)
    sae.load_state_dict(state_dict)

    # Apply normalization if needed
    if cfg_dict.get("normalize_activations") == "expected_average_only_in":
        norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
        if norm_scaling_factor is not None:
            sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
            cfg_dict["normalize_activations"] = "none"
        else:
            warnings.warn(
                f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
            )

    return sae, cfg_dict, log_sparsities

get_activation_fn()

Get the activation function specified in config.

Source code in sae_lens/saes/sae.py
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
    """Get the activation function specified in config."""
    return nn.ReLU()

get_name()

Generate a name for this SAE.

Source code in sae_lens/saes/sae.py
def get_name(self):
    """Generate a name for this SAE."""
    return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"

get_sae_class_for_architecture(architecture) classmethod

Get the SAE class for a given architecture.

Source code in sae_lens/saes/sae.py
@classmethod
def get_sae_class_for_architecture(
    cls: Type[T_SAE], architecture: str
) -> Type[T_SAE]:
    """Get the SAE class for a given architecture."""
    sae_cls, _ = get_sae_class(architecture)
    if not issubclass(sae_cls, cls):
        raise ValueError(
            f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
        )
    return sae_cls

initialize_weights()

Initialize model weights.

Source code in sae_lens/saes/sae.py
def initialize_weights(self):
    """Initialize model weights."""
    self.b_dec = nn.Parameter(
        torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
    )

    w_dec_data = torch.empty(
        self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
    )
    nn.init.kaiming_uniform_(w_dec_data)
    self.W_dec = nn.Parameter(w_dec_data)

    w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
    self.W_enc = nn.Parameter(w_enc_data)

save_model(path)

Save model weights and config to disk.

Source code in sae_lens/saes/sae.py
def save_model(self, path: str | Path) -> tuple[Path, Path]:
    """Save model weights and config to disk."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Generate the weights
    state_dict = self.state_dict()  # Use internal SAE state dict
    self.process_state_dict_for_saving(state_dict)
    model_weights_path = path / SAE_WEIGHTS_FILENAME
    save_file(state_dict, model_weights_path)

    # Save the config
    config = self.cfg.to_dict()
    cfg_path = path / SAE_CFG_FILENAME
    with open(cfg_path, "w") as f:
        json.dump(config, f)

    return model_weights_path, cfg_path

SAEConfig dataclass

Bases: ABC

Base configuration for SAE models.

Source code in sae_lens/saes/sae.py
@dataclass
class SAEConfig(ABC):
    """Base configuration for SAE models."""

    d_in: int
    d_sae: int
    dtype: str = "float32"
    device: str = "cpu"
    apply_b_dec_to_input: bool = True
    normalize_activations: Literal[
        "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
    ] = "none"  # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
    reshape_activations: Literal["none", "hook_z"] = "none"
    metadata: SAEMetadata = field(default_factory=SAEMetadata)

    @classmethod
    @abstractmethod
    def architecture(cls) -> str: ...

    def to_dict(self) -> dict[str, Any]:
        res = {field.name: getattr(self, field.name) for field in fields(self)}
        res["metadata"] = asdict(self.metadata)
        res["architecture"] = self.architecture()
        return res

    @classmethod
    def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
        cfg_class = get_sae_class(config_dict["architecture"])[1]
        filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
        res = cfg_class(**filtered_config_dict)
        if "metadata" in config_dict:
            res.metadata = SAEMetadata(**config_dict["metadata"])
        if not isinstance(res, cls):
            raise ValueError(
                f"SAE config class {cls} does not match dict config class {type(res)}"
            )
        return res

    def __post_init__(self):
        if self.normalize_activations not in [
            "none",
            "expected_average_only_in",
            "constant_norm_rescale",
            "layer_norm",
        ]:
            raise ValueError(
                f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
            )

SAETrainingRunner

Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.

Source code in sae_lens/sae_training_runner.py
class SAETrainingRunner:
    """
    Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
    """

    cfg: LanguageModelSAERunnerConfig[Any]
    model: HookedRootModule
    sae: TrainingSAE[Any]
    activations_store: ActivationsStore

    def __init__(
        self,
        cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
        override_dataset: HfDataset | None = None,
        override_model: HookedRootModule | None = None,
        override_sae: TrainingSAE[Any] | None = None,
    ):
        if override_dataset is not None:
            logger.warning(
                f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
            )
        if override_model is not None:
            logger.warning(
                f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
            )

        self.cfg = cfg

        if override_model is None:
            self.model = load_model(
                self.cfg.model_class_name,
                self.cfg.model_name,
                device=self.cfg.device,
                model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
            )
        else:
            self.model = override_model

        self.activations_store = ActivationsStore.from_config(
            self.model,
            self.cfg,
            override_dataset=override_dataset,
        )

        if override_sae is None:
            if self.cfg.from_pretrained_path is not None:
                self.sae = TrainingSAE.load_from_disk(
                    self.cfg.from_pretrained_path, self.cfg.device
                )
            else:
                self.sae = TrainingSAE.from_dict(
                    TrainingSAEConfig.from_dict(
                        self.cfg.get_training_sae_cfg_dict(),
                    ).to_dict()
                )
                self._init_sae_group_b_decs()
        else:
            self.sae = override_sae

    def run(self):
        """
        Run the training of the SAE.
        """

        if self.cfg.logger.log_to_wandb:
            wandb.init(
                project=self.cfg.logger.wandb_project,
                entity=self.cfg.logger.wandb_entity,
                config=cast(Any, self.cfg),
                name=self.cfg.logger.run_name,
                id=self.cfg.logger.wandb_id,
            )

        trainer = SAETrainer(
            model=self.model,
            sae=self.sae,
            activation_store=self.activations_store,
            save_checkpoint_fn=self.save_checkpoint,
            cfg=self.cfg,
        )

        self._compile_if_needed()
        sae = self.run_trainer_with_interruption_handling(trainer)

        if self.cfg.logger.log_to_wandb:
            wandb.finish()

        return sae

    def _compile_if_needed(self):
        # Compile model and SAE
        #  torch.compile can provide significant speedups (10-20% in testing)
        # using max-autotune gives the best speedups but:
        # (a) increases VRAM usage,
        # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
        # (c) takes some time to compile
        # optimal settings seem to be:
        # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
        # (also pylance seems to really hate this)
        if self.cfg.compile_llm:
            self.model = torch.compile(
                self.model,
                mode=self.cfg.llm_compilation_mode,
            )  # type: ignore

        if self.cfg.compile_sae:
            backend = "aot_eager" if self.cfg.device == "mps" else "inductor"

            self.sae.training_forward_pass = torch.compile(  # type: ignore
                self.sae.training_forward_pass,
                mode=self.cfg.sae_compilation_mode,
                backend=backend,
            )  # type: ignore

    def run_trainer_with_interruption_handling(
        self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
    ):
        try:
            # signal handlers (if preempted)
            signal.signal(signal.SIGINT, interrupt_callback)
            signal.signal(signal.SIGTERM, interrupt_callback)

            # train SAE
            sae = trainer.fit()

        except (KeyboardInterrupt, InterruptedException):
            logger.warning("interrupted, saving progress")
            checkpoint_name = str(trainer.n_training_tokens)
            self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
            logger.info("done saving")
            raise

        return sae

    # TODO: move this into the SAE trainer or Training SAE class
    def _init_sae_group_b_decs(
        self,
    ) -> None:
        """
        extract all activations at a certain layer and use for sae b_dec initialization
        """

        if self.cfg.sae.b_dec_init_method == "geometric_median":
            self.activations_store.set_norm_scaling_factor_if_needed()
            layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
            # get geometric median of the activations if we're using those.
            median = compute_geometric_median(
                layer_acts,
                maxiter=100,
            ).median
            self.sae.initialize_b_dec_with_precalculated(median)
        elif self.cfg.sae.b_dec_init_method == "mean":
            self.activations_store.set_norm_scaling_factor_if_needed()
            layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
            self.sae.initialize_b_dec_with_mean(layer_acts)  # type: ignore

    @staticmethod
    def save_checkpoint(
        trainer: SAETrainer[TrainingSAE[Any], Any],
        checkpoint_name: str,
        wandb_aliases: list[str] | None = None,
    ) -> None:
        base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
        base_path.mkdir(exist_ok=True, parents=True)

        trainer.activations_store.save(
            str(base_path / "activations_store_state.safetensors")
        )

        weights_path, cfg_path = trainer.sae.save_model(str(base_path))

        sparsity_path = base_path / SPARSITY_FILENAME
        save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)

        runner_config = trainer.cfg.to_dict()
        with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
            json.dump(runner_config, f)

        if trainer.cfg.logger.log_to_wandb:
            trainer.cfg.logger.log(
                trainer,
                weights_path,
                cfg_path,
                sparsity_path=sparsity_path,
                wandb_aliases=wandb_aliases,
            )

run()

Run the training of the SAE.

Source code in sae_lens/sae_training_runner.py
def run(self):
    """
    Run the training of the SAE.
    """

    if self.cfg.logger.log_to_wandb:
        wandb.init(
            project=self.cfg.logger.wandb_project,
            entity=self.cfg.logger.wandb_entity,
            config=cast(Any, self.cfg),
            name=self.cfg.logger.run_name,
            id=self.cfg.logger.wandb_id,
        )

    trainer = SAETrainer(
        model=self.model,
        sae=self.sae,
        activation_store=self.activations_store,
        save_checkpoint_fn=self.save_checkpoint,
        cfg=self.cfg,
    )

    self._compile_if_needed()
    sae = self.run_trainer_with_interruption_handling(trainer)

    if self.cfg.logger.log_to_wandb:
        wandb.finish()

    return sae

StandardSAE

Bases: SAE[StandardSAEConfig]

StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE) using a simple linear encoder and decoder.

It implements the required abstract methods from BaseSAE
  • initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
  • encode: computes the feature activations from an input.
  • decode: reconstructs the input from the feature activations.

The BaseSAE.forward() method automatically calls encode and decode, including any error-term processing if configured.

Source code in sae_lens/saes/standard_sae.py
class StandardSAE(SAE[StandardSAEConfig]):
    """
    StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
    using a simple linear encoder and decoder.

    It implements the required abstract methods from BaseSAE:
      - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
      - encode: computes the feature activations from an input.
      - decode: reconstructs the input from the feature activations.

    The BaseSAE.forward() method automatically calls encode and decode,
    including any error-term processing if configured.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        # Initialize encoder weights and bias.
        super().initialize_weights()
        _init_weights_standard(self)

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Encode the input tensor into the feature space.
        For inference, no noise is added.
        """
        # Preprocess the SAE input (casting type, applying hooks, normalization)
        sae_in = self.process_sae_in(x)
        # Compute the pre-activation values
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
        return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Decode the feature activations back to the input space.
        Now, if hook_z reshaping is turned on, we reverse the flattening.
        """
        # 1) linear transform
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        # 2) hook reconstruction
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
        return self.reshape_fn_out(sae_out_pre, self.d_head)

decode(feature_acts)

Decode the feature activations back to the input space. Now, if hook_z reshaping is turned on, we reverse the flattening.

Source code in sae_lens/saes/standard_sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """
    Decode the feature activations back to the input space.
    Now, if hook_z reshaping is turned on, we reverse the flattening.
    """
    # 1) linear transform
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    # 2) hook reconstruction
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Encode the input tensor into the feature space. For inference, no noise is added.

Source code in sae_lens/saes/standard_sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Encode the input tensor into the feature space.
    For inference, no noise is added.
    """
    # Preprocess the SAE input (casting type, applying hooks, normalization)
    sae_in = self.process_sae_in(x)
    # Compute the pre-activation values
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
    return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

StandardSAEConfig dataclass

Bases: SAEConfig

Configuration class for a StandardSAE.

Source code in sae_lens/saes/standard_sae.py
@dataclass
class StandardSAEConfig(SAEConfig):
    """
    Configuration class for a StandardSAE.
    """

    @override
    @classmethod
    def architecture(cls) -> str:
        return "standard"

StandardTrainingSAE

Bases: TrainingSAE[StandardTrainingSAEConfig]

StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture. It implements: - initialize_weights: basic weight initialization for encoder/decoder. - encode: inference encoding (invokes encode_with_hidden_pre). - decode: a simple linear decoder. - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates. - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.

Source code in sae_lens/saes/standard_sae.py
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
    """
    StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
    It implements:
      - initialize_weights: basic weight initialization for encoder/decoder.
      - encode: inference encoding (invokes encode_with_hidden_pre).
      - decode: a simple linear decoder.
      - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
      - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
    """

    b_enc: nn.Parameter

    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_standard(self)

    @override
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
        return {
            "l1": TrainCoefficientConfig(
                value=self.cfg.l1_coefficient,
                warm_up_steps=self.cfg.l1_warm_up_steps,
            ),
        }

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        # Process the input (including dtype conversion, hook call, and any activation normalization)
        sae_in = self.process_sae_in(x)
        # Compute the pre-activation (and allow for a hook if desired)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)  # type: ignore
        # Add noise during training for robustness (scaled by noise_scale from the configuration)
        if self.training and self.cfg.noise_scale > 0:
            hidden_pre_noised = (
                hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
            )
        else:
            hidden_pre_noised = hidden_pre
        # Apply the activation function (and any post-activation hook)
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
        return feature_acts, hidden_pre_noised

    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # The "standard" auxiliary loss is a sparsity penalty on the feature activations
        weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)

        # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
        sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
        l1_loss = (step_input.coefficients["l1"] * sparsity).mean()

        return {"l1_loss": l1_loss}

    def log_histograms(self) -> dict[str, NDArray[np.generic]]:
        """Log histograms of the weights and biases."""
        b_e_dist = self.b_enc.detach().float().cpu().numpy()
        return {
            **super().log_histograms(),
            "weights/b_e": b_e_dist,
        }

    def to_inference_config_dict(self) -> dict[str, Any]:
        return filter_valid_dataclass_fields(
            self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
        )

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/standard_sae.py
def log_histograms(self) -> dict[str, NDArray[np.generic]]:
    """Log histograms of the weights and biases."""
    b_e_dist = self.b_enc.detach().float().cpu().numpy()
    return {
        **super().log_histograms(),
        "weights/b_e": b_e_dist,
    }

StandardTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a StandardTrainingSAE.

Source code in sae_lens/saes/standard_sae.py
@dataclass
class StandardTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a StandardTrainingSAE.
    """

    l1_coefficient: float = 1.0
    lp_norm: float = 1.0
    l1_warm_up_steps: int = 0

    @override
    @classmethod
    def architecture(cls) -> str:
        return "standard"

TopKSAE

Bases: SAE[TopKSAEConfig]

An inference-only sparse autoencoder using a "topk" activation function. It uses linear encoder and decoder layers, applying the TopK activation to the hidden pre-activation in its encode step.

Source code in sae_lens/saes/topk_sae.py
class TopKSAE(SAE[TopKSAEConfig]):
    """
    An inference-only sparse autoencoder using a "topk" activation function.
    It uses linear encoder and decoder layers, applying the TopK activation
    to the hidden pre-activation in its encode step.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
        """
        Args:
            cfg: SAEConfig defining model size and behavior.
            use_error_term: Whether to apply the error-term approach in the forward pass.
        """
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        # Initialize encoder weights and bias.
        super().initialize_weights()
        _init_weights_topk(self)

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Converts input x into feature activations.
        Uses topk activation from the config (cfg.activation_fn == "topk")
        under the hood.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
        return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Reconstructs the input from topk feature activations.
        Applies optional finetuning scaling, hooking to recons, out normalization,
        and optional head reshaping.
        """
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @override
    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        return TopK(self.cfg.k)

__init__(cfg, use_error_term=False)

Parameters:

Name Type Description Default
cfg TopKSAEConfig

SAEConfig defining model size and behavior.

required
use_error_term bool

Whether to apply the error-term approach in the forward pass.

False
Source code in sae_lens/saes/topk_sae.py
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
    """
    Args:
        cfg: SAEConfig defining model size and behavior.
        use_error_term: Whether to apply the error-term approach in the forward pass.
    """
    super().__init__(cfg, use_error_term)

decode(feature_acts)

Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping.

Source code in sae_lens/saes/topk_sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """
    Reconstructs the input from topk feature activations.
    Applies optional finetuning scaling, hooking to recons, out normalization,
    and optional head reshaping.
    """
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

Converts input x into feature activations. Uses topk activation from the config (cfg.activation_fn == "topk") under the hood.

Source code in sae_lens/saes/topk_sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Converts input x into feature activations.
    Uses topk activation from the config (cfg.activation_fn == "topk")
    under the hood.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
    return self.hook_sae_acts_post(self.activation_fn(hidden_pre))

TopKSAEConfig dataclass

Bases: SAEConfig

Configuration class for a TopKSAE.

Source code in sae_lens/saes/topk_sae.py
@dataclass
class TopKSAEConfig(SAEConfig):
    """
    Configuration class for a TopKSAE.
    """

    k: int = 100

    @override
    @classmethod
    def architecture(cls) -> str:
        return "topk"

TopKTrainingSAE

Bases: TrainingSAE[TopKTrainingSAEConfig]

TopK variant with training functionality. Injects noise during training, optionally calculates a topk-related auxiliary loss, etc.

Source code in sae_lens/saes/topk_sae.py
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
    """
    TopK variant with training functionality. Injects noise during training, optionally
    calculates a topk-related auxiliary loss, etc.
    """

    b_enc: nn.Parameter

    def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

    @override
    def initialize_weights(self) -> None:
        super().initialize_weights()
        _init_weights_topk(self)

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        """
        Similar to the base training method: cast input, optionally add noise, then apply TopK.
        """
        sae_in = self.process_sae_in(x)
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

        # Inject noise if training
        if self.training and self.cfg.noise_scale > 0:
            hidden_pre_noised = (
                hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
            )
        else:
            hidden_pre_noised = hidden_pre

        # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
        return feature_acts, hidden_pre_noised

    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> dict[str, torch.Tensor]:
        # Calculate the auxiliary loss for dead neurons
        topk_loss = self.calculate_topk_aux_loss(
            sae_in=step_input.sae_in,
            sae_out=sae_out,
            hidden_pre=hidden_pre,
            dead_neuron_mask=step_input.dead_neuron_mask,
        )
        return {"auxiliary_reconstruction_loss": topk_loss}

    @override
    def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
        return TopK(self.cfg.k)

    @override
    def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
        return {}

    def calculate_topk_aux_loss(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        """
        Calculate TopK auxiliary loss.

        This auxiliary loss encourages dead neurons to learn useful features by having
        them reconstruct the residual error from the live neurons. It's a key part of
        preventing neuron death in TopK SAEs.
        """
        # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
        # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
        if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
            return sae_out.new_tensor(0.0)
        residual = (sae_in - sae_out).detach()

        # Heuristic from Appendix B.1 in the paper
        k_aux = sae_in.shape[-1] // 2

        # Reduce the scale of the loss if there are a small number of dead latents
        scale = min(num_dead / k_aux, 1.0)
        k_aux = min(k_aux, num_dead)

        auxk_acts = _calculate_topk_aux_acts(
            k_aux=k_aux,
            hidden_pre=hidden_pre,
            dead_neuron_mask=dead_neuron_mask,
        )

        # Encourage the top ~50% of dead latents to predict the residual of the
        # top k living latents
        recons = self.decode(auxk_acts)
        auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
        return scale * auxk_loss

    def _calculate_topk_aux_acts(
        self,
        k_aux: int,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Helper method to calculate activations for the auxiliary loss.

        Args:
            k_aux: Number of top dead neurons to select
            hidden_pre: Pre-activation values from encoder
            dead_neuron_mask: Boolean mask indicating which neurons are dead

        Returns:
            Tensor with activations for only the top-k dead neurons, zeros elsewhere
        """
        # Don't include living latents in this loss (set them to -inf so they won't be selected)
        auxk_latents = torch.where(
            dead_neuron_mask[None],
            hidden_pre,
            torch.tensor(-float("inf"), device=hidden_pre.device),
        )

        # Find topk values among dead neurons
        auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)

        # Create a tensor of zeros, then place the topk values at their proper indices
        auxk_acts = torch.zeros_like(hidden_pre)
        auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)

        return auxk_acts

    def to_inference_config_dict(self) -> dict[str, Any]:
        return filter_valid_dataclass_fields(
            self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
        )

calculate_topk_aux_loss(sae_in, sae_out, hidden_pre, dead_neuron_mask)

Calculate TopK auxiliary loss.

This auxiliary loss encourages dead neurons to learn useful features by having them reconstruct the residual error from the live neurons. It's a key part of preventing neuron death in TopK SAEs.

Source code in sae_lens/saes/topk_sae.py
def calculate_topk_aux_loss(
    self,
    sae_in: torch.Tensor,
    sae_out: torch.Tensor,
    hidden_pre: torch.Tensor,
    dead_neuron_mask: torch.Tensor | None,
) -> torch.Tensor:
    """
    Calculate TopK auxiliary loss.

    This auxiliary loss encourages dead neurons to learn useful features by having
    them reconstruct the residual error from the live neurons. It's a key part of
    preventing neuron death in TopK SAEs.
    """
    # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
    # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
    if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
        return sae_out.new_tensor(0.0)
    residual = (sae_in - sae_out).detach()

    # Heuristic from Appendix B.1 in the paper
    k_aux = sae_in.shape[-1] // 2

    # Reduce the scale of the loss if there are a small number of dead latents
    scale = min(num_dead / k_aux, 1.0)
    k_aux = min(k_aux, num_dead)

    auxk_acts = _calculate_topk_aux_acts(
        k_aux=k_aux,
        hidden_pre=hidden_pre,
        dead_neuron_mask=dead_neuron_mask,
    )

    # Encourage the top ~50% of dead latents to predict the residual of the
    # top k living latents
    recons = self.decode(auxk_acts)
    auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
    return scale * auxk_loss

encode_with_hidden_pre(x)

Similar to the base training method: cast input, optionally add noise, then apply TopK.

Source code in sae_lens/saes/topk_sae.py
def encode_with_hidden_pre(
    self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
    """
    Similar to the base training method: cast input, optionally add noise, then apply TopK.
    """
    sae_in = self.process_sae_in(x)
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)

    # Inject noise if training
    if self.training and self.cfg.noise_scale > 0:
        hidden_pre_noised = (
            hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
        )
    else:
        hidden_pre_noised = hidden_pre

    # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
    feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
    return feature_acts, hidden_pre_noised

TopKTrainingSAEConfig dataclass

Bases: TrainingSAEConfig

Configuration class for training a TopKTrainingSAE.

Source code in sae_lens/saes/topk_sae.py
@dataclass
class TopKTrainingSAEConfig(TrainingSAEConfig):
    """
    Configuration class for training a TopKTrainingSAE.
    """

    k: int = 100

    @override
    @classmethod
    def architecture(cls) -> str:
        return "topk"

TrainingSAE

Bases: SAE[T_TRAINING_SAE_CONFIG], ABC

Abstract base class for training versions of SAEs.

Source code in sae_lens/saes/sae.py
class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
    """Abstract base class for training versions of SAEs."""

    def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
        super().__init__(cfg, use_error_term)

        # Turn off hook_z reshaping for training mode - the activation store
        # is expected to handle reshaping before passing data to the SAE
        self.turn_off_forward_pass_hook_z_reshaping()
        self.mse_loss_fn = self._get_mse_loss_fn()

    @abstractmethod
    def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...

    @abstractmethod
    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
        """Encode with access to pre-activation values for training."""
        ...

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        For inference, just encode without returning hidden_pre.
        (training_forward_pass calls encode_with_hidden_pre).
        """
        feature_acts, _ = self.encode_with_hidden_pre(x)
        return feature_acts

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """
        Decodes feature activations back into input space,
        applying optional finetuning scale, hooking, out normalization, etc.
        """
        sae_out_pre = feature_acts @ self.W_dec + self.b_dec
        sae_out_pre = self.hook_sae_recons(sae_out_pre)
        sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
        return self.reshape_fn_out(sae_out_pre, self.d_head)

    @override
    def initialize_weights(self):
        super().initialize_weights()
        if self.cfg.decoder_init_norm is not None:
            with torch.no_grad():
                self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
                self.W_dec.data *= self.cfg.decoder_init_norm
            self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()

    @abstractmethod
    def calculate_aux_loss(
        self,
        step_input: TrainStepInput,
        feature_acts: torch.Tensor,
        hidden_pre: torch.Tensor,
        sae_out: torch.Tensor,
    ) -> torch.Tensor | dict[str, torch.Tensor]:
        """Calculate architecture-specific auxiliary loss terms."""
        ...

    def training_forward_pass(
        self,
        step_input: TrainStepInput,
    ) -> TrainStepOutput:
        """Forward pass during training."""
        feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
        sae_out = self.decode(feature_acts)

        # Calculate MSE loss
        per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
        mse_loss = per_item_mse_loss.sum(dim=-1).mean()

        # Calculate architecture-specific auxiliary losses
        aux_losses = self.calculate_aux_loss(
            step_input=step_input,
            feature_acts=feature_acts,
            hidden_pre=hidden_pre,
            sae_out=sae_out,
        )

        # Total loss is MSE plus all auxiliary losses
        total_loss = mse_loss

        # Create losses dictionary with mse_loss
        losses = {"mse_loss": mse_loss}

        # Add architecture-specific losses to the dictionary
        # Make sure aux_losses is a dictionary with string keys and tensor values
        if isinstance(aux_losses, dict):
            losses.update(aux_losses)

        # Sum all losses for total_loss
        if isinstance(aux_losses, dict):
            for loss_value in aux_losses.values():
                total_loss = total_loss + loss_value
        else:
            # Handle case where aux_losses is a tensor
            total_loss = total_loss + aux_losses

        return TrainStepOutput(
            sae_in=step_input.sae_in,
            sae_out=sae_out,
            feature_acts=feature_acts,
            hidden_pre=hidden_pre,
            loss=total_loss,
            losses=losses,
        )

    def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
        """Save inference version of model weights and config to disk."""
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        # Generate the weights
        state_dict = self.state_dict()  # Use internal SAE state dict
        self.process_state_dict_for_saving_inference(state_dict)
        model_weights_path = path / SAE_WEIGHTS_FILENAME
        save_file(state_dict, model_weights_path)

        # Save the config
        config = self.to_inference_config_dict()
        cfg_path = path / SAE_CFG_FILENAME
        with open(cfg_path, "w") as f:
            json.dump(config, f)

        return model_weights_path, cfg_path

    @abstractmethod
    def to_inference_config_dict(self) -> dict[str, Any]:
        """Convert the config into an inference SAE config dict."""
        ...

    def process_state_dict_for_saving_inference(
        self, state_dict: dict[str, Any]
    ) -> None:
        """
        Process the state dict for saving the inference model.
        This is a hook that can be overridden to change how the state dict is processed for the inference model.
        """
        return self.process_state_dict_for_saving(state_dict)

    def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        """Get the MSE loss function based on config."""

        def standard_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            return torch.nn.functional.mse_loss(preds, target, reduction="none")

        def batch_norm_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            target_centered = target - target.mean(dim=0, keepdim=True)
            normalization = target_centered.norm(dim=-1, keepdim=True)
            return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
                normalization + 1e-6
            )

        if self.cfg.mse_loss_normalization == "dense_batch":
            return batch_norm_mse_loss_fn
        return standard_mse_loss_fn

    @torch.no_grad()
    def remove_gradient_parallel_to_decoder_directions(self) -> None:
        """Remove gradient components parallel to decoder directions."""
        # Implement the original logic since this may not be in the base class
        assert self.W_dec.grad is not None

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_in, d_sae d_in -> d_sae",
        )
        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_in -> d_sae d_in",
        )

    @torch.no_grad()
    def log_histograms(self) -> dict[str, NDArray[Any]]:
        """Log histograms of the weights and biases."""
        W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
        return {
            "weights/W_dec_norms": W_dec_norm_dist,
        }

    @classmethod
    def get_sae_class_for_architecture(
        cls: Type[T_TRAINING_SAE], architecture: str
    ) -> Type[T_TRAINING_SAE]:
        """Get the SAE class for a given architecture."""
        sae_cls, _ = get_sae_training_class(architecture)
        if not issubclass(sae_cls, cls):
            raise ValueError(
                f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
            )
        return sae_cls

    # in the future, this can be used to load different config classes for different architectures
    @classmethod
    def get_sae_config_class_for_architecture(
        cls,
        architecture: str,  # noqa: ARG003
    ) -> type[TrainingSAEConfig]:
        return get_sae_training_class(architecture)[1]

calculate_aux_loss(step_input, feature_acts, hidden_pre, sae_out) abstractmethod

Calculate architecture-specific auxiliary loss terms.

Source code in sae_lens/saes/sae.py
@abstractmethod
def calculate_aux_loss(
    self,
    step_input: TrainStepInput,
    feature_acts: torch.Tensor,
    hidden_pre: torch.Tensor,
    sae_out: torch.Tensor,
) -> torch.Tensor | dict[str, torch.Tensor]:
    """Calculate architecture-specific auxiliary loss terms."""
    ...

decode(feature_acts)

Decodes feature activations back into input space, applying optional finetuning scale, hooking, out normalization, etc.

Source code in sae_lens/saes/sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """
    Decodes feature activations back into input space,
    applying optional finetuning scale, hooking, out normalization, etc.
    """
    sae_out_pre = feature_acts @ self.W_dec + self.b_dec
    sae_out_pre = self.hook_sae_recons(sae_out_pre)
    sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
    return self.reshape_fn_out(sae_out_pre, self.d_head)

encode(x)

For inference, just encode without returning hidden_pre. (training_forward_pass calls encode_with_hidden_pre).

Source code in sae_lens/saes/sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    For inference, just encode without returning hidden_pre.
    (training_forward_pass calls encode_with_hidden_pre).
    """
    feature_acts, _ = self.encode_with_hidden_pre(x)
    return feature_acts

encode_with_hidden_pre(x) abstractmethod

Encode with access to pre-activation values for training.

Source code in sae_lens/saes/sae.py
@abstractmethod
def encode_with_hidden_pre(
    self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
    """Encode with access to pre-activation values for training."""
    ...

get_sae_class_for_architecture(architecture) classmethod

Get the SAE class for a given architecture.

Source code in sae_lens/saes/sae.py
@classmethod
def get_sae_class_for_architecture(
    cls: Type[T_TRAINING_SAE], architecture: str
) -> Type[T_TRAINING_SAE]:
    """Get the SAE class for a given architecture."""
    sae_cls, _ = get_sae_training_class(architecture)
    if not issubclass(sae_cls, cls):
        raise ValueError(
            f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
        )
    return sae_cls

log_histograms()

Log histograms of the weights and biases.

Source code in sae_lens/saes/sae.py
@torch.no_grad()
def log_histograms(self) -> dict[str, NDArray[Any]]:
    """Log histograms of the weights and biases."""
    W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
    return {
        "weights/W_dec_norms": W_dec_norm_dist,
    }

process_state_dict_for_saving_inference(state_dict)

Process the state dict for saving the inference model. This is a hook that can be overridden to change how the state dict is processed for the inference model.

Source code in sae_lens/saes/sae.py
def process_state_dict_for_saving_inference(
    self, state_dict: dict[str, Any]
) -> None:
    """
    Process the state dict for saving the inference model.
    This is a hook that can be overridden to change how the state dict is processed for the inference model.
    """
    return self.process_state_dict_for_saving(state_dict)

remove_gradient_parallel_to_decoder_directions()

Remove gradient components parallel to decoder directions.

Source code in sae_lens/saes/sae.py
@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self) -> None:
    """Remove gradient components parallel to decoder directions."""
    # Implement the original logic since this may not be in the base class
    assert self.W_dec.grad is not None

    parallel_component = einops.einsum(
        self.W_dec.grad,
        self.W_dec.data,
        "d_sae d_in, d_sae d_in -> d_sae",
    )
    self.W_dec.grad -= einops.einsum(
        parallel_component,
        self.W_dec.data,
        "d_sae, d_sae d_in -> d_sae d_in",
    )

save_inference_model(path)

Save inference version of model weights and config to disk.

Source code in sae_lens/saes/sae.py
def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
    """Save inference version of model weights and config to disk."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Generate the weights
    state_dict = self.state_dict()  # Use internal SAE state dict
    self.process_state_dict_for_saving_inference(state_dict)
    model_weights_path = path / SAE_WEIGHTS_FILENAME
    save_file(state_dict, model_weights_path)

    # Save the config
    config = self.to_inference_config_dict()
    cfg_path = path / SAE_CFG_FILENAME
    with open(cfg_path, "w") as f:
        json.dump(config, f)

    return model_weights_path, cfg_path

to_inference_config_dict() abstractmethod

Convert the config into an inference SAE config dict.

Source code in sae_lens/saes/sae.py
@abstractmethod
def to_inference_config_dict(self) -> dict[str, Any]:
    """Convert the config into an inference SAE config dict."""
    ...

training_forward_pass(step_input)

Forward pass during training.

Source code in sae_lens/saes/sae.py
def training_forward_pass(
    self,
    step_input: TrainStepInput,
) -> TrainStepOutput:
    """Forward pass during training."""
    feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
    sae_out = self.decode(feature_acts)

    # Calculate MSE loss
    per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
    mse_loss = per_item_mse_loss.sum(dim=-1).mean()

    # Calculate architecture-specific auxiliary losses
    aux_losses = self.calculate_aux_loss(
        step_input=step_input,
        feature_acts=feature_acts,
        hidden_pre=hidden_pre,
        sae_out=sae_out,
    )

    # Total loss is MSE plus all auxiliary losses
    total_loss = mse_loss

    # Create losses dictionary with mse_loss
    losses = {"mse_loss": mse_loss}

    # Add architecture-specific losses to the dictionary
    # Make sure aux_losses is a dictionary with string keys and tensor values
    if isinstance(aux_losses, dict):
        losses.update(aux_losses)

    # Sum all losses for total_loss
    if isinstance(aux_losses, dict):
        for loss_value in aux_losses.values():
            total_loss = total_loss + loss_value
    else:
        # Handle case where aux_losses is a tensor
        total_loss = total_loss + aux_losses

    return TrainStepOutput(
        sae_in=step_input.sae_in,
        sae_out=sae_out,
        feature_acts=feature_acts,
        hidden_pre=hidden_pre,
        loss=total_loss,
        losses=losses,
    )

TrainingSAEConfig dataclass

Bases: SAEConfig, ABC

Source code in sae_lens/saes/sae.py
@dataclass(kw_only=True)
class TrainingSAEConfig(SAEConfig, ABC):
    noise_scale: float = 0.0
    mse_loss_normalization: str | None = None
    b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
    # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
    # 0.1 corresponds to the "heuristic" initialization, use None to disable
    decoder_init_norm: float | None = 0.1

    @classmethod
    @abstractmethod
    def architecture(cls) -> str: ...

    @classmethod
    def from_sae_runner_config(
        cls: type[T_TRAINING_SAE_CONFIG],
        cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
    ) -> T_TRAINING_SAE_CONFIG:
        metadata = SAEMetadata(
            model_name=cfg.model_name,
            hook_name=cfg.hook_name,
            hook_layer=cfg.hook_layer,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            prepend_bos=cfg.prepend_bos,
            seqpos_slice=cfg.seqpos_slice,
            model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
        )
        if not isinstance(cfg.sae, cls):
            raise ValueError(
                f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
            )
        return replace(cfg.sae, metadata=metadata)

    @classmethod
    def from_dict(
        cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
    ) -> T_TRAINING_SAE_CONFIG:
        # remove any keys that are not in the dataclass
        # since we sometimes enhance the config with the whole LM runner config
        valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
        cfg_class = cls
        if "architecture" in config_dict:
            cfg_class = get_sae_training_class(config_dict["architecture"])[1]
        if not issubclass(cfg_class, cls):
            raise ValueError(
                f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
            )
        if "metadata" in config_dict:
            valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
        return cfg_class(**valid_config_dict)

    def to_dict(self) -> dict[str, Any]:
        return {
            **super().to_dict(),
            **asdict(self),
            "architecture": self.architecture(),
        }

    # this needs to exist so we can initialize the parent sae cfg without the training specific
    # parameters. Maybe there's a cleaner way to do this
    def get_base_sae_cfg_dict(self) -> dict[str, Any]:
        """
        Creates a dictionary containing attributes corresponding to the fields
        defined in the base SAEConfig class.
        """
        base_config_field_names = {f.name for f in fields(SAEConfig)}
        result_dict = {
            field_name: getattr(self, field_name)
            for field_name in base_config_field_names
        }
        result_dict["architecture"] = self.architecture()
        return result_dict

get_base_sae_cfg_dict()

Creates a dictionary containing attributes corresponding to the fields defined in the base SAEConfig class.

Source code in sae_lens/saes/sae.py
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
    """
    Creates a dictionary containing attributes corresponding to the fields
    defined in the base SAEConfig class.
    """
    base_config_field_names = {f.name for f in fields(SAEConfig)}
    result_dict = {
        field_name: getattr(self, field_name)
        for field_name in base_config_field_names
    }
    result_dict["architecture"] = self.architecture()
    return result_dict