hbfreed's picture
Upload 5 files
4e4764b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Olmo2ForCausalLM
class SAE(nn.Module):
def __init__(self, input_size, hidden_size, init_scale=0.1):
super().__init__()
# Store dimensions
self.input_size = input_size
self.hidden_size = hidden_size
# Initialize as before
self.encode = nn.Linear(input_size, hidden_size, bias=True)
self.decode = nn.Linear(hidden_size, input_size, bias=True)
with torch.no_grad():
# Random directions
decoder_weights = torch.randn(input_size, hidden_size)
# Normalize columns
decoder_weights = decoder_weights / torch.linalg.vector_norm(decoder_weights, dim=0, keepdim=True)
# Scale by random values between 0.05 and 1.0
scales = torch.rand(hidden_size) * 0.95 + 0.05
decoder_weights = decoder_weights * scales
self.decode.weight.data = decoder_weights
self.encode.weight.data = decoder_weights.T.contiguous()
self.encode.bias.data.zero_() #zero in place
self.decode.bias.data.zero_()
self.constrain_weights()
@property
def device(self):
"""Return the device the model parameters are on"""
return next(self.parameters()).device
def constrain_weights(self):
"""Constrain the decoder weights to have unit norm."""
with torch.no_grad():
decoder_norm = torch.linalg.vector_norm(self.decode.weight, dim=0, keepdim=True)
self.decode.weight.data = self.decode.weight.data / decoder_norm
def forward(self, x):
features = F.relu(self.encode(x))
reconstruction = self.decode(features)
return reconstruction, features
def get_decoder_norms(self):
# returns a 1-D tensor (hidden_size,) on the right device/dtype
return torch.linalg.vector_norm(self.decode.weight, dim=0)
@property
def W_dec(self):
"""Return decoder weights for easier access during analysis"""
return self.decode.weight
def compute_loss(self, x, recon, feats, lambda_):
# reconstruction term β€” sum over feature-dim, mean over batch
recon_mse = (recon - x).pow(2).sum(-1).mean()
# sparsity term β€” L1 on feature activations * current decoder-column norms
sparsity = (feats.abs() * self.get_decoder_norms()).sum(1).mean()
return recon_mse + lambda_ * sparsity
class SteerableOlmo2ForCausalLM(Olmo2ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.steering_layer = None
self.sae = None
self.steering_features = {}
self.steering_hook = None
self.sae_max = None
def set_sae_and_layer(self, sae, layer):
self.sae = sae
self.steering_layer = layer
self._register_steering_hook()
def set_sae_max(self, sae_max):
self.sae_max = sae_max
def set_steering(self, feature_idx, value, *, as_multiple_of_max=False):
if as_multiple_of_max and self.sae_max is not None:
value = float(value) * float(self.sae_max[feature_idx])
self.steering_features[feature_idx] = value
def clear_steering(self):
self.steering_features = {}
@torch.no_grad()
def _steering_hook_fn(self, module, input, output):
if not self.steering_features or self.sae is None:
return output
hidden_states = output[0]
feats = self.sae.encode(hidden_states)
recon = self.sae.decode(feats)
error = hidden_states - recon
feats_steered = feats.clone()
for idx, clamp_value in self.steering_features.items():
feats_steered[..., idx] = clamp_value
recon_steered = self.sae.decode(feats_steered)
hidden_steered = recon_steered + error
return (hidden_steered,) + output[1:]
def _register_steering_hook(self):
if self.steering_hook is not None:
self.steering_hook.remove()
self.steering_hook = None
if self.steering_layer is not None:
target_layer = self.model.layers[self.steering_layer]
self.steering_hook = target_layer.register_forward_hook(self._steering_hook_fn)
def remove_steering_hook(self):
if self.steering_hook is not None:
self.steering_hook.remove()
self.steering_hook = None