Skip to content

API

ActivationsStore

Class for streaming tokens and generating and storing activations while training SAEs.

Source code in sae_lens/training/activations_store.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
class ActivationsStore:
    """
    Class for streaming tokens and generating and storing activations
    while training SAEs.
    """

    model: HookedRootModule
    dataset: HfDataset
    cached_activations_path: str | None
    tokens_column: Literal["tokens", "input_ids", "text"]
    hook_name: str
    hook_layer: int
    hook_head_index: int | None
    _dataloader: Iterator[Any] | None = None
    _storage_buffer: torch.Tensor | None = None
    device: torch.device

    @classmethod
    def from_config(
        cls,
        model: HookedRootModule,
        cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
        dataset: HfDataset | None = None,
    ) -> "ActivationsStore":
        cached_activations_path = cfg.cached_activations_path
        # set cached_activations_path to None if we're not using cached activations
        if (
            isinstance(cfg, LanguageModelSAERunnerConfig)
            and not cfg.use_cached_activations
        ):
            cached_activations_path = None
        return cls(
            model=model,
            dataset=dataset or cfg.dataset_path,
            streaming=cfg.streaming,
            hook_name=cfg.hook_name,
            hook_layer=cfg.hook_layer,
            hook_head_index=cfg.hook_head_index,
            context_size=cfg.context_size,
            d_in=cfg.d_in,
            n_batches_in_buffer=cfg.n_batches_in_buffer,
            total_training_tokens=cfg.training_tokens,
            store_batch_size_prompts=cfg.store_batch_size_prompts,
            train_batch_size_tokens=cfg.train_batch_size_tokens,
            prepend_bos=cfg.prepend_bos,
            normalize_activations=cfg.normalize_activations,
            device=torch.device(cfg.act_store_device),
            dtype=cfg.dtype,
            cached_activations_path=cached_activations_path,
            model_kwargs=cfg.model_kwargs,
            autocast_lm=cfg.autocast_lm,
            dataset_trust_remote_code=cfg.dataset_trust_remote_code,
        )

    @classmethod
    def from_sae(
        cls,
        model: HookedRootModule,
        sae: SAE,
        streaming: bool = True,
        store_batch_size_prompts: int = 8,
        n_batches_in_buffer: int = 8,
        train_batch_size_tokens: int = 4096,
        total_tokens: int = 10**9,
        device: str = "cpu",
    ) -> "ActivationsStore":

        return cls(
            model=model,
            dataset=sae.cfg.dataset_path,
            d_in=sae.cfg.d_in,
            hook_name=sae.cfg.hook_name,
            hook_layer=sae.cfg.hook_layer,
            hook_head_index=sae.cfg.hook_head_index,
            context_size=sae.cfg.context_size,
            prepend_bos=sae.cfg.prepend_bos,
            streaming=streaming,
            store_batch_size_prompts=store_batch_size_prompts,
            train_batch_size_tokens=train_batch_size_tokens,
            n_batches_in_buffer=n_batches_in_buffer,
            total_training_tokens=total_tokens,
            normalize_activations=sae.cfg.normalize_activations,
            dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
            dtype=sae.cfg.dtype,
            device=torch.device(device),
        )

    def __init__(
        self,
        model: HookedRootModule,
        dataset: HfDataset | str,
        streaming: bool,
        hook_name: str,
        hook_layer: int,
        hook_head_index: int | None,
        context_size: int,
        d_in: int,
        n_batches_in_buffer: int,
        total_training_tokens: int,
        store_batch_size_prompts: int,
        train_batch_size_tokens: int,
        prepend_bos: bool,
        normalize_activations: str,
        device: torch.device,
        dtype: str,
        cached_activations_path: str | None = None,
        model_kwargs: dict[str, Any] | None = None,
        autocast_lm: bool = False,
        dataset_trust_remote_code: bool | None = None,
    ):
        self.model = model
        if model_kwargs is None:
            model_kwargs = {}
        self.model_kwargs = model_kwargs
        self.dataset = (
            load_dataset(
                dataset,
                split="train",
                streaming=streaming,
                trust_remote_code=dataset_trust_remote_code,  # type: ignore
            )
            if isinstance(dataset, str)
            else dataset
        )
        self.hook_name = hook_name
        self.hook_layer = hook_layer
        self.hook_head_index = hook_head_index
        self.context_size = context_size
        self.d_in = d_in
        self.n_batches_in_buffer = n_batches_in_buffer
        self.total_training_tokens = total_training_tokens
        self.store_batch_size_prompts = store_batch_size_prompts
        self.train_batch_size_tokens = train_batch_size_tokens
        self.prepend_bos = prepend_bos
        self.normalize_activations = normalize_activations
        self.device = torch.device(device)
        self.dtype = DTYPE_MAP[dtype]
        self.cached_activations_path = cached_activations_path
        self.autocast_lm = autocast_lm

        self.n_dataset_processed = 0
        self.iterable_dataset = iter(self.dataset)

        self.estimated_norm_scaling_factor = 1.0

        # Check if dataset is tokenized
        dataset_sample = next(self.iterable_dataset)

        # check if it's tokenized
        if "tokens" in dataset_sample.keys():
            self.is_dataset_tokenized = True
            self.tokens_column = "tokens"
        elif "input_ids" in dataset_sample.keys():
            self.is_dataset_tokenized = True
            self.tokens_column = "input_ids"
        elif "text" in dataset_sample.keys():
            self.is_dataset_tokenized = False
            self.tokens_column = "text"
        else:
            raise ValueError(
                "Dataset must have a 'tokens', 'input_ids', or 'text' column."
            )
        self.iterable_dataset = iter(self.dataset)  # Reset iterator after checking

        self.check_cached_activations_against_config()

        # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

    def check_cached_activations_against_config(self):

        if self.cached_activations_path is not None:  # EDIT: load from multi-layer acts
            assert self.cached_activations_path is not None  # keep pyright happy
            # Sanity check: does the cache directory exist?
            assert os.path.exists(
                self.cached_activations_path
            ), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

            self.next_cache_idx = 0  # which file to open next
            self.next_idx_within_buffer = 0  # where to start reading from in that file

            # Check that we have enough data on disk
            first_buffer = self.load_buffer(
                f"{self.cached_activations_path}/{self.next_cache_idx}.safetensors"
            )

            buffer_size_on_disk = first_buffer.shape[0]
            n_buffers_on_disk = len(os.listdir(self.cached_activations_path))

            # Note: we're assuming all files have the same number of tokens
            # (which seems reasonable imo since that's what our script does)
            n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk
            assert (
                n_activations_on_disk >= self.total_training_tokens
            ), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but total_training_tokens is {self.total_training_tokens/1e6:.1f}M."

    def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        return activations * self.estimated_norm_scaling_factor

    def unscale(self, activations: torch.Tensor) -> torch.Tensor:
        return activations / self.estimated_norm_scaling_factor

    def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
        return (self.d_in**0.5) / activations.norm(dim=-1).mean()

    @torch.no_grad()
    def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):

        norms_per_batch = []
        for _ in tqdm(
            range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
        ):
            acts = self.next_batch()
            norms_per_batch.append(acts.norm(dim=-1).mean().item())
        mean_norm = np.mean(norms_per_batch)
        scaling_factor = np.sqrt(self.d_in) / mean_norm

        return scaling_factor

    @property
    def storage_buffer(self) -> torch.Tensor:
        if self._storage_buffer is None:
            self._storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)

        return self._storage_buffer

    @property
    def dataloader(self) -> Iterator[Any]:
        if self._dataloader is None:
            self._dataloader = self.get_data_loader()
        return self._dataloader

    def get_batch_tokens(self, batch_size: int | None = None):
        """
        Streams a batch of tokens from a dataset.
        """
        if not batch_size:
            batch_size = self.store_batch_size_prompts
        context_size = self.context_size
        device = self.device

        batch_tokens = torch.zeros(
            size=(0, context_size), device=device, dtype=torch.long, requires_grad=False
        )

        current_batch = []
        current_length = 0

        while batch_tokens.shape[0] < batch_size:
            tokens = self._get_next_dataset_tokens()
            token_len = tokens.shape[0]

            # TODO: Fix this so that we are limiting how many tokens we get from the same context.
            assert self.model.tokenizer is not None  # keep pyright happy
            while token_len > 0 and batch_tokens.shape[0] < batch_size:
                # Space left in the current batch
                space_left = context_size - current_length

                # If the current tokens fit entirely into the remaining space
                if token_len <= space_left:
                    current_batch.append(tokens[:token_len])
                    current_length += token_len
                    break

                else:
                    # Take as much as will fit
                    current_batch.append(tokens[:space_left])

                    # Remove used part, add BOS
                    tokens = tokens[space_left:]
                    token_len -= space_left

                    # only add BOS if it's not already the first token
                    if self.prepend_bos:
                        bos_token_id_tensor = torch.tensor(
                            [self.model.tokenizer.bos_token_id],
                            device=tokens.device,
                            dtype=torch.long,
                        )
                        if tokens[0] != bos_token_id_tensor:
                            tokens = torch.cat(
                                (
                                    bos_token_id_tensor,
                                    tokens,
                                ),
                                dim=0,
                            )
                            token_len += 1
                    current_length = context_size

                # If a batch is full, concatenate and move to next batch
                if current_length == context_size:
                    full_batch = torch.cat(current_batch, dim=0)
                    batch_tokens = torch.cat(
                        (batch_tokens, full_batch.unsqueeze(0)), dim=0
                    )
                    current_batch = []
                    current_length = 0

            # pbar.n = batch_tokens.shape[0]
            # pbar.refresh()
        return batch_tokens[:batch_size].to(self.model.W_E.device)

    @torch.no_grad()
    def get_activations(self, batch_tokens: torch.Tensor):
        """
        Returns activations of shape (batches, context, num_layers, d_in)

        d_in may result from a concatenated head dimension.
        """

        # Setup autocast if using
        if self.autocast_lm:
            autocast_if_enabled = torch.autocast(
                device_type="cuda",
                dtype=torch.bfloat16,
                enabled=self.autocast_lm,
            )
        else:
            autocast_if_enabled = contextlib.nullcontext()

        with autocast_if_enabled:
            layerwise_activations = self.model.run_with_cache(
                batch_tokens,
                names_filter=[self.hook_name],
                stop_at_layer=self.hook_layer + 1,
                prepend_bos=self.prepend_bos,
                **self.model_kwargs,
            )[1]

        n_batches, n_context = batch_tokens.shape

        stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

        if self.hook_head_index is not None:
            stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][
                :, :, self.hook_head_index
            ]
        elif (
            layerwise_activations[self.hook_name].ndim > 3
        ):  # if we have a head dimension
            stacked_activations[:, :, 0] = layerwise_activations[self.hook_name].view(
                n_batches, n_context, -1
            )
        else:
            stacked_activations[:, :, 0] = layerwise_activations[self.hook_name]

        return stacked_activations

    @torch.no_grad()
    def get_buffer(self, n_batches_in_buffer: int) -> torch.Tensor:
        context_size = self.context_size
        batch_size = self.store_batch_size_prompts
        d_in = self.d_in
        total_size = batch_size * n_batches_in_buffer
        num_layers = 1

        if self.cached_activations_path is not None:
            # Load the activations from disk
            buffer_size = total_size * context_size
            # Initialize an empty tensor with an additional dimension for layers
            new_buffer = torch.zeros(
                (buffer_size, num_layers, d_in),
                dtype=self.dtype,  # type: ignore
                device=self.device,
            )
            n_tokens_filled = 0

            # Assume activations for different layers are stored separately and need to be combined
            while n_tokens_filled < buffer_size:
                if not os.path.exists(
                    f"{self.cached_activations_path}/{self.next_cache_idx}.safetensors"
                ):
                    print(
                        "\n\nWarning: Ran out of cached activation files earlier than expected."
                    )
                    print(
                        f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}."
                    )
                    if buffer_size % self.total_training_tokens != 0:
                        print(
                            "This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens"
                        )
                    print(f"Returning a buffer of size {n_tokens_filled} instead.")
                    print("\n\n")
                    new_buffer = new_buffer[:n_tokens_filled, ...]
                    return new_buffer

                activations = self.load_buffer(
                    f"{self.cached_activations_path}/{self.next_cache_idx}.safetensors"
                )
                taking_subset_of_file = False
                if n_tokens_filled + activations.shape[0] > buffer_size:
                    activations = activations[: buffer_size - n_tokens_filled, ...]
                    taking_subset_of_file = True

                new_buffer[
                    n_tokens_filled : n_tokens_filled + activations.shape[0], ...
                ] = activations

                if taking_subset_of_file:
                    self.next_idx_within_buffer = activations.shape[0]
                else:
                    self.next_cache_idx += 1
                    self.next_idx_within_buffer = 0

                n_tokens_filled += activations.shape[0]

            return new_buffer

        refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
        # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
        new_buffer = torch.zeros(
            (total_size, context_size, num_layers, d_in),
            dtype=self.dtype,  # type: ignore
            device=self.device,
        )

        for refill_batch_idx_start in refill_iterator:
            # move batch toks to gpu for model
            refill_batch_tokens = self.get_batch_tokens().to(self.model.cfg.device)
            refill_activations = self.get_activations(refill_batch_tokens)
            # move acts back to cpu
            refill_activations.to(self.device)
            new_buffer[
                refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
            ] = refill_activations

            # pbar.update(1)

        new_buffer = new_buffer.reshape(-1, num_layers, d_in)
        new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

        # every buffer should be normalized:
        if self.normalize_activations == "expected_average_only_in":
            new_buffer = self.apply_norm_scaling_factor(new_buffer)

        return new_buffer

    def save_buffer(self, buffer: torch.Tensor, path: str):
        """
        Used by cached activations runner to save a buffer to disk.
        For reuse by later workflows.
        """
        save_file({"activations": buffer}, path)

    def load_buffer(self, path: str) -> torch.Tensor:

        with safe_open(path, framework="pt", device=str(self.device)) as f:  # type: ignore
            buffer = f.get_tensor("activations")
        return buffer

    def get_data_loader(
        self,
    ) -> Iterator[Any]:
        """
        Return a torch.utils.dataloader which you can get batches from.

        Should automatically refill the buffer when it gets to n % full.
        (better mixing if you refill and shuffle regularly).

        """

        batch_size = self.train_batch_size_tokens

        # 1. # create new buffer by mixing stored and new buffer
        mixing_buffer = torch.cat(
            [self.get_buffer(self.n_batches_in_buffer // 2), self.storage_buffer],
            dim=0,
        )

        mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

        # 2.  put 50 % in storage
        self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

        # 3. put other 50 % in a dataloader
        dataloader = iter(
            DataLoader(
                # TODO: seems like a typing bug?
                cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
                batch_size=batch_size,
                shuffle=True,
            )
        )

        return dataloader

    def next_batch(self):
        """
        Get the next batch from the current DataLoader.
        If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
        """
        try:
            # Try to get the next batch
            return next(self.dataloader)
        except StopIteration:
            # If the DataLoader is exhausted, create a new one
            self._dataloader = self.get_data_loader()
            return next(self.dataloader)

    def state_dict(self) -> dict[str, torch.Tensor]:
        result = {
            "n_dataset_processed": torch.tensor(self.n_dataset_processed),
        }
        if self._storage_buffer is not None:  # first time might be None
            result["storage_buffer"] = self._storage_buffer
        return result

    def save(self, file_path: str):
        save_file(self.state_dict(), file_path)

    def _get_next_dataset_tokens(self) -> torch.Tensor:
        device = self.device
        if not self.is_dataset_tokenized:
            s = next(self.iterable_dataset)[self.tokens_column]
            tokens = (
                self.model.to_tokens(
                    s,
                    truncate=False,
                    move_to_device=True,
                    prepend_bos=self.prepend_bos,
                )
                .squeeze(0)
                .to(device)
            )
            assert (
                len(tokens.shape) == 1
            ), f"tokens.shape should be 1D but was {tokens.shape}"
        else:
            tokens = torch.tensor(
                next(self.iterable_dataset)[self.tokens_column],
                dtype=torch.long,
                device=device,
                requires_grad=False,
            )
            if (
                not self.prepend_bos
                and tokens[0] == self.model.tokenizer.bos_token_id  # type: ignore
            ):
                tokens = tokens[1:]
        self.n_dataset_processed += 1
        return tokens

get_activations(batch_tokens)

Returns activations of shape (batches, context, num_layers, d_in)

d_in may result from a concatenated head dimension.

Source code in sae_lens/training/activations_store.py
@torch.no_grad()
def get_activations(self, batch_tokens: torch.Tensor):
    """
    Returns activations of shape (batches, context, num_layers, d_in)

    d_in may result from a concatenated head dimension.
    """

    # Setup autocast if using
    if self.autocast_lm:
        autocast_if_enabled = torch.autocast(
            device_type="cuda",
            dtype=torch.bfloat16,
            enabled=self.autocast_lm,
        )
    else:
        autocast_if_enabled = contextlib.nullcontext()

    with autocast_if_enabled:
        layerwise_activations = self.model.run_with_cache(
            batch_tokens,
            names_filter=[self.hook_name],
            stop_at_layer=self.hook_layer + 1,
            prepend_bos=self.prepend_bos,
            **self.model_kwargs,
        )[1]

    n_batches, n_context = batch_tokens.shape

    stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

    if self.hook_head_index is not None:
        stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][
            :, :, self.hook_head_index
        ]
    elif (
        layerwise_activations[self.hook_name].ndim > 3
    ):  # if we have a head dimension
        stacked_activations[:, :, 0] = layerwise_activations[self.hook_name].view(
            n_batches, n_context, -1
        )
    else:
        stacked_activations[:, :, 0] = layerwise_activations[self.hook_name]

    return stacked_activations

get_batch_tokens(batch_size=None)

Streams a batch of tokens from a dataset.

Source code in sae_lens/training/activations_store.py
def get_batch_tokens(self, batch_size: int | None = None):
    """
    Streams a batch of tokens from a dataset.
    """
    if not batch_size:
        batch_size = self.store_batch_size_prompts
    context_size = self.context_size
    device = self.device

    batch_tokens = torch.zeros(
        size=(0, context_size), device=device, dtype=torch.long, requires_grad=False
    )

    current_batch = []
    current_length = 0

    while batch_tokens.shape[0] < batch_size:
        tokens = self._get_next_dataset_tokens()
        token_len = tokens.shape[0]

        # TODO: Fix this so that we are limiting how many tokens we get from the same context.
        assert self.model.tokenizer is not None  # keep pyright happy
        while token_len > 0 and batch_tokens.shape[0] < batch_size:
            # Space left in the current batch
            space_left = context_size - current_length

            # If the current tokens fit entirely into the remaining space
            if token_len <= space_left:
                current_batch.append(tokens[:token_len])
                current_length += token_len
                break

            else:
                # Take as much as will fit
                current_batch.append(tokens[:space_left])

                # Remove used part, add BOS
                tokens = tokens[space_left:]
                token_len -= space_left

                # only add BOS if it's not already the first token
                if self.prepend_bos:
                    bos_token_id_tensor = torch.tensor(
                        [self.model.tokenizer.bos_token_id],
                        device=tokens.device,
                        dtype=torch.long,
                    )
                    if tokens[0] != bos_token_id_tensor:
                        tokens = torch.cat(
                            (
                                bos_token_id_tensor,
                                tokens,
                            ),
                            dim=0,
                        )
                        token_len += 1
                current_length = context_size

            # If a batch is full, concatenate and move to next batch
            if current_length == context_size:
                full_batch = torch.cat(current_batch, dim=0)
                batch_tokens = torch.cat(
                    (batch_tokens, full_batch.unsqueeze(0)), dim=0
                )
                current_batch = []
                current_length = 0

        # pbar.n = batch_tokens.shape[0]
        # pbar.refresh()
    return batch_tokens[:batch_size].to(self.model.W_E.device)

get_data_loader()

Return a torch.utils.dataloader which you can get batches from.

Should automatically refill the buffer when it gets to n % full. (better mixing if you refill and shuffle regularly).

Source code in sae_lens/training/activations_store.py
def get_data_loader(
    self,
) -> Iterator[Any]:
    """
    Return a torch.utils.dataloader which you can get batches from.

    Should automatically refill the buffer when it gets to n % full.
    (better mixing if you refill and shuffle regularly).

    """

    batch_size = self.train_batch_size_tokens

    # 1. # create new buffer by mixing stored and new buffer
    mixing_buffer = torch.cat(
        [self.get_buffer(self.n_batches_in_buffer // 2), self.storage_buffer],
        dim=0,
    )

    mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

    # 2.  put 50 % in storage
    self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

    # 3. put other 50 % in a dataloader
    dataloader = iter(
        DataLoader(
            # TODO: seems like a typing bug?
            cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
            batch_size=batch_size,
            shuffle=True,
        )
    )

    return dataloader

next_batch()

Get the next batch from the current DataLoader. If the DataLoader is exhausted, refill the buffer and create a new DataLoader.

Source code in sae_lens/training/activations_store.py
def next_batch(self):
    """
    Get the next batch from the current DataLoader.
    If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
    """
    try:
        # Try to get the next batch
        return next(self.dataloader)
    except StopIteration:
        # If the DataLoader is exhausted, create a new one
        self._dataloader = self.get_data_loader()
        return next(self.dataloader)

save_buffer(buffer, path)

Used by cached activations runner to save a buffer to disk. For reuse by later workflows.

Source code in sae_lens/training/activations_store.py
def save_buffer(self, buffer: torch.Tensor, path: str):
    """
    Used by cached activations runner to save a buffer to disk.
    For reuse by later workflows.
    """
    save_file({"activations": buffer}, path)

CacheActivationsRunner

Source code in sae_lens/cache_activations_runner.py
class CacheActivationsRunner:

    def __init__(self, cfg: CacheActivationsRunnerConfig):
        self.cfg = cfg
        self.model = load_model(
            model_class_name=cfg.model_class_name,
            model_name=cfg.model_name,
            device=cfg.device,
            model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
        )
        self.activations_store = ActivationsStore.from_config(
            self.model,
            cfg,
        )

        self.file_extension = "safetensors"

    def __str__(self):
        """
        Print the number of tokens to be cached.
        Print the number of buffers, and the number of tokens per buffer.
        Print the disk space required to store the activations.

        """

        bytes_per_token = (
            self.cfg.d_in * self.cfg.dtype.itemsize
            if isinstance(self.cfg.dtype, torch.dtype)
            else DTYPE_MAP[self.cfg.dtype].itemsize
        )
        tokens_in_buffer = (
            self.cfg.n_batches_in_buffer
            * self.cfg.store_batch_size_prompts
            * self.cfg.context_size
        )
        total_training_tokens = self.cfg.training_tokens
        total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

        return (
            f"Activation Cache Runner:\n"
            f"Total training tokens: {total_training_tokens}\n"
            f"Number of buffers: {math.ceil(total_training_tokens / tokens_in_buffer)}\n"
            f"Tokens per buffer: {tokens_in_buffer}\n"
            f"Disk space required: {total_disk_space_gb:.2f} GB\n"
            f"Configuration:\n"
            f"{self.cfg}"
        )

    @torch.no_grad()
    def run(self):

        new_cached_activations_path = self.cfg.new_cached_activations_path

        # if the activations directory exists and has files in it, raise an exception
        assert new_cached_activations_path is not None
        if os.path.exists(new_cached_activations_path):
            if len(os.listdir(new_cached_activations_path)) > 0:
                raise Exception(
                    f"Activations directory ({new_cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
                )
        else:
            os.makedirs(new_cached_activations_path)

        print(f"Started caching {self.cfg.training_tokens} activations")
        tokens_per_buffer = (
            self.cfg.store_batch_size_prompts
            * self.cfg.context_size
            * self.cfg.n_batches_in_buffer
        )

        n_buffers = math.ceil(self.cfg.training_tokens / tokens_per_buffer)

        for i in tqdm(range(n_buffers), desc="Caching activations"):
            buffer = self.activations_store.get_buffer(self.cfg.n_batches_in_buffer)
            self.activations_store.save_buffer(
                buffer, f"{new_cached_activations_path}/{i}.safetensors"
            )

            del buffer

            if i % self.cfg.shuffle_every_n_buffers == 0 and i > 0:
                # Shuffle the buffers on disk

                # Do random pairwise shuffling between the last shuffle_every_n_buffers buffers
                for _ in range(self.cfg.n_shuffles_with_last_section):
                    self.shuffle_activations_pairwise(
                        new_cached_activations_path,
                        buffer_idx_range=(i - self.cfg.shuffle_every_n_buffers, i),
                    )

                # Do more random pairwise shuffling between all the buffers
                for _ in range(self.cfg.n_shuffles_in_entire_dir):
                    self.shuffle_activations_pairwise(
                        new_cached_activations_path,
                        buffer_idx_range=(0, i),
                    )

        # More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
        if n_buffers > 1:
            for _ in tqdm(range(self.cfg.n_shuffles_final), desc="Final shuffling"):
                self.shuffle_activations_pairwise(
                    new_cached_activations_path,
                    buffer_idx_range=(0, n_buffers),
                )

    @torch.no_grad()
    def shuffle_activations_pairwise(
        self, datapath: str, buffer_idx_range: Tuple[int, int]
    ):
        """
        Shuffles two buffers on disk.
        """
        assert (
            buffer_idx_range[0] < buffer_idx_range[1] - 1
        ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1"

        buffer_idx1 = torch.randint(
            buffer_idx_range[0], buffer_idx_range[1], (1,)
        ).item()
        buffer_idx2 = torch.randint(
            buffer_idx_range[0], buffer_idx_range[1], (1,)
        ).item()
        while buffer_idx1 == buffer_idx2:  # Make sure they're not the same
            buffer_idx2 = torch.randint(
                buffer_idx_range[0], buffer_idx_range[1], (1,)
            ).item()

        buffer1 = self.activations_store.load_buffer(
            f"{datapath}/{buffer_idx1}.{self.file_extension}"
        )
        buffer2 = self.activations_store.load_buffer(
            f"{datapath}/{buffer_idx2}.{self.file_extension}"
        )
        joint_buffer = torch.cat([buffer1, buffer2])

        # Shuffle them
        joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
        shuffled_buffer1 = joint_buffer[: buffer1.shape[0]]
        shuffled_buffer2 = joint_buffer[buffer1.shape[0] :]

        # Save them back
        self.activations_store.save_buffer(
            shuffled_buffer1, f"{datapath}/{buffer_idx1}.{self.file_extension}"
        )
        self.activations_store.save_buffer(
            shuffled_buffer2, f"{datapath}/{buffer_idx2}.{self.file_extension}"
        )

__str__()

Print the number of tokens to be cached. Print the number of buffers, and the number of tokens per buffer. Print the disk space required to store the activations.

Source code in sae_lens/cache_activations_runner.py
def __str__(self):
    """
    Print the number of tokens to be cached.
    Print the number of buffers, and the number of tokens per buffer.
    Print the disk space required to store the activations.

    """

    bytes_per_token = (
        self.cfg.d_in * self.cfg.dtype.itemsize
        if isinstance(self.cfg.dtype, torch.dtype)
        else DTYPE_MAP[self.cfg.dtype].itemsize
    )
    tokens_in_buffer = (
        self.cfg.n_batches_in_buffer
        * self.cfg.store_batch_size_prompts
        * self.cfg.context_size
    )
    total_training_tokens = self.cfg.training_tokens
    total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

    return (
        f"Activation Cache Runner:\n"
        f"Total training tokens: {total_training_tokens}\n"
        f"Number of buffers: {math.ceil(total_training_tokens / tokens_in_buffer)}\n"
        f"Tokens per buffer: {tokens_in_buffer}\n"
        f"Disk space required: {total_disk_space_gb:.2f} GB\n"
        f"Configuration:\n"
        f"{self.cfg}"
    )

shuffle_activations_pairwise(datapath, buffer_idx_range)

Shuffles two buffers on disk.

Source code in sae_lens/cache_activations_runner.py
@torch.no_grad()
def shuffle_activations_pairwise(
    self, datapath: str, buffer_idx_range: Tuple[int, int]
):
    """
    Shuffles two buffers on disk.
    """
    assert (
        buffer_idx_range[0] < buffer_idx_range[1] - 1
    ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1"

    buffer_idx1 = torch.randint(
        buffer_idx_range[0], buffer_idx_range[1], (1,)
    ).item()
    buffer_idx2 = torch.randint(
        buffer_idx_range[0], buffer_idx_range[1], (1,)
    ).item()
    while buffer_idx1 == buffer_idx2:  # Make sure they're not the same
        buffer_idx2 = torch.randint(
            buffer_idx_range[0], buffer_idx_range[1], (1,)
        ).item()

    buffer1 = self.activations_store.load_buffer(
        f"{datapath}/{buffer_idx1}.{self.file_extension}"
    )
    buffer2 = self.activations_store.load_buffer(
        f"{datapath}/{buffer_idx2}.{self.file_extension}"
    )
    joint_buffer = torch.cat([buffer1, buffer2])

    # Shuffle them
    joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
    shuffled_buffer1 = joint_buffer[: buffer1.shape[0]]
    shuffled_buffer2 = joint_buffer[buffer1.shape[0] :]

    # Save them back
    self.activations_store.save_buffer(
        shuffled_buffer1, f"{datapath}/{buffer_idx1}.{self.file_extension}"
    )
    self.activations_store.save_buffer(
        shuffled_buffer2, f"{datapath}/{buffer_idx2}.{self.file_extension}"
    )

CacheActivationsRunnerConfig dataclass

Configuration for caching activations of an LLM.

Source code in sae_lens/config.py
@dataclass
class CacheActivationsRunnerConfig:
    """
    Configuration for caching activations of an LLM.
    """

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.{layer}.hook_mlp_out"
    hook_layer: int = 0
    hook_head_index: Optional[int] = None
    dataset_path: str = "NeelNanda/c4-tokenized-2b"
    dataset_trust_remote_code: bool | None = None
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    new_cached_activations_path: Optional[str] = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )
    # dont' specify this since you don't want to load from disk with the cache runner.
    cached_activations_path: Optional[str] = None
    # SAE Parameters
    d_in: int = 512

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    store_batch_size_prompts: int = 32
    train_batch_size_tokens: int = 4096
    normalize_activations: str = "none"  # should always be none for activation caching

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"
    prepend_bos: bool = True
    autocast_lm: bool = False  # autocast lm during activation fetching

    # Activation caching stuff
    shuffle_every_n_buffers: int = 10
    n_shuffles_with_last_section: int = 10
    n_shuffles_in_entire_dir: int = 10
    n_shuffles_final: int = 100
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        # Autofill cached_activations_path unless the user overrode it
        if self.new_cached_activations_path is None:
            self.new_cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

HookedSAETransformer

Bases: HookedTransformer

Source code in sae_lens/analysis/hooked_sae_transformer.py
class HookedSAETransformer(HookedTransformer):

    def __init__(
        self,
        *model_args: Any,
        **model_kwargs: Any,
    ):
        """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

        Note that if you want to load the model from pretrained weights, you should use
        :meth:`from_pretrained` instead.

        Args:
            *model_args: Positional arguments for HookedTransformer initialization
            **model_kwargs: Keyword arguments for HookedTransformer initialization
        """
        super().__init__(*model_args, **model_kwargs)
        self.acts_to_saes: Dict[str, SAE] = {}

    def add_sae(self, sae: SAE):
        """Attaches an SAE to the model

        WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

        Args:
            sae: SparseAutoencoderBase. The SAE to attach to the model
        """
        act_name = sae.cfg.hook_name
        if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
            logging.warning(
                f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
            )
            return

        self.acts_to_saes[act_name] = sae
        set_deep_attr(self, act_name, sae)
        self.setup()

    def _reset_sae(self, act_name: str, prev_sae: Optional[SAE] = None):
        """Resets an SAE that was attached to the model

        By default will remove the SAE from that hook_point.
        If prev_sae is provided, will replace the current SAE with the provided one.
        This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)

        Args:
            act_name: str. The hook_name of the SAE to reset
            prev_sae: Optional[HookedSAE]. The SAE to replace the current one with. If None, will just remove the SAE from this hook point. Defaults to None
        """
        if act_name not in self.acts_to_saes:
            logging.warning(
                f"No SAE is attached to {act_name}. There's nothing to reset."
            )
            return

        if prev_sae:
            set_deep_attr(self, act_name, prev_sae)
            self.acts_to_saes[act_name] = prev_sae
        else:
            set_deep_attr(self, act_name, HookPoint())
            del self.acts_to_saes[act_name]

    def reset_saes(
        self,
        act_names: Optional[Union[str, List[str]]] = None,
        prev_saes: Optional[List[Union[SAE, None]]] = None,
    ):
        """Reset the SAEs attached to the model

        If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
        Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

        Args:
            act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
            prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
        """
        if isinstance(act_names, str):
            act_names = [act_names]
        elif act_names is None:
            act_names = list(self.acts_to_saes.keys())

        if prev_saes:
            assert len(act_names) == len(
                prev_saes
            ), "act_names and prev_saes must have the same length"
        else:
            prev_saes = [None] * len(act_names)  # type: ignore

        for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
            self._reset_sae(act_name, prev_sae)

        self.setup()

    def run_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        **model_kwargs: Any,
    ) -> Union[
        None,
        Float[torch.Tensor, "batch pos d_vocab"],
        Loss,
        Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
    ]:
        """Wrapper around HookedTransformer forward pass.

        Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self(*model_args, **model_kwargs)

    def run_with_cache_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        return_cache_object: bool = True,
        remove_batch_dim: bool = False,
        **kwargs: Any,
    ) -> Tuple[
        Union[
            None,
            Float[torch.Tensor, "batch pos d_vocab"],
            Loss,
            Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
        ],
        Union[ActivationCache, Dict[str, torch.Tensor]],
    ]:
        """Wrapper around 'run_with_cache' in HookedTransformer.

        Attaches given SAEs before running the model with cache and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
            return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
                useful HookedTransformer specific methods, otherwise it will return a dictionary of
                activations as in HookedRootModule.
            remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
            **kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self.run_with_cache(  # type: ignore
                *model_args,
                return_cache_object=return_cache_object,  # type: ignore
                remove_batch_dim=remove_batch_dim,
                **kwargs,
            )

    def run_with_hooks_with_saes(
        self,
        *model_args: Any,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
        fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
        bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
        reset_hooks_end: bool = True,
        clear_contexts: bool = False,
        **model_kwargs: Any,
    ):
        """Wrapper around 'run_with_hooks' in HookedTransformer.

        Attaches the given SAEs to the model before running the model with hooks and then removes them.
        By default, will reset all SAEs to original state after.

        Args:
            *model_args: Positional arguments for the model forward pass
            act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
            reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
            fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply
            bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply
            reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
            clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
            **model_kwargs: Keyword arguments for the model forward pass
        """
        with self.saes(saes=saes, reset_saes_end=reset_saes_end):
            return self.run_with_hooks(
                *model_args,
                fwd_hooks=fwd_hooks,
                bwd_hooks=bwd_hooks,
                reset_hooks_end=reset_hooks_end,
                clear_contexts=clear_contexts,
                **model_kwargs,
            )

    @contextmanager
    def saes(
        self,
        saes: Union[SAE, List[SAE]] = [],
        reset_saes_end: bool = True,
    ):
        """
        A context manager for adding temporary SAEs to the model.
        See HookedTransformer.hooks for a similar context manager for hooks.
        By default will keep track of previously attached SAEs, and restore them when the context manager exits.

        Example:

        .. code-block:: python

            from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

            model = HookedSAETransformer.from_pretrained('gpt2-small')
            sae_cfg = HookedSAEConfig(...)
            sae = HookedSAE(sae_cfg)
            with model.saes(saes=[sae]):
                spliced_logits = model(text)


        Args:
            saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
            reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
        """
        act_names_to_reset = []
        prev_saes = []
        if isinstance(saes, SAE):
            saes = [saes]
        try:
            for sae in saes:
                act_names_to_reset.append(sae.cfg.hook_name)
                prev_saes.append(self.acts_to_saes.get(sae.cfg.hook_name, None))
                self.add_sae(sae)
            yield self
        finally:
            if reset_saes_end:
                self.reset_saes(act_names_to_reset, prev_saes)

__init__(*model_args, **model_kwargs)

Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

Note that if you want to load the model from pretrained weights, you should use :meth:from_pretrained instead.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for HookedTransformer initialization

()
**model_kwargs Any

Keyword arguments for HookedTransformer initialization

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def __init__(
    self,
    *model_args: Any,
    **model_kwargs: Any,
):
    """Model initialization. Just HookedTransformer init, but adds a dictionary to keep track of attached SAEs.

    Note that if you want to load the model from pretrained weights, you should use
    :meth:`from_pretrained` instead.

    Args:
        *model_args: Positional arguments for HookedTransformer initialization
        **model_kwargs: Keyword arguments for HookedTransformer initialization
    """
    super().__init__(*model_args, **model_kwargs)
    self.acts_to_saes: Dict[str, SAE] = {}

add_sae(sae)

Attaches an SAE to the model

WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

Parameters:

Name Type Description Default
sae SAE

SparseAutoencoderBase. The SAE to attach to the model

required
Source code in sae_lens/analysis/hooked_sae_transformer.py
def add_sae(self, sae: SAE):
    """Attaches an SAE to the model

    WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.

    Args:
        sae: SparseAutoencoderBase. The SAE to attach to the model
    """
    act_name = sae.cfg.hook_name
    if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
        logging.warning(
            f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
        )
        return

    self.acts_to_saes[act_name] = sae
    set_deep_attr(self, act_name, sae)
    self.setup()

reset_saes(act_names=None, prev_saes=None)

Reset the SAEs attached to the model

If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model. Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

Parameters:

Name Type Description Default
act_names Optional[Union[str, List[str]]

The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.

None
prev_saes Optional[List[Union[HookedSAE, None]]]

List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.

None
Source code in sae_lens/analysis/hooked_sae_transformer.py
def reset_saes(
    self,
    act_names: Optional[Union[str, List[str]]] = None,
    prev_saes: Optional[List[Union[SAE, None]]] = None,
):
    """Reset the SAEs attached to the model

    If act_names are provided will just reset SAEs attached to those hooks. Otherwise will reset all SAEs attached to the model.
    Optionally can provide a list of prev_saes to reset to. This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes).

    Args:
        act_names (Optional[Union[str, List[str]]): The act_names of the SAEs to reset. If None, will reset all SAEs attached to the model. Defaults to None.
        prev_saes (Optional[List[Union[HookedSAE, None]]]): List of SAEs to replace the current ones with. If None, will just remove the SAEs. Defaults to None.
    """
    if isinstance(act_names, str):
        act_names = [act_names]
    elif act_names is None:
        act_names = list(self.acts_to_saes.keys())

    if prev_saes:
        assert len(act_names) == len(
            prev_saes
        ), "act_names and prev_saes must have the same length"
    else:
        prev_saes = [None] * len(act_names)  # type: ignore

    for act_name, prev_sae in zip(act_names, prev_saes):  # type: ignore
        self._reset_sae(act_name, prev_sae)

    self.setup()

run_with_cache_with_saes(*model_args, saes=[], reset_saes_end=True, return_cache_object=True, remove_batch_dim=False, **kwargs)

Wrapper around 'run_with_cache' in HookedTransformer.

Attaches given SAEs before running the model with cache and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes Union[SAE, List[SAE]]

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
return_cache_object bool

(bool) if True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule.

True
remove_batch_dim bool

(bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

False
**kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_cache_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    return_cache_object: bool = True,
    remove_batch_dim: bool = False,
    **kwargs: Any,
) -> Tuple[
    Union[
        None,
        Float[torch.Tensor, "batch pos d_vocab"],
        Loss,
        Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
    ],
    Union[ActivationCache, Dict[str, torch.Tensor]],
]:
    """Wrapper around 'run_with_cache' in HookedTransformer.

    Attaches given SAEs before running the model with cache and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
            useful HookedTransformer specific methods, otherwise it will return a dictionary of
            activations as in HookedRootModule.
        remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
        **kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self.run_with_cache(  # type: ignore
            *model_args,
            return_cache_object=return_cache_object,  # type: ignore
            remove_batch_dim=remove_batch_dim,
            **kwargs,
        )

run_with_hooks_with_saes(*model_args, saes=[], reset_saes_end=True, fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False, **model_kwargs)

Wrapper around 'run_with_hooks' in HookedTransformer.

Attaches the given SAEs to the model before running the model with hooks and then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
act_names

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

required
reset_saes_end bool

(bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)

True
fwd_hooks List[Tuple[Union[str, Callable], Callable]]

(List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply

[]
bwd_hooks List[Tuple[Union[str, Callable], Callable]]

(List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply

[]
reset_hooks_end bool

(bool) Whether to reset the hooks at the end of the forward pass (default: True)

True
clear_contexts bool

(bool) Whether to clear the contexts at the end of the forward pass (default: False)

False
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_hooks_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
    bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],  # type: ignore
    reset_hooks_end: bool = True,
    clear_contexts: bool = False,
    **model_kwargs: Any,
):
    """Wrapper around 'run_with_hooks' in HookedTransformer.

    Attaches the given SAEs to the model before running the model with hooks and then removes them.
    By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        act_names: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. (default: True)
        fwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of forward hooks to apply
        bwd_hooks: (List[Tuple[Union[str, Callable], Callable]]) List of backward hooks to apply
        reset_hooks_end: (bool) Whether to reset the hooks at the end of the forward pass (default: True)
        clear_contexts: (bool) Whether to clear the contexts at the end of the forward pass (default: False)
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self.run_with_hooks(
            *model_args,
            fwd_hooks=fwd_hooks,
            bwd_hooks=bwd_hooks,
            reset_hooks_end=reset_hooks_end,
            clear_contexts=clear_contexts,
            **model_kwargs,
        )

run_with_saes(*model_args, saes=[], reset_saes_end=True, **model_kwargs)

Wrapper around HookedTransformer forward pass.

Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

Parameters:

Name Type Description Default
*model_args Any

Positional arguments for the model forward pass

()
saes Union[SAE, List[SAE]]

(Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass

[]
reset_saes_end bool

If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.

True
**model_kwargs Any

Keyword arguments for the model forward pass

{}
Source code in sae_lens/analysis/hooked_sae_transformer.py
def run_with_saes(
    self,
    *model_args: Any,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
    **model_kwargs: Any,
) -> Union[
    None,
    Float[torch.Tensor, "batch pos d_vocab"],
    Loss,
    Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
]:
    """Wrapper around HookedTransformer forward pass.

    Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.

    Args:
        *model_args: Positional arguments for the model forward pass
        saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
        reset_saes_end (bool): If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
        **model_kwargs: Keyword arguments for the model forward pass
    """
    with self.saes(saes=saes, reset_saes_end=reset_saes_end):
        return self(*model_args, **model_kwargs)

saes(saes=[], reset_saes_end=True)

A context manager for adding temporary SAEs to the model. See HookedTransformer.hooks for a similar context manager for hooks. By default will keep track of previously attached SAEs, and restore them when the context manager exits.

Example:

.. code-block:: python

from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

model = HookedSAETransformer.from_pretrained('gpt2-small')
sae_cfg = HookedSAEConfig(...)
sae = HookedSAE(sae_cfg)
with model.saes(saes=[sae]):
    spliced_logits = model(text)

Parameters:

Name Type Description Default
saes Union[HookedSAE, List[HookedSAE]]

SAEs to be attached.

[]
reset_saes_end bool

If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.

True
Source code in sae_lens/analysis/hooked_sae_transformer.py
@contextmanager
def saes(
    self,
    saes: Union[SAE, List[SAE]] = [],
    reset_saes_end: bool = True,
):
    """
    A context manager for adding temporary SAEs to the model.
    See HookedTransformer.hooks for a similar context manager for hooks.
    By default will keep track of previously attached SAEs, and restore them when the context manager exits.

    Example:

    .. code-block:: python

        from transformer_lens import HookedSAETransformer, HookedSAE, HookedSAEConfig

        model = HookedSAETransformer.from_pretrained('gpt2-small')
        sae_cfg = HookedSAEConfig(...)
        sae = HookedSAE(sae_cfg)
        with model.saes(saes=[sae]):
            spliced_logits = model(text)


    Args:
        saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
        reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
    """
    act_names_to_reset = []
    prev_saes = []
    if isinstance(saes, SAE):
        saes = [saes]
    try:
        for sae in saes:
            act_names_to_reset.append(sae.cfg.hook_name)
            prev_saes.append(self.acts_to_saes.get(sae.cfg.hook_name, None))
            self.add_sae(sae)
        yield self
    finally:
        if reset_saes_end:
            self.reset_saes(act_names_to_reset, prev_saes)

LanguageModelSAERunnerConfig dataclass

Configuration for training a sparse autoencoder on a language model.

Parameters:

Name Type Description Default
model_name str

The name of the model to use. This should be the name of the model in the Hugging Face model hub.

'gelu-2l'
model_class_name str

The name of the class of the model to use. This should be either HookedTransformer or HookedMamba.

'HookedTransformer'
hook_name str

The name of the hook to use. This should be a valid TransformerLens hook.

'blocks.0.hook_mlp_out'
hook_eval str

NOT CURRENTLY IN USE. The name of the hook to use for evaluation.

'NOT_IN_USE'
hook_layer int

The index of the layer to hook. Used to stop forward passes early and speed up processing.

0
hook_head_index int

When the hook if for an activatio with a head index, we can specify a specific head to use here.

None
dataset_path str

A Hugging Face dataset path.

'NeelNanda/c4-tokenized-2b'
dataset_trust_remote_code bool

Whether to trust remote code when loading datasets from Huggingface.

True
streaming bool

Whether to stream the dataset. Streaming large datasets is usually practical.

True
is_dataset_tokenized bool

NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.

True
context_size int

The context size to use when generating activations on which to train the SAE.

128
use_cached_activations bool

Whether to use cached activations. This is useful when doing sweeps over the same activations.

False
cached_activations_path str

The path to the cached activations.

None
d_in int

The input dimension of the SAE.

512
d_sae int

The output dimension of the SAE. If None, defaults to d_in * expansion_factor.

None
b_dec_init_method str

The method to use to initialize the decoder bias. Zeros is likely fine.

'geometric_median'
expansion_factor int

The expansion factor. Larger is better but more computationally expensive.

4
activation_fn str

The activation function to use. Relu is standard.

'relu'
normalize_sae_decoder bool

Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.

True
noise_scale float

Using noise to induce sparsity is supported but not recommended.

0.0
from_pretrained_path str

The path to a pretrained SAE. We can finetune an existing SAE if needed.

None
apply_b_dec_to_input bool

Whether to apply the decoder bias to the input. Not currently advised.

True
decoder_orthogonal_init bool

Whether to use orthogonal initialization for the decoder. Not currently advised.

False
decoder_heuristic_init bool

Whether to use heuristic initialization for the decoder. See Anthropic April Update.

False
init_encoder_as_decoder_transpose bool

Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.

False
n_batches_in_buffer int

The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.

20
training_tokens int

The number of training tokens.

2000000
finetuning_tokens int

The number of finetuning tokens. See here

0
store_batch_size_prompts int

The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.

32
train_batch_size_tokens int

The batch size for training. This controls the batch size of the SAE Training loop.

4096
normalize_activations str

Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).

'none'
device str

The device to use. Usually cuda.

'cpu'
act_store_device str

The device to use for the activation store. CPU is advised in order to save vram.

'with_model'
seed int

The seed to use.

42
dtype str

The data type to use.

'float32'
prepend_bos bool

Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.

True
autocast bool

Whether to use autocast during training. Saves vram.

False
autocast_lm bool

Whether to use autocast during activation fetching.

False
compile_llm bool

Whether to compile the LLM.

False
llm_compilation_mode str

The compilation mode to use for the LLM.

None
compile_sae bool

Whether to compile the SAE.

False
sae_compilation_mode str

The compilation mode to use for the SAE.

None
train_batch_size_tokens int

The batch size for training.

4096
adam_beta1 float

The beta1 parameter for Adam.

0
adam_beta2 float

The beta2 parameter for Adam.

0.999
mse_loss_normalization str

The normalization to use for the MSE loss.

None
l1_coefficient float

The L1 coefficient.

0.001
lp_norm float

The Lp norm.

1
scale_sparsity_penalty_by_decoder_norm bool

Whether to scale the sparsity penalty by the decoder norm.

False
l1_warm_up_steps int

The number of warm-up steps for the L1 loss.

0
lr float

The learning rate.

0.0003
lr_scheduler_name str

The name of the learning rate scheduler to use.

'constant'
lr_warm_up_steps int

The number of warm-up steps for the learning rate.

0
lr_end float

The end learning rate for the cosine annealing scheduler.

None
lr_decay_steps int

The number of decay steps for the learning rate.

0
n_restart_cycles int

The number of restart cycles for the cosine annealing warm restarts scheduler.

1
finetuning_method str

The method to use for finetuning.

None
use_ghost_grads bool

Whether to use ghost gradients.

False
feature_sampling_window int

The feature sampling window.

2000
dead_feature_window int

The dead feature window.

1000
dead_feature_threshold float

The dead feature threshold.

1e-08
n_eval_batches int

The number of evaluation batches.

10
eval_batch_size_prompts int

The batch size for evaluation.

None
log_to_wandb bool

Whether to log to Weights & Biases.

True
log_activations_store_to_wandb bool

NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.

False
log_optimizer_state_to_wandb bool

NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.

False
wandb_project str

The Weights & Biases project to log to.

'mats_sae_training_language_model'
wandb_id str

The Weights & Biases ID.

None
run_name str

The name of the run.

None
wandb_entity str

The Weights & Biases entity.

None
wandb_log_frequency int

The frequency to log to Weights & Biases.

10
eval_every_n_wandb_logs int

The frequency to evaluate.

100
resume bool

Whether to resume training.

False
n_checkpoints int

The number of checkpoints.

0
checkpoint_path str

The path to save checkpoints.

'checkpoints'
verbose bool

Whether to print verbose output.

True
model_kwargs dict[str, Any]

Additional keyword arguments for the model.

dict()
model_from_pretrained_kwargs dict[str, Any]

Additional keyword arguments for the model from pretrained.

dict()
Source code in sae_lens/config.py
@dataclass
class LanguageModelSAERunnerConfig:
    """
    Configuration for training a sparse autoencoder on a language model.

    Args:
        model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
        model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
        hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
        hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
        hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
        hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
        dataset_path (str): A Hugging Face dataset path.
        dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
        streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
        is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
        context_size (int): The context size to use when generating activations on which to train the SAE.
        use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
        cached_activations_path (str, optional): The path to the cached activations.
        d_in (int): The input dimension of the SAE.
        d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
        b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
        expansion_factor (int): The expansion factor. Larger is better but more computationally expensive.
        activation_fn (str): The activation function to use. Relu is standard.
        normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
        noise_scale (float): Using noise to induce sparsity is supported but not recommended.
        from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
        apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
        decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
        decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
        init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
        n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
        training_tokens (int): The number of training tokens.
        finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
        store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
        train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
        normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
        device (str): The device to use. Usually cuda.
        act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
        seed (int): The seed to use.
        dtype (str): The data type to use.
        prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
        autocast (bool): Whether to use autocast during training. Saves vram.
        autocast_lm (bool): Whether to use autocast during activation fetching.
        compile_llm (bool): Whether to compile the LLM.
        llm_compilation_mode (str): The compilation mode to use for the LLM.
        compile_sae (bool): Whether to compile the SAE.
        sae_compilation_mode (str): The compilation mode to use for the SAE.
        train_batch_size_tokens (int): The batch size for training.
        adam_beta1 (float): The beta1 parameter for Adam.
        adam_beta2 (float): The beta2 parameter for Adam.
        mse_loss_normalization (str): The normalization to use for the MSE loss.
        l1_coefficient (float): The L1 coefficient.
        lp_norm (float): The Lp norm.
        scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
        l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
        lr (float): The learning rate.
        lr_scheduler_name (str): The name of the learning rate scheduler to use.
        lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
        lr_end (float): The end learning rate for the cosine annealing scheduler.
        lr_decay_steps (int): The number of decay steps for the learning rate.
        n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
        finetuning_method (str): The method to use for finetuning.
        use_ghost_grads (bool): Whether to use ghost gradients.
        feature_sampling_window (int): The feature sampling window.
        dead_feature_window (int): The dead feature window.
        dead_feature_threshold (float): The dead feature threshold.
        n_eval_batches (int): The number of evaluation batches.
        eval_batch_size_prompts (int): The batch size for evaluation.
        log_to_wandb (bool): Whether to log to Weights & Biases.
        log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
        log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
        wandb_project (str): The Weights & Biases project to log to.
        wandb_id (str): The Weights & Biases ID.
        run_name (str): The name of the run.
        wandb_entity (str): The Weights & Biases entity.
        wandb_log_frequency (int): The frequency to log to Weights & Biases.
        eval_every_n_wandb_logs (int): The frequency to evaluate.
        resume (bool): Whether to resume training.
        n_checkpoints (int): The number of checkpoints.
        checkpoint_path (str): The path to save checkpoints.
        verbose (bool): Whether to print verbose output.
        model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
        model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
    """

    # Data Generating Function (Model + Training Distibuion)
    model_name: str = "gelu-2l"
    model_class_name: str = "HookedTransformer"
    hook_name: str = "blocks.0.hook_mlp_out"
    hook_eval: str = "NOT_IN_USE"
    hook_layer: int = 0
    hook_head_index: Optional[int] = None
    dataset_path: str = "NeelNanda/c4-tokenized-2b"
    dataset_trust_remote_code: bool = True
    streaming: bool = True
    is_dataset_tokenized: bool = True
    context_size: int = 128
    use_cached_activations: bool = False
    cached_activations_path: Optional[str] = (
        None  # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}"
    )

    # SAE Parameters
    architecture: Literal["standard", "gated"] = "standard"
    d_in: int = 512
    d_sae: Optional[int] = None
    b_dec_init_method: str = "geometric_median"
    expansion_factor: int = 4
    activation_fn: str = "relu"  # relu, tanh-relu
    normalize_sae_decoder: bool = True
    noise_scale: float = 0.0
    from_pretrained_path: Optional[str] = None
    apply_b_dec_to_input: bool = True
    decoder_orthogonal_init: bool = False
    decoder_heuristic_init: bool = False
    init_encoder_as_decoder_transpose: bool = False

    # Activation Store Parameters
    n_batches_in_buffer: int = 20
    training_tokens: int = 2_000_000
    finetuning_tokens: int = 0
    store_batch_size_prompts: int = 32
    train_batch_size_tokens: int = 4096
    normalize_activations: str = (
        "none"  # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
    )

    # Misc
    device: str = "cpu"
    act_store_device: str = "with_model"  # will be set by post init if with_model
    seed: int = 42
    dtype: str = "float32"  # type: ignore #
    prepend_bos: bool = True

    # Performance - see compilation section of lm_runner.py for info
    autocast: bool = False  # autocast to autocast_dtype during training
    autocast_lm: bool = False  # autocast lm during activation fetching
    compile_llm: bool = False  # use torch.compile on the LLM
    llm_compilation_mode: str | None = None  # which torch.compile mode to use
    compile_sae: bool = False  # use torch.compile on the SAE
    sae_compilation_mode: str | None = None

    # Training Parameters

    ## Batch size
    train_batch_size_tokens: int = 4096

    ## Adam
    adam_beta1: float = 0
    adam_beta2: float = 0.999

    ## Loss Function
    mse_loss_normalization: Optional[str] = None
    l1_coefficient: float = 1e-3
    lp_norm: float = 1
    scale_sparsity_penalty_by_decoder_norm: bool = False
    l1_warm_up_steps: int = 0

    ## Learning Rate Schedule
    lr: float = 3e-4
    lr_scheduler_name: str = (
        "constant"  # constant, cosineannealing, cosineannealingwarmrestarts
    )
    lr_warm_up_steps: int = 0
    lr_end: Optional[float] = None  # only used for cosine annealing, default is lr / 10
    lr_decay_steps: int = 0
    n_restart_cycles: int = 1  # used only for cosineannealingwarmrestarts

    ## FineTuning
    finetuning_method: Optional[str] = None  # scale, decoder or unrotated_decoder

    # Resampling protocol args
    use_ghost_grads: bool = False  # want to change this to true on some timeline.
    feature_sampling_window: int = 2000
    dead_feature_window: int = 1000  # unless this window is larger feature sampling,

    dead_feature_threshold: float = 1e-8

    # Evals
    n_eval_batches: int = 10
    eval_batch_size_prompts: int | None = None  # useful if evals cause OOM

    # WANDB
    log_to_wandb: bool = True
    log_activations_store_to_wandb: bool = False
    log_optimizer_state_to_wandb: bool = False
    wandb_project: str = "mats_sae_training_language_model"
    wandb_id: Optional[str] = None
    run_name: Optional[str] = None
    wandb_entity: Optional[str] = None
    wandb_log_frequency: int = 10
    eval_every_n_wandb_logs: int = 100  # logs every 1000 steps.

    # Misc
    resume: bool = False
    n_checkpoints: int = 0
    checkpoint_path: str = "checkpoints"
    verbose: bool = True
    model_kwargs: dict[str, Any] = field(default_factory=dict)
    model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
    sae_lens_version: str = field(default_factory=lambda: __version__)
    sae_lens_training_version: str = field(default_factory=lambda: __version__)

    def __post_init__(self):

        if self.resume:
            raise ValueError(
                "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
                + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
            )

        if self.use_cached_activations and self.cached_activations_path is None:
            self.cached_activations_path = _default_cached_activations_path(
                self.dataset_path,
                self.model_name,
                self.hook_name,
                self.hook_head_index,
            )

        if not isinstance(self.expansion_factor, list):
            self.d_sae = self.d_in * self.expansion_factor
        self.tokens_per_buffer = (
            self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
        )

        if self.run_name is None:
            self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"

        if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
            raise ValueError(
                f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
            )

        if self.normalize_sae_decoder and self.decoder_heuristic_init:
            raise ValueError(
                "You can't normalize the decoder and use heuristic initialization."
            )

        if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
            raise ValueError(
                "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
            )

        # if we use decoder fine tuning, we can't be applying b_dec to the input
        if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
            raise ValueError(
                "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
            )

        if self.normalize_activations not in [
            "none",
            "expected_average_only_in",
            "constant_norm_rescale",
        ]:
            raise ValueError(
                f"normalize_activations must be none, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
            )

        if self.act_store_device == "with_model":
            self.act_store_device = self.device

        if self.lr_end is None:
            self.lr_end = self.lr / 10

        unique_id = self.wandb_id
        if unique_id is None:
            unique_id = cast(
                Any, wandb
            ).util.generate_id()  # not sure why this type is erroring
        self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

        if self.verbose:
            print(
                f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
            )
            # Print out some useful info:
            n_tokens_per_buffer = (
                self.store_batch_size_prompts
                * self.context_size
                * self.n_batches_in_buffer
            )
            print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 ** 6}")
            n_contexts_per_buffer = (
                self.store_batch_size_prompts * self.n_batches_in_buffer
            )
            print(
                f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 ** 6}"
            )

            total_training_steps = (
                self.training_tokens + self.finetuning_tokens
            ) // self.train_batch_size_tokens
            print(f"Total training steps: {total_training_steps}")

            total_wandb_updates = total_training_steps // self.wandb_log_frequency
            print(f"Total wandb updates: {total_wandb_updates}")

            # how many times will we sample dead neurons?
            # assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
            n_feature_window_samples = (
                total_training_steps // self.feature_sampling_window
            )
            print(
                f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
            )
            print(
                f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
            )
            print(
                f"We will reset the sparsity calculation {n_feature_window_samples} times."
            )
            # print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
            print(
                f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
            )

        if self.use_ghost_grads:
            print("Using Ghost Grads.")

    @property
    def total_training_tokens(self) -> int:
        return self.training_tokens + self.finetuning_tokens

    @property
    def total_training_steps(self) -> int:
        return self.total_training_tokens // self.train_batch_size_tokens

    def get_base_sae_cfg_dict(self) -> dict[str, Any]:
        return {
            # TEMP
            "architecture": self.architecture,
            "d_in": self.d_in,
            "d_sae": self.d_sae,
            "dtype": self.dtype,
            "device": self.device,
            "model_name": self.model_name,
            "hook_name": self.hook_name,
            "hook_layer": self.hook_layer,
            "hook_head_index": self.hook_head_index,
            "activation_fn_str": self.activation_fn,
            "apply_b_dec_to_input": self.apply_b_dec_to_input,
            "context_size": self.context_size,
            "prepend_bos": self.prepend_bos,
            "dataset_path": self.dataset_path,
            "dataset_trust_remote_code": self.dataset_trust_remote_code,
            "finetuning_scaling_factor": self.finetuning_method is not None,
            "sae_lens_training_version": self.sae_lens_training_version,
            "normalize_activations": self.normalize_activations,
        }

    def get_training_sae_cfg_dict(self) -> dict[str, Any]:
        return {
            **self.get_base_sae_cfg_dict(),
            "l1_coefficient": self.l1_coefficient,
            "lp_norm": self.lp_norm,
            "use_ghost_grads": self.use_ghost_grads,
            "normalize_sae_decoder": self.normalize_sae_decoder,
            "noise_scale": self.noise_scale,
            "decoder_orthogonal_init": self.decoder_orthogonal_init,
            "mse_loss_normalization": self.mse_loss_normalization,
            "decoder_heuristic_init": self.decoder_heuristic_init,
            "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
            "normalize_activations": self.normalize_activations,
        }

    def to_dict(self) -> dict[str, Any]:

        cfg_dict = {
            **self.__dict__,
            # some args may not be serializable by default
            "dtype": str(self.dtype),
            "device": str(self.device),
            "act_store_device": str(self.act_store_device),
        }

        return cfg_dict

    def to_json(self, path: str) -> None:

        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path))

        with open(path + "cfg.json", "w") as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
        with open(path + "cfg.json", "r") as f:
            cfg = json.load(f)
        return cls(**cfg)

