Trained
Collection
3 items
โข
Updated
Sparse Autoencoders (SAEs) trained on all 18 layers of google/functiongemma-270m-it.
self_attn.o_proj (output projection of self-attention)Each checkpoint contains:
{
"model_name": "google/functiongemma-270m-it",
"layer_idx": int,
"d_in": 640,
"d_sae": 4096,
"W_enc": torch.Tensor, # (640, 4096)
"b_enc": torch.Tensor, # (4096,)
"W_dec": torch.Tensor, # (4096, 640)
"b_dec": torch.Tensor, # (640,)
"history": {
"loss": [...],
"mse": [...],
"l0": [...]
}
}
import torch
from huggingface_hub import hf_hub_download
# Load SAE for a specific layer
layer_idx = 0
ckpt_path = hf_hub_download(
"mindchain/functiongemma-270m-sae",
f"sae_layer_{layer_idx:02d}.pt"
)
sae = torch.load(ckpt_path, map_location="cpu")
# Use SAE
class JumpReLUSAE(torch.nn.Module):
def __init__(self, W_enc, b_enc, W_dec, b_dec):
super().__init__()
self.W_enc = torch.nn.Parameter(W_enc)
self.b_enc = torch.nn.Parameter(b_enc)
self.W_dec = torch.nn.Parameter(W_dec)
self.b_dec = torch.nn.Parameter(b_dec)
def forward(self, x):
batch, seq, d_in = x.shape
x_flat = x.view(-1, d_in)
pre_act = x_flat @ self.W_enc + self.b_enc
features = torch.relu(pre_act)
recon = features @ self.W_dec + self.b_dec
return recon.view(batch, seq, d_in), features.view(batch, seq, -1)
sae_model = JumpReLUSAE(
sae["W_enc"], sae["b_enc"],
sae["W_dec"], sae["b_dec"]
)
# Get activations from FunctionGemma and encode
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"google/functiongemma-270m-it",
torch_dtype=torch.bfloat16,
device_map="cuda"
)
tokenizer = AutoTokenizer.from_pretrained("google/functiongemma-270m-it")
inputs = tokenizer("What's the weather?", return_tensors="pt").to(model.device)
# Hook to get activations
acts = []
def hook(module, inp, out):
acts.append(out[0].detach().float())
handle = model.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
with torch.no_grad():
_ = model(**inputs)
handle.remove()
# Run through SAE
recon, features = sae_model(acts[0])
print(f"Active features: {(features > 0).sum().item()}")
| Layer | Final Loss | Final MSE | L0 |
|---|---|---|---|
| 0 | 3.4457 | 3.1244 | 1225 |
| 1 | 2.0052 | 1.9042 | 1386 |
| 2 | 0.1182 | 0.0759 | 1546 |
| 3 | 0.1182 | 0.0758 | 3096 |
| 4 | 0.0361 | 0.0170 | 1635 |
| 5 | 0.0414 | 0.0351 | 399 |
| 6 | 0.0318 | 0.0138 | 1807 |
| 7 | 0.0877 | 0.0661 | 1120 |
| 8 | 0.0733 | 0.0445 | 1379 |
| 9 | 0.0561 | 0.0317 | 1569 |
| 10 | 0.0997 | 0.0852 | 591 |
| 11 | 0.0252 | 0.0097 | 3658 |
| 12 | 0.0565 | 0.0395 | 962 |
| 13 | 0.0924 | 0.0619 | 1403 |
| 14 | 0.2711 | 0.2504 | 709 |
| 15 | 0.1501 | 0.1062 | 1576 |
| 16 | 0.1670 | 0.1426 | 870 |
| 17 | 0.0385 | 0.0218 | 1470 |
Apache 2.0