Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import Olmo2ForCausalLM | |
| class SAE(nn.Module): | |
| def __init__(self, input_size, hidden_size, init_scale=0.1): | |
| super().__init__() | |
| # Store dimensions | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| # Initialize as before | |
| self.encode = nn.Linear(input_size, hidden_size, bias=True) | |
| self.decode = nn.Linear(hidden_size, input_size, bias=True) | |
| with torch.no_grad(): | |
| # Random directions | |
| decoder_weights = torch.randn(input_size, hidden_size) | |
| # Normalize columns | |
| decoder_weights = decoder_weights / torch.linalg.vector_norm(decoder_weights, dim=0, keepdim=True) | |
| # Scale by random values between 0.05 and 1.0 | |
| scales = torch.rand(hidden_size) * 0.95 + 0.05 | |
| decoder_weights = decoder_weights * scales | |
| self.decode.weight.data = decoder_weights | |
| self.encode.weight.data = decoder_weights.T.contiguous() | |
| self.encode.bias.data.zero_() #zero in place | |
| self.decode.bias.data.zero_() | |
| self.constrain_weights() | |
| def device(self): | |
| """Return the device the model parameters are on""" | |
| return next(self.parameters()).device | |
| def constrain_weights(self): | |
| """Constrain the decoder weights to have unit norm.""" | |
| with torch.no_grad(): | |
| decoder_norm = torch.linalg.vector_norm(self.decode.weight, dim=0, keepdim=True) | |
| self.decode.weight.data = self.decode.weight.data / decoder_norm | |
| def forward(self, x): | |
| features = F.relu(self.encode(x)) | |
| reconstruction = self.decode(features) | |
| return reconstruction, features | |
| def get_decoder_norms(self): | |
| # returns a 1-D tensor (hidden_size,) on the right device/dtype | |
| return torch.linalg.vector_norm(self.decode.weight, dim=0) | |
| def W_dec(self): | |
| """Return decoder weights for easier access during analysis""" | |
| return self.decode.weight | |
| def compute_loss(self, x, recon, feats, lambda_): | |
| # reconstruction term β sum over feature-dim, mean over batch | |
| recon_mse = (recon - x).pow(2).sum(-1).mean() | |
| # sparsity term β L1 on feature activations * current decoder-column norms | |
| sparsity = (feats.abs() * self.get_decoder_norms()).sum(1).mean() | |
| return recon_mse + lambda_ * sparsity | |
| class SteerableOlmo2ForCausalLM(Olmo2ForCausalLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.steering_layer = None | |
| self.sae = None | |
| self.steering_features = {} | |
| self.steering_hook = None | |
| self.sae_max = None | |
| def set_sae_and_layer(self, sae, layer): | |
| self.sae = sae | |
| self.steering_layer = layer | |
| self._register_steering_hook() | |
| def set_sae_max(self, sae_max): | |
| self.sae_max = sae_max | |
| def set_steering(self, feature_idx, value, *, as_multiple_of_max=False): | |
| if as_multiple_of_max and self.sae_max is not None: | |
| value = float(value) * float(self.sae_max[feature_idx]) | |
| self.steering_features[feature_idx] = value | |
| def clear_steering(self): | |
| self.steering_features = {} | |
| def _steering_hook_fn(self, module, input, output): | |
| if not self.steering_features or self.sae is None: | |
| return output | |
| hidden_states = output[0] | |
| feats = self.sae.encode(hidden_states) | |
| recon = self.sae.decode(feats) | |
| error = hidden_states - recon | |
| feats_steered = feats.clone() | |
| for idx, clamp_value in self.steering_features.items(): | |
| feats_steered[..., idx] = clamp_value | |
| recon_steered = self.sae.decode(feats_steered) | |
| hidden_steered = recon_steered + error | |
| return (hidden_steered,) + output[1:] | |
| def _register_steering_hook(self): | |
| if self.steering_hook is not None: | |
| self.steering_hook.remove() | |
| self.steering_hook = None | |
| if self.steering_layer is not None: | |
| target_layer = self.model.layers[self.steering_layer] | |
| self.steering_hook = target_layer.register_forward_hook(self._steering_hook_fn) | |
| def remove_steering_hook(self): | |
| if self.steering_hook is not None: | |
| self.steering_hook.remove() | |
| self.steering_hook = None | |