PretokenizeRunner

Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
class PretokenizeRunner:
    """
    Runner to pretokenize a dataset using a given tokenizer, and optionally upload to Huggingface.
    """

    def __init__(self, cfg: PretokenizeRunnerConfig):
        self.cfg = cfg

    def run(self):
        """
        Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
        """
        dataset = load_dataset(
            self.cfg.dataset_path,
            data_dir=self.cfg.data_dir,
            data_files=self.cfg.data_files,
            split=self.cfg.split,
            streaming=self.cfg.streaming,
        )
        if isinstance(dataset, DatasetDict):
            raise ValueError(
                "Dataset has multiple splits. Must provide a 'split' param."
            )
        tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
        tokenizer.model_max_length = sys.maxsize
        tokenized_dataset = pretokenize_dataset(
            cast(Dataset, dataset), tokenizer, self.cfg
        )

        if self.cfg.save_path is not None:
            tokenized_dataset.save_to_disk(self.cfg.save_path)
            metadata = metadata_from_config(self.cfg)
            metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
            with open(metadata_path, "w") as f:
                json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

        if self.cfg.hf_repo_id is not None:
            push_to_hugging_face_hub(tokenized_dataset, self.cfg)

        return tokenized_dataset

run()

Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.

Source code in sae_lens/pretokenize_runner.py
def run(self):
    """
    Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
    """
    dataset = load_dataset(
        self.cfg.dataset_path,
        data_dir=self.cfg.data_dir,
        data_files=self.cfg.data_files,
        split=self.cfg.split,
        streaming=self.cfg.streaming,
    )
    if isinstance(dataset, DatasetDict):
        raise ValueError(
            "Dataset has multiple splits. Must provide a 'split' param."
        )
    tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
    tokenizer.model_max_length = sys.maxsize
    tokenized_dataset = pretokenize_dataset(
        cast(Dataset, dataset), tokenizer, self.cfg
    )

    if self.cfg.save_path is not None:
        tokenized_dataset.save_to_disk(self.cfg.save_path)
        metadata = metadata_from_config(self.cfg)
        metadata_path = Path(self.cfg.save_path) / "sae_lens.json"
        with open(metadata_path, "w") as f:
            json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

    if self.cfg.hf_repo_id is not None:
        push_to_hugging_face_hub(tokenized_dataset, self.cfg)

    return tokenized_dataset

SAE

Bases: HookedRootModule

Core Sparse Autoencoder (SAE) class used for inference. For training, see TrainingSAE.

Source code in sae_lens/sae.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
class SAE(HookedRootModule):
    """
    Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
    """

    cfg: SAEConfig
    dtype: torch.dtype
    device: torch.device

    # analysis
    use_error_term: bool

    def __init__(
        self,
        cfg: SAEConfig,
        use_error_term: bool = False,
    ):
        super().__init__()

        self.cfg = cfg
        self.activation_fn = get_activation_fn(cfg.activation_fn_str)
        self.dtype = DTYPE_MAP[cfg.dtype]
        self.device = torch.device(cfg.device)
        self.use_error_term = use_error_term

        if self.cfg.architecture == "standard":
            self.initialize_weights_basic()
            self.encode_fn = self.encode
        elif self.cfg.architecture == "gated":
            self.initialize_weights_gated()
            self.encode_fn = self.encode_gated

        # handle presence / absence of scaling factor.
        if self.cfg.finetuning_scaling_factor:
            self.apply_finetuning_scaling_factor = (
                lambda x: x * self.finetuning_scaling_factor
            )
        else:
            self.apply_finetuning_scaling_factor = lambda x: x

        # set up hooks
        self.hook_sae_input = HookPoint()
        self.hook_sae_acts_pre = HookPoint()
        self.hook_sae_acts_post = HookPoint()
        self.hook_sae_output = HookPoint()
        self.hook_sae_recons = HookPoint()
        self.hook_sae_error = HookPoint()

        # handle hook_z reshaping if needed.
        # this is very cursed and should be refactored. it exists so that we can reshape out
        # the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
        # into a separate encode and decode function.
        # this will cause errors if we call decode before encode.
        if self.cfg.hook_name.endswith("_z"):
            self.turn_on_forward_pass_hook_z_reshaping()
        else:
            # need to default the reshape fns
            self.turn_off_forward_pass_hook_z_reshaping()

        # handle run time activation normalization if needed:
        if self.cfg.normalize_activations == "constant_norm_rescale":

            #  we need to scale the norm of the input and store the scaling factor
            def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
                self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
                x = x * self.x_norm_coeff
                return x

            def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
                x = x / self.x_norm_coeff
                del self.x_norm_coeff  # prevents reusing
                return x

            self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
            self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

        else:
            self.run_time_activation_norm_fn_in = lambda x: x
            self.run_time_activation_norm_fn_out = lambda x: x

        self.setup()  # Required for `HookedRootModule`s

    def initialize_weights_basic(self):

        # no config changes encoder bias init for now.
        self.b_enc = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        # Start with the default init strategy:
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
        )

        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
                )
            )
        )

        # methdods which change b_dec as a function of the dataset are implemented after init.
        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

        # scaling factor for fine-tuning (not to be used in initial training)
        # TODO: Make this optional and not included with all SAEs by default (but maintain backwards compatibility)
        if self.cfg.finetuning_scaling_factor:
            self.finetuning_scaling_factor = nn.Parameter(
                torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
            )

    def initialize_weights_gated(self):
        # Initialize the weights and biases for the gated encoder
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
                )
            )
        )

        self.b_gate = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.r_mag = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.b_mag = nn.Parameter(
            torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
        )

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
        )

        self.b_dec = nn.Parameter(
            torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
        )

    # Basic Forward Pass Functionality.
    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        feature_acts = self.encode_fn(x)
        sae_out = self.decode(feature_acts)

        # TEMP
        if self.use_error_term and self.cfg.architecture != "gated":
            with torch.no_grad():
                # Recompute everything without hooks to get true error term
                # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
                # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
                # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.

                # move x to correct dtype
                x = x.to(self.dtype)

                # handle hook z reshaping if needed.
                sae_in = self.reshape_fn_in(x)  # type: ignore

                # handle run time activation normalization if needed
                sae_in = self.run_time_activation_norm_fn_in(sae_in)

                # apply b_dec_to_input if using that method.
                sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

                # "... d_in, d_in d_sae -> ... d_sae",
                hidden_pre = sae_in_cent @ self.W_enc + self.b_enc
                feature_acts = self.activation_fn(hidden_pre)
                x_reconstruct_clean = self.reshape_fn_out(
                    self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec
                    + self.b_dec,
                    d_head=self.d_head,
                )

                sae_out = self.run_time_activation_norm_fn_out(sae_out)
                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            return self.hook_sae_output(sae_out + sae_error)

        # TODO: Add tests
        elif self.use_error_term and self.cfg.architecture == "gated":
            with torch.no_grad():
                x = x.to(self.dtype)
                sae_in = self.reshape_fn_in(x)  # type: ignore
                gating_pre_activation = sae_in @ self.W_enc + self.b_gate
                active_features = (gating_pre_activation > 0).float()

                # Magnitude path with weight sharing
                magnitude_pre_activation = self.hook_sae_acts_pre(
                    sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
                )
                feature_magnitudes = self.hook_sae_acts_post(
                    self.activation_fn(magnitude_pre_activation)
                )
                feature_acts_clean = active_features * feature_magnitudes
                x_reconstruct_clean = self.reshape_fn_out(
                    self.apply_finetuning_scaling_factor(feature_acts_clean)
                    @ self.W_dec
                    + self.b_dec,
                    d_head=self.d_head,
                )

                sae_error = self.hook_sae_error(x - x_reconstruct_clean)
            return self.hook_sae_output(sae_out + sae_error)

        return self.hook_sae_output(sae_out)

    def encode_gated(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        x = x.to(self.dtype)
        x = self.reshape_fn_in(x)
        sae_in = self.hook_sae_input(x - self.b_dec * self.cfg.apply_b_dec_to_input)

        # Gating path
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).float()

        # Magnitude path with weight sharing
        magnitude_pre_activation = self.hook_sae_acts_pre(
            sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        )
        feature_magnitudes = self.hook_sae_acts_post(
            self.activation_fn(magnitude_pre_activation)
        )

        return active_features * feature_magnitudes

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Calculate SAE features from inputs
        """

        # move x to correct dtype
        x = x.to(self.dtype)

        # handle hook z reshaping if needed.
        x = self.reshape_fn_in(x)  # type: ignore

        # handle run time activation normalization if needed
        x = self.run_time_activation_norm_fn_in(x)

        # apply b_dec_to_input if using that method.
        sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

        # "... d_in, d_in d_sae -> ... d_sae",
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

        return feature_acts

    def decode(
        self, feature_acts: Float[torch.Tensor, "... d_sae"]
    ) -> Float[torch.Tensor, "... d_in"]:
        """Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
        # "... d_sae, d_sae d_in -> ... d_in",
        sae_out = self.hook_sae_recons(
            self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
        )

        # handle run time activation normalization if needed
        # will fail if you call this twice without calling encode in between.
        sae_out = self.run_time_activation_norm_fn_out(sae_out)

        # handle hook z reshaping if needed.
        sae_out = self.reshape_fn_out(sae_out, self.d_head)  # type: ignore

        return sae_out

    @torch.no_grad()
    def fold_W_dec_norm(self):
        W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
        self.W_dec.data = self.W_dec.data / W_dec_norms
        self.W_enc.data = self.W_enc.data * W_dec_norms.T
        if self.cfg.architecture == "gated":
            self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
            self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
            self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
        else:
            self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

    @torch.no_grad()
    def fold_activation_norm_scaling_factor(
        self, activation_norm_scaling_factor: float
    ):
        self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor

    def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):

        if not os.path.exists(path):
            os.mkdir(path)

        # generate the weights
        save_file(self.state_dict(), f"{path}/{SAE_WEIGHTS_PATH}")

        # save the config
        config = self.cfg.to_dict()

        with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
            json.dump(config, f)

        if sparsity is not None:
            sparsity_in_dict = {"sparsity": sparsity}
            save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}")  # type: ignore

    @classmethod
    def load_from_pretrained(
        cls, path: str, device: str = "cpu", dtype: str = "float32"
    ) -> "SAE":

        config_path = os.path.join(path, "cfg.json")
        weight_path = os.path.join(path, "sae_weights.safetensors")

        cfg_dict, state_dict, _ = load_pretrained_sae_lens_sae_components(
            config_path, weight_path, device, dtype
        )

        sae_cfg = SAEConfig.from_dict(cfg_dict)

        sae = cls(sae_cfg)
        sae.load_state_dict(state_dict)

        return sae

    @classmethod
    def from_pretrained(
        cls,
        release: str,
        sae_id: str,
        device: str = "cpu",
    ) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
        """

        Load a pretrained SAE from the Hugging Face model hub.

        Args:
            release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
            id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
            device: The device to load the SAE on.
            return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
        """

        # get sae directory
        sae_directory = get_pretrained_saes_directory()

        # get the repo id and path to the SAE
        if release not in sae_directory:
            raise ValueError(
                f"Release {release} not found in pretrained SAEs directory."
            )
        if sae_id not in sae_directory[release].saes_map:
            raise ValueError(f"ID {sae_id} not found in release {release}.")
        sae_info = sae_directory[release]
        hf_repo_id = sae_info.repo_id
        hf_path = sae_info.saes_map[sae_id]

        conversion_loader_name = sae_info.conversion_func or "sae_lens"
        if conversion_loader_name not in NAMED_PRETRAINED_SAE_LOADERS:
            raise ValueError(
                f"Conversion func {conversion_loader_name} not found in NAMED_PRETRAINED_SAE_LOADERS."
            )
        conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]

        cfg_dict, state_dict, log_sparsities = conversion_loader(
            repo_id=hf_repo_id,
            folder_name=hf_path,
            device=device,
            force_download=False,
        )

        sae = cls(SAEConfig.from_dict(cfg_dict))
        sae.load_state_dict(state_dict)

        return sae, cfg_dict, log_sparsities

    def get_name(self):
        sae_name = f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
        return sae_name

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
        return cls(SAEConfig.from_dict(config_dict))

    def turn_on_forward_pass_hook_z_reshaping(self):

        assert self.cfg.hook_name.endswith(
            "_z"
        ), "This method should only be called for hook_z SAEs."

        def reshape_fn_in(x: torch.Tensor):
            self.d_head = x.shape[-1]  # type: ignore
            self.reshape_fn_in = lambda x: einops.rearrange(
                x, "... n_heads d_head -> ... (n_heads d_head)"
            )
            return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

        self.reshape_fn_in = reshape_fn_in

        self.reshape_fn_out = lambda x, d_head: einops.rearrange(
            x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
        )
        self.hook_z_reshaping_mode = True

    def turn_off_forward_pass_hook_z_reshaping(self):
        self.reshape_fn_in = lambda x: x
        self.reshape_fn_out = lambda x, d_head: x
        self.d_head = None
        self.hook_z_reshaping_mode = False

decode(feature_acts)

Decodes SAE feature activation tensor into a reconstructed input activation tensor.

Source code in sae_lens/sae.py
def decode(
    self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
    """Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
    # "... d_sae, d_sae d_in -> ... d_in",
    sae_out = self.hook_sae_recons(
        self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
    )

    # handle run time activation normalization if needed
    # will fail if you call this twice without calling encode in between.
    sae_out = self.run_time_activation_norm_fn_out(sae_out)

    # handle hook z reshaping if needed.
    sae_out = self.reshape_fn_out(sae_out, self.d_head)  # type: ignore

    return sae_out

encode(x)

Calculate SAE features from inputs

Source code in sae_lens/sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Calculate SAE features from inputs
    """

    # move x to correct dtype
    x = x.to(self.dtype)

    # handle hook z reshaping if needed.
    x = self.reshape_fn_in(x)  # type: ignore

    # handle run time activation normalization if needed
    x = self.run_time_activation_norm_fn_in(x)

    # apply b_dec_to_input if using that method.
    sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

    # "... d_in, d_in d_sae -> ... d_sae",
    hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
    feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

    return feature_acts

from_pretrained(release, sae_id, device='cpu') classmethod

Load a pretrained SAE from the Hugging Face model hub.

Parameters:

Name Type Description Default
release str

The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.

required
id

The id of the SAE to load. This will be mapped to a path in the huggingface repo.

required
device str

The device to load the SAE on.

'cpu'
return_sparsity_if_present

If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.

required
Source code in sae_lens/sae.py
@classmethod
def from_pretrained(
    cls,
    release: str,
    sae_id: str,
    device: str = "cpu",
) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
    """

    Load a pretrained SAE from the Hugging Face model hub.

    Args:
        release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
        id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
        device: The device to load the SAE on.
        return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
    """

    # get sae directory
    sae_directory = get_pretrained_saes_directory()

    # get the repo id and path to the SAE
    if release not in sae_directory:
        raise ValueError(
            f"Release {release} not found in pretrained SAEs directory."
        )
    if sae_id not in sae_directory[release].saes_map:
        raise ValueError(f"ID {sae_id} not found in release {release}.")
    sae_info = sae_directory[release]
    hf_repo_id = sae_info.repo_id
    hf_path = sae_info.saes_map[sae_id]

    conversion_loader_name = sae_info.conversion_func or "sae_lens"
    if conversion_loader_name not in NAMED_PRETRAINED_SAE_LOADERS:
        raise ValueError(
            f"Conversion func {conversion_loader_name} not found in NAMED_PRETRAINED_SAE_LOADERS."
        )
    conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]

    cfg_dict, state_dict, log_sparsities = conversion_loader(
        repo_id=hf_repo_id,
        folder_name=hf_path,
        device=device,
        force_download=False,
    )

    sae = cls(SAEConfig.from_dict(cfg_dict))
    sae.load_state_dict(state_dict)

    return sae, cfg_dict, log_sparsities

SAETrainingRunner

Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.

Source code in sae_lens/sae_training_runner.py
class SAETrainingRunner:
    """
    Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
    """

    cfg: LanguageModelSAERunnerConfig
    model: HookedRootModule
    sae: TrainingSAE
    activations_store: ActivationsStore

    def __init__(self, cfg: LanguageModelSAERunnerConfig):
        self.cfg = cfg

        self.model = load_model(
            self.cfg.model_class_name,
            self.cfg.model_name,
            device=self.cfg.device,
            model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
        )

        self.activations_store = ActivationsStore.from_config(
            self.model,
            self.cfg,
        )

        if self.cfg.from_pretrained_path is not None:
            self.sae = TrainingSAE.load_from_pretrained(
                self.cfg.from_pretrained_path, self.cfg.device
            )
        else:
            self.sae = TrainingSAE(
                TrainingSAEConfig.from_dict(
                    self.cfg.get_training_sae_cfg_dict(),
                )
            )
            self._init_sae_group_b_decs()

    def run(self):
        """
        Run the training of the SAE.
        """

        if self.cfg.log_to_wandb:
            wandb.init(
                project=self.cfg.wandb_project,
                config=cast(Any, self.cfg),
                name=self.cfg.run_name,
                id=self.cfg.wandb_id,
            )

        trainer = SAETrainer(
            model=self.model,
            sae=self.sae,
            activation_store=self.activations_store,
            save_checkpoint_fn=self.save_checkpoint,
            cfg=self.cfg,
        )

        self._compile_if_needed()
        sae = self.run_trainer_with_interruption_handling(trainer)

        if self.cfg.log_to_wandb:
            wandb.finish()

        return sae

    def _compile_if_needed(self):

        # Compile model and SAE
        #  torch.compile can provide significant speedups (10-20% in testing)
        # using max-autotune gives the best speedups but:
        # (a) increases VRAM usage,
        # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
        # (c) takes some time to compile
        # optimal settings seem to be:
        # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
        # (also pylance seems to really hate this)
        if self.cfg.compile_llm:
            self.model = torch.compile(
                self.model,
                mode=self.cfg.llm_compilation_mode,
            )  # type: ignore

        if self.cfg.compile_sae:
            if self.cfg.device == "mps":
                backend = "aot_eager"
            else:
                backend = "inductor"

            self.sae.training_forward_pass = torch.compile(  # type: ignore
                self.sae.training_forward_pass,
                mode=self.cfg.sae_compilation_mode,
                backend=backend,
            )  # type: ignore

    def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
        try:
            # signal handlers (if preempted)
            signal.signal(signal.SIGINT, interrupt_callback)
            signal.signal(signal.SIGTERM, interrupt_callback)

            # train SAE
            sae = trainer.fit()

        except (KeyboardInterrupt, InterruptedException):
            print("interrupted, saving progress")
            checkpoint_name = trainer.n_training_tokens
            self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
            print("done saving")
            raise

        return sae

    # TODO: move this into the SAE trainer or Training SAE class
    def _init_sae_group_b_decs(
        self,
    ) -> None:
        """
        extract all activations at a certain layer and use for sae b_dec initialization
        """

        if self.cfg.b_dec_init_method == "geometric_median":
            layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
            # get geometric median of the activations if we're using those.
            median = compute_geometric_median(
                layer_acts,
                maxiter=100,
            ).median
            self.sae.initialize_b_dec_with_precalculated(median)  # type: ignore
        elif self.cfg.b_dec_init_method == "mean":
            layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
            self.sae.initialize_b_dec_with_mean(layer_acts)  # type: ignore

    def save_checkpoint(
        self,
        trainer: SAETrainer,
        checkpoint_name: int | str,
        wandb_aliases: list[str] | None = None,
    ) -> str:

        checkpoint_path = f"{trainer.cfg.checkpoint_path}/{checkpoint_name}"

        os.makedirs(checkpoint_path, exist_ok=True)

        path = f"{checkpoint_path}"
        os.makedirs(path, exist_ok=True)

        if self.sae.cfg.normalize_sae_decoder:
            self.sae.set_decoder_norm_to_unit_norm()
        self.sae.save_model(path)

        # let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
        # and should not cause issues but give us more info about SAEs we trained in SAE Lens.
        config = trainer.cfg.to_dict()
        with open(f"{path}/cfg.json", "w") as f:
            json.dump(config, f)

        log_feature_sparsities = {"sparsity": trainer.log_feature_sparsity}

        log_feature_sparsity_path = f"{path}/{SPARSITY_PATH}"
        save_file(log_feature_sparsities, log_feature_sparsity_path)

        if trainer.cfg.log_to_wandb and os.path.exists(log_feature_sparsity_path):
            model_artifact = wandb.Artifact(
                f"{self.sae.get_name()}",
                type="model",
                metadata=dict(trainer.cfg.__dict__),
            )

            model_artifact.add_file(f"{path}/{SAE_WEIGHTS_PATH}")
            model_artifact.add_file(f"{path}/{SAE_CFG_PATH}")

            wandb.log_artifact(model_artifact, aliases=wandb_aliases)

            sparsity_artifact = wandb.Artifact(
                f"{self.sae.get_name()}_log_feature_sparsity",
                type="log_feature_sparsity",
                metadata=dict(trainer.cfg.__dict__),
            )
            sparsity_artifact.add_file(log_feature_sparsity_path)
            wandb.log_artifact(sparsity_artifact)

        return checkpoint_path

run()

Run the training of the SAE.

Source code in sae_lens/sae_training_runner.py
def run(self):
    """
    Run the training of the SAE.
    """

    if self.cfg.log_to_wandb:
        wandb.init(
            project=self.cfg.wandb_project,
            config=cast(Any, self.cfg),
            name=self.cfg.run_name,
            id=self.cfg.wandb_id,
        )

    trainer = SAETrainer(
        model=self.model,
        sae=self.sae,
        activation_store=self.activations_store,
        save_checkpoint_fn=self.save_checkpoint,
        cfg=self.cfg,
    )

    self._compile_if_needed()
    sae = self.run_trainer_with_interruption_handling(trainer)

    if self.cfg.log_to_wandb:
        wandb.finish()

    return sae

TrainingSAE

Bases: SAE

A SAE used for training. This class provides a training_forward_pass method which calculates losses used for training.

Source code in sae_lens/training/training_sae.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
class TrainingSAE(SAE):
    """
    A SAE used for training. This class provides a `training_forward_pass` method which calculates
    losses used for training.
    """

    cfg: TrainingSAEConfig
    use_error_term: bool
    dtype: torch.dtype
    device: torch.device

    def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):

        base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
        super().__init__(base_sae_cfg)
        self.cfg = cfg  # type: ignore

        self.encode_with_hidden_pre_fn = (
            self.encode_with_hidden_pre
            if cfg.architecture != "gated"
            else self.encode_with_hidden_pre_gated
        )

        self.check_cfg_compatibility()

        self.use_error_term = use_error_term

        self.initialize_weights_complex()

        # The training SAE will assume that the activation store handles
        # reshaping.
        self.turn_off_forward_pass_hook_z_reshaping()

        self.mse_loss_fn = self._get_mse_loss_fn()

    @classmethod
    def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE":
        return cls(TrainingSAEConfig.from_dict(config_dict))

    def check_cfg_compatibility(self):
        if self.cfg.architecture == "gated":
            assert (
                self.cfg.use_ghost_grads is False
            ), "Gated SAEs do not support ghost grads"
            assert self.use_error_term is False, "Gated SAEs do not support error terms"

    def encode(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> Float[torch.Tensor, "... d_sae"]:
        """
        Calcuate SAE features from inputs
        """
        feature_acts, _ = self.encode_with_hidden_pre_fn(x)
        return feature_acts

    def encode_with_hidden_pre(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

        # move x to correct dtype
        x = x.to(self.dtype)

        # handle hook z reshaping if needed.
        x = self.reshape_fn_in(x)  # type: ignore

        # apply b_dec_to_input if using that method.
        sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

        # handle run time activation normalization if needed
        x = self.run_time_activation_norm_fn_in(x)

        # "... d_in, d_in d_sae -> ... d_sae",
        hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
        hidden_pre_noised = hidden_pre + (
            torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training
        )
        feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))

        return feature_acts, hidden_pre_noised

    def encode_with_hidden_pre_gated(
        self, x: Float[torch.Tensor, "... d_in"]
    ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

        # move x to correct dtype
        x = x.to(self.dtype)

        # handle hook z reshaping if needed.
        x = self.reshape_fn_in(x)  # type: ignore

        # apply b_dec_to_input if using that method.
        sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

        # Gating path with Heaviside step function
        gating_pre_activation = sae_in @ self.W_enc + self.b_gate
        active_features = (gating_pre_activation > 0).float()

        # Magnitude path with weight sharing
        magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
        # magnitude_pre_activation_noised = magnitude_pre_activation + (
        #     torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training
        # )
        feature_magnitudes = self.activation_fn(
            magnitude_pre_activation
        )  # magnitude_pre_activation_noised)

        # Return both the gated feature activations and the magnitude pre-activations
        return (
            active_features * feature_magnitudes,
            magnitude_pre_activation,
        )  # magnitude_pre_activation_noised

    def forward(
        self,
        x: Float[torch.Tensor, "... d_in"],
    ) -> Float[torch.Tensor, "... d_in"]:

        feature_acts, _ = self.encode_with_hidden_pre_fn(x)
        sae_out = self.decode(feature_acts)

        return sae_out

    def training_forward_pass(
        self,
        sae_in: torch.Tensor,
        current_l1_coefficient: float,
        dead_neuron_mask: Optional[torch.Tensor] = None,
    ) -> TrainStepOutput:

        # do a forward pass to get SAE out, but we also need the
        # hidden pre.
        feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in)
        sae_out = self.decode(feature_acts)

        # MSE LOSS
        per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
        mse_loss = per_item_mse_loss.sum(dim=-1).mean()

        # GHOST GRADS
        if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None:

            # first half of second forward pass
            _, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
            ghost_grad_loss = self.calculate_ghost_grad_loss(
                x=sae_in,
                sae_out=sae_out,
                per_item_mse_loss=per_item_mse_loss,
                hidden_pre=hidden_pre,
                dead_neuron_mask=dead_neuron_mask,
            )
        else:
            ghost_grad_loss = 0.0

        if self.cfg.architecture == "gated":
            # Gated SAE Loss Calculation

            # Shared variables
            sae_in_centered = (
                self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input
            )
            pi_gate = sae_in_centered @ self.W_enc + self.b_gate
            pi_gate_act = torch.relu(pi_gate)

            # SFN sparsity loss - summed over the feature dimension and averaged over the batch
            l1_loss = (
                current_l1_coefficient
                * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
            )

            # Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch
            via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
            aux_reconstruction_loss = torch.sum(
                (via_gate_reconstruction - sae_in) ** 2, dim=-1
            ).mean()

            loss = mse_loss + l1_loss + aux_reconstruction_loss
        else:
            # default SAE sparsity loss
            weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
            sparsity = weighted_feature_acts.norm(
                p=self.cfg.lp_norm, dim=-1
            )  # sum over the feature dimension

            l1_loss = (current_l1_coefficient * sparsity).mean()
            loss = mse_loss + l1_loss + ghost_grad_loss

            aux_reconstruction_loss = torch.tensor(0.0)

        return TrainStepOutput(
            sae_in=sae_in,
            sae_out=sae_out,
            feature_acts=feature_acts,
            loss=loss,
            mse_loss=mse_loss.item(),
            l1_loss=l1_loss.item(),
            ghost_grad_loss=(
                ghost_grad_loss.item()
                if isinstance(ghost_grad_loss, torch.Tensor)
                else ghost_grad_loss
            ),
            auxiliary_reconstruction_loss=aux_reconstruction_loss.item(),
        )

    def calculate_ghost_grad_loss(
        self,
        x: torch.Tensor,
        sae_out: torch.Tensor,
        per_item_mse_loss: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor,
    ) -> torch.Tensor:

        # 1.
        residual = x - sae_out
        l2_norm_residual = torch.norm(residual, dim=-1)

        # 2.
        # ghost grads use an exponentional activation function, ignoring whatever
        # the activation function is in the SAE. The forward pass uses the dead neurons only.
        feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
        ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
        l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
        norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
        ghost_out = ghost_out * norm_scaling_factor[:, None].detach()

        # 3. There is some fairly complex rescaling here to make sure that the loss
        # is comparable to the original loss. This is because the ghost grads are
        # only calculated for the dead neurons, so we need to rescale the loss to
        # make sure that the loss is comparable to the original loss.
        # There have been methodological improvements that are not implemented here yet
        # see here: https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Improving_ghost_grads
        per_item_mse_loss_ghost_resid = self.mse_loss_fn(ghost_out, residual.detach())
        mse_rescaling_factor = (
            per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
        ).detach()
        per_item_mse_loss_ghost_resid = (
            mse_rescaling_factor * per_item_mse_loss_ghost_resid
        )

        return per_item_mse_loss_ghost_resid.mean()

    @torch.no_grad()
    def _get_mse_loss_fn(self) -> Any:

        def standard_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            return torch.nn.functional.mse_loss(preds, target, reduction="none")

        def batch_norm_mse_loss_fn(
            preds: torch.Tensor, target: torch.Tensor
        ) -> torch.Tensor:
            target_centered = target - target.mean(dim=0, keepdim=True)
            normalization = target_centered.norm(dim=-1, keepdim=True)
            return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
                normalization + 1e-6
            )

        if self.cfg.mse_loss_normalization == "dense_batch":
            return batch_norm_mse_loss_fn
        else:
            return standard_mse_loss_fn

    @classmethod
    def load_from_pretrained(
        cls,
        path: str,
        device: str = "cpu",
        dtype: str = "float32",
    ) -> "TrainingSAE":

        config_path = os.path.join(path, "cfg.json")
        weight_path = os.path.join(path, "sae_weights.safetensors")

        cfg_dict, state_dict, _ = load_pretrained_sae_lens_sae_components(
            config_path, weight_path, device, dtype
        )

        sae_cfg = TrainingSAEConfig.from_dict(cfg_dict)

        sae = cls(sae_cfg)
        sae.load_state_dict(state_dict)

        return sae

    def initialize_weights_complex(self):
        """ """

        if self.cfg.decoder_orthogonal_init:
            self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

        elif self.cfg.decoder_heuristic_init:
            self.W_dec = nn.Parameter(
                torch.rand(
                    self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
                )
            )
            self.initialize_decoder_norm_constant_norm()

        elif self.cfg.normalize_sae_decoder:
            self.set_decoder_norm_to_unit_norm()

        # Then we initialize the encoder weights (either as the transpose of decoder or not)
        if self.cfg.init_encoder_as_decoder_transpose:
            self.W_enc.data = self.W_dec.data.T.clone().contiguous()
        else:
            self.W_enc = nn.Parameter(
                torch.nn.init.kaiming_uniform_(
                    torch.empty(
                        self.cfg.d_in,
                        self.cfg.d_sae,
                        dtype=self.dtype,
                        device=self.device,
                    )
                )
            )

        if self.cfg.normalize_sae_decoder:
            with torch.no_grad():
                # Anthropic normalize this to have unit columns
                self.set_decoder_norm_to_unit_norm()

    ## Initialization Methods
    @torch.no_grad()
    def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
        out = torch.tensor(origin, dtype=self.dtype, device=self.device)
        self.b_dec.data = out

    @torch.no_grad()
    def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
        previous_b_dec = self.b_dec.clone().cpu()
        out = all_activations.mean(dim=0)

        previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
        distances = torch.norm(all_activations - out, dim=-1)

        print("Reinitializing b_dec with mean of activations")
        print(
            f"Previous distances: {previous_distances.median(0).values.mean().item()}"
        )
        print(f"New distances: {distances.median(0).values.mean().item()}")

        self.b_dec.data = out.to(self.dtype).to(self.device)

    ## Training Utils
    @torch.no_grad()
    def set_decoder_norm_to_unit_norm(self):
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

    @torch.no_grad()
    def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
        """
        A heuristic proceedure inspired by:
        https://transformer-circuits.pub/2024/april-update/index.html#training-saes
        """
        # TODO: Parameterise this as a function of m and n

        # ensure W_dec norms at unit norm
        self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
        self.W_dec.data *= norm  # will break tests but do this for now.

    @torch.no_grad()
    def remove_gradient_parallel_to_decoder_directions(self):
        """
        Update grads so that they remove the parallel component
            (d_sae, d_in) shape
        """
        assert self.W_dec.grad is not None  # keep pyright happy

        parallel_component = einops.einsum(
            self.W_dec.grad,
            self.W_dec.data,
            "d_sae d_in, d_sae d_in -> d_sae",
        )
        self.W_dec.grad -= einops.einsum(
            parallel_component,
            self.W_dec.data,
            "d_sae, d_sae d_in -> d_sae d_in",
        )

encode(x)

Calcuate SAE features from inputs

Source code in sae_lens/training/training_sae.py
def encode(
    self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
    """
    Calcuate SAE features from inputs
    """
    feature_acts, _ = self.encode_with_hidden_pre_fn(x)
    return feature_acts

initialize_decoder_norm_constant_norm(norm=0.1)

A heuristic proceedure inspired by: https://transformer-circuits.pub/2024/april-update/index.html#training-saes

Source code in sae_lens/training/training_sae.py
@torch.no_grad()
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
    """
    A heuristic proceedure inspired by:
    https://transformer-circuits.pub/2024/april-update/index.html#training-saes
    """
    # TODO: Parameterise this as a function of m and n

    # ensure W_dec norms at unit norm
    self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
    self.W_dec.data *= norm  # will break tests but do this for now.

initialize_weights_complex()

Source code in sae_lens/training/training_sae.py
def initialize_weights_complex(self):
    """ """

    if self.cfg.decoder_orthogonal_init:
        self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T

    elif self.cfg.decoder_heuristic_init:
        self.W_dec = nn.Parameter(
            torch.rand(
                self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
            )
        )
        self.initialize_decoder_norm_constant_norm()

    elif self.cfg.normalize_sae_decoder:
        self.set_decoder_norm_to_unit_norm()

    # Then we initialize the encoder weights (either as the transpose of decoder or not)
    if self.cfg.init_encoder_as_decoder_transpose:
        self.W_enc.data = self.W_dec.data.T.clone().contiguous()
    else:
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.cfg.d_in,
                    self.cfg.d_sae,
                    dtype=self.dtype,
                    device=self.device,
                )
            )
        )

    if self.cfg.normalize_sae_decoder:
        with torch.no_grad():
            # Anthropic normalize this to have unit columns
            self.set_decoder_norm_to_unit_norm()

remove_gradient_parallel_to_decoder_directions()

Update grads so that they remove the parallel component (d_sae, d_in) shape

Source code in sae_lens/training/training_sae.py
@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
    """
    Update grads so that they remove the parallel component
        (d_sae, d_in) shape
    """
    assert self.W_dec.grad is not None  # keep pyright happy

    parallel_component = einops.einsum(
        self.W_dec.grad,
        self.W_dec.data,
        "d_sae d_in, d_sae d_in -> d_sae",
    )
    self.W_dec.grad -= einops.einsum(
        parallel_component,
        self.W_dec.data,
        "d_sae, d_sae d_in -> d_sae d_in",
    )