Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import logging | |
| import time | |
| from utils import factorization | |
| LOG = logging.getLogger(__name__) | |
| class FixableDropout(nn.Module): | |
| def __init__(self, p: float): | |
| super().__init__() | |
| self.p = p | |
| self.mask_cache = {} | |
| self.seed = 0 | |
| def resample(self, seed=None): | |
| if seed is None: | |
| seed = int(time.time() * 1e6) | |
| self.mask_cache = {} | |
| self.seed = seed | |
| def forward(self, x): | |
| if self.training: | |
| if x.shape not in self.mask_cache: | |
| generator = torch.Generator(x.device).manual_seed(self.seed) | |
| self.mask_cache[x.shape] = torch.bernoulli( | |
| torch.full_like(x, 1 - self.p), generator=generator | |
| ).bool() | |
| self.should_resample = False | |
| x = (self.mask_cache[x.shape] * x) / (1 - self.p) | |
| return x | |
| def extra_repr(self) -> str: | |
| return f"p={self.p}" | |
| class ActMLP(nn.Module): | |
| def __init__(self, hidden_dim, n_hidden): | |
| super().__init__() | |
| self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id") | |
| def forward(self, x): | |
| return self.mlp(x.view(-1, 1)).view(x.shape) | |
| class LightIDMLP(nn.Module): | |
| def __init__( | |
| self, | |
| indim: int, | |
| outdim: int, | |
| hidden_dim: int, | |
| n_hidden: int, | |
| init: str = None, | |
| act: str = None, | |
| rank: int = None, | |
| ): | |
| super().__init__() | |
| LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}") | |
| self.layer1 = nn.Linear(indim, rank) | |
| self.layer2 = nn.Linear(rank, indim) | |
| self.layer2.weight.data[:] = 0 | |
| self.layer2.bias = None | |
| def forward(self, x): | |
| h = self.layer1(x).relu() | |
| return x + self.layer2(h) | |
| class IDMLP(nn.Module): | |
| def __init__( | |
| self, | |
| indim: int, | |
| outdim: int, | |
| hidden_dim: int, | |
| n_hidden: int, | |
| init: str = None, | |
| act: str = None, | |
| rank: int = None, | |
| n_modes: int = None | |
| ): | |
| super().__init__() | |
| LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}") | |
| self.layers = nn.ModuleList( | |
| [ | |
| LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes) | |
| for idx in range(n_hidden + 1) | |
| ] | |
| ) | |
| def forward(self, x, mode=None): | |
| for layer in self.layers: | |
| x = layer(x, mode=mode) | |
| return x | |
| class LatentIDMLP(nn.Module): | |
| def __init__( | |
| self, | |
| indim: int, | |
| outdim: int, | |
| hidden_dim: int, | |
| n_hidden: int, | |
| init: str = None, | |
| act: str = None, | |
| rank: int = None, | |
| ): | |
| super().__init__() | |
| LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}") | |
| self.layers = nn.ModuleList() | |
| self.layers.append(nn.Linear(indim, rank)) | |
| for _ in range(n_hidden - 1): | |
| self.layers.append(nn.Linear(rank, rank)) | |
| self.layers.append(nn.Linear(rank, outdim)) | |
| for layer in self.layers[:-1]: | |
| nn.init.xavier_normal_(layer.weight.data) | |
| if init == "id": | |
| self.layers[-1].weight.data.zero_() | |
| self.layers[-1].bias.data.zero_() | |
| self.init = init | |
| def forward(self, x): | |
| out = x | |
| for layer in self.layers[:-1]: | |
| out = layer(out).relu() | |
| out = self.layers[-1](out) | |
| if self.init == "id": | |
| return out + x | |
| else: | |
| return out | |
| class KLinear(nn.Module): | |
| def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True): | |
| super().__init__() | |
| self.inf = inf | |
| in_fact = factorization(inf) | |
| out_fact = factorization(outf) | |
| total_params = 0 | |
| self.a, self.b = nn.ParameterList(), nn.ParameterList() | |
| for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)): | |
| new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1) | |
| if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0: | |
| break | |
| total_params += new_params | |
| self.a.append(nn.Parameter(torch.empty(o1, i1))) | |
| if symmetric: | |
| self.a.append(nn.Parameter(torch.empty(o2, i2))) | |
| self.b.append(nn.Parameter(torch.empty(o2, i2))) | |
| if symmetric: | |
| self.b.append(nn.Parameter(torch.empty(o1, i1))) | |
| assert self.a[-1].kron(self.b[-1]).shape == (outf, inf) | |
| for factor in self.a: | |
| nn.init.kaiming_normal_(factor.data) | |
| for factor in self.b: | |
| if zero_init: | |
| factor.data.zero_() | |
| else: | |
| nn.init.kaiming_normal_(factor.data) | |
| print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps") | |
| self.bias = nn.Parameter(torch.zeros(outf)) | |
| def forward(self, x): | |
| assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}" | |
| w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5) | |
| y = w @ x.T | |
| if self.bias is not None: | |
| y = y + self.bias | |
| return y | |
| class LRLinear(nn.Module): | |
| def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None): | |
| super().__init__() | |
| mid_dim = min(rank, inf) | |
| if init == "id": | |
| self.u = nn.Parameter(torch.zeros(outf, mid_dim)) | |
| self.v = nn.Parameter(torch.randn(mid_dim, inf)) | |
| elif init == "xavier": | |
| self.u = nn.Parameter(torch.empty(outf, mid_dim)) | |
| self.v = nn.Parameter(torch.empty(mid_dim, inf)) | |
| nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu")) | |
| nn.init.xavier_uniform_(self.v.data, gain=1.0) | |
| else: | |
| raise ValueError(f"Unrecognized initialization {init}") | |
| if n_modes is not None: | |
| self.mode_shift = nn.Embedding(n_modes, outf) | |
| self.mode_shift.weight.data.zero_() | |
| self.mode_scale = nn.Embedding(n_modes, outf) | |
| self.mode_scale.weight.data.fill_(1) | |
| self.n_modes = n_modes | |
| self.bias = nn.Parameter(torch.zeros(outf)) | |
| self.inf = inf | |
| self.init = init | |
| def forward(self, x, mode=None): | |
| if mode is not None: | |
| assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it" | |
| assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}" | |
| assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})" | |
| pre_act = (self.u @ (self.v @ x.T)).T | |
| if self.bias is not None: | |
| pre_act += self.bias | |
| if mode is not None: | |
| if not isinstance(mode, torch.Tensor): | |
| mode = torch.tensor(mode).to(x.device) | |
| scale, shift = self.mode_scale(mode), self.mode_shift(mode) | |
| pre_act = pre_act * scale + shift | |
| # need clamp instead of relu so gradient at 0 isn't 0 | |
| acts = pre_act.clamp(min=0) | |
| if self.init == "id": | |
| return acts + x | |
| else: | |
| return acts | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| indim: int, | |
| outdim: int, | |
| hidden_dim: int, | |
| n_hidden: int, | |
| init: str = "xavier_uniform", | |
| act: str = "relu", | |
| rank: int = None, | |
| ): | |
| super().__init__() | |
| self.init = init | |
| if act == "relu": | |
| self.act = nn.ReLU() | |
| elif act == "learned": | |
| self.act = ActMLP(10, 1) | |
| else: | |
| raise ValueError(f"Unrecognized activation function '{act}'") | |
| if hidden_dim is None: | |
| hidden_dim = outdim * 2 | |
| if init.startswith("id") and outdim != indim: | |
| LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})") | |
| outdim = indim | |
| if init == "id": | |
| old_hidden_dim = hidden_dim | |
| if hidden_dim < indim * 2: | |
| hidden_dim = indim * 2 | |
| if hidden_dim % indim != 0: | |
| hidden_dim += hidden_dim % indim | |
| if old_hidden_dim != hidden_dim: | |
| LOG.info( | |
| f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}" | |
| ) | |
| if init == "id_alpha": | |
| self.alpha = nn.Parameter(torch.zeros(1, outdim)) | |
| dims = [indim] + [hidden_dim] * n_hidden + [outdim] | |
| LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})") | |
| layers = [] | |
| for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])): | |
| if rank is None: | |
| layers.append(nn.Linear(ind, outd)) | |
| else: | |
| layers.append(LRLinear(ind, outd, rank=rank)) | |
| if idx < n_hidden: | |
| layers.append(self.act) | |
| if rank is None: | |
| if init == "id": | |
| if n_hidden > 0: | |
| layers[0].weight.data = torch.eye(indim).repeat( | |
| hidden_dim // indim, 1 | |
| ) | |
| layers[0].weight.data[hidden_dim // 2:] *= -1 | |
| layers[-1].weight.data = torch.eye(outdim).repeat( | |
| 1, hidden_dim // outdim | |
| ) | |
| layers[-1].weight.data[:, hidden_dim // 2:] *= -1 | |
| layers[-1].weight.data /= (hidden_dim // indim) / 2.0 | |
| for layer in layers: | |
| if isinstance(layer, nn.Linear): | |
| if init == "ortho": | |
| nn.init.orthogonal_(layer.weight) | |
| elif init == "id": | |
| if layer.weight.shape[0] == layer.weight.shape[1]: | |
| layer.weight.data = torch.eye(hidden_dim) | |
| else: | |
| gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0 | |
| nn.init.xavier_uniform_(layer.weight, gain=gain) | |
| layer.bias.data[:] = 0 | |
| layers[-1].bias = None | |
| self.mlp = nn.Sequential(*layers) | |
| def forward(self, x): | |
| if self.init == "id_alpha": | |
| return x + self.alpha * self.mlp(x) | |
| else: | |
| return self.mlp(x) | |
| if __name__ == "__main__": | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s", | |
| level=logging.INFO, | |
| ) | |
| m0 = MLP(1000, 1000, 1500, 3) | |
| m1 = MLP(1000, 1000, 1500, 3, init="id") | |
| m2 = MLP(1000, 1000, 1500, 3, init="id_alpha") | |
| m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned") | |
| x = 0.01 * torch.randn(999, 1000) | |
| y0 = m0(x) | |
| y1 = m1(x) | |
| y2 = m2(x) | |
| y3 = m3(x) | |
| print("y0", (y0 - x).abs().max()) | |
| print("y1", (y1 - x).abs().max()) | |
| print("y2", (y2 - x).abs().max()) | |
| print("y3", (y3 - x).abs().max()) | |
| assert not torch.allclose(y0, x) | |
| assert torch.allclose(y1, x) | |
| assert torch.allclose(y2, x) | |
| assert not torch.allclose(y3, x) | |
| import pdb; pdb.set_trace() # fmt: skip | |