Spaces:
Runtime error
Runtime error
Changed parallel to nn
Browse files
superposed/llama/superposed_model.py
CHANGED
|
@@ -199,39 +199,31 @@ class Attention(nn.Module):
|
|
| 199 |
"""
|
| 200 |
super().__init__()
|
| 201 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 202 |
-
model_parallel_size =
|
| 203 |
self.n_local_heads = args.n_heads // model_parallel_size
|
| 204 |
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
| 205 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 206 |
self.head_dim = args.dim // args.n_heads
|
| 207 |
|
| 208 |
-
self.wq =
|
| 209 |
args.dim,
|
| 210 |
args.n_heads * self.head_dim,
|
| 211 |
bias=False,
|
| 212 |
-
gather_output=False,
|
| 213 |
-
init_method=lambda x: x,
|
| 214 |
)
|
| 215 |
-
self.wk =
|
| 216 |
args.dim,
|
| 217 |
self.n_kv_heads * self.head_dim,
|
| 218 |
-
bias=False
|
| 219 |
-
gather_output=False,
|
| 220 |
-
init_method=lambda x: x,
|
| 221 |
)
|
| 222 |
-
self.wv =
|
| 223 |
args.dim,
|
| 224 |
self.n_kv_heads * self.head_dim,
|
| 225 |
-
bias=False
|
| 226 |
-
gather_output=False,
|
| 227 |
-
init_method=lambda x: x,
|
| 228 |
)
|
| 229 |
-
self.wo =
|
| 230 |
args.n_heads * self.head_dim,
|
| 231 |
args.dim,
|
| 232 |
-
bias=False
|
| 233 |
-
input_is_parallel=True,
|
| 234 |
-
init_method=lambda x: x,
|
| 235 |
)
|
| 236 |
|
| 237 |
self.cache_k = torch.zeros(
|
|
@@ -336,14 +328,14 @@ class FeedForward(nn.Module):
|
|
| 336 |
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 337 |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 338 |
|
| 339 |
-
self.w1 =
|
| 340 |
-
dim, hidden_dim, bias=False
|
| 341 |
)
|
| 342 |
-
self.w2 =
|
| 343 |
-
hidden_dim, dim, bias=False
|
| 344 |
)
|
| 345 |
-
self.w3 =
|
| 346 |
-
dim, hidden_dim, bias=False
|
| 347 |
)
|
| 348 |
|
| 349 |
def forward(self, x):
|
|
@@ -435,12 +427,12 @@ class SuperposedTransformer(nn.Module):
|
|
| 435 |
self.vocab_size = params.vocab_size
|
| 436 |
self.n_layers = params.n_layers
|
| 437 |
|
| 438 |
-
self.tok_embeddings =
|
| 439 |
-
params.vocab_size, params.dim
|
| 440 |
)
|
| 441 |
|
| 442 |
-
self.tok_mixing_embeddings =
|
| 443 |
-
params.vocab_size, params.dim, bias=False
|
| 444 |
) # dims here are formality (what matters is below)
|
| 445 |
self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
|
| 446 |
|
|
@@ -449,8 +441,8 @@ class SuperposedTransformer(nn.Module):
|
|
| 449 |
self.layers.append(MixedTransformerBlock(layer_id, params))
|
| 450 |
|
| 451 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 452 |
-
self.output =
|
| 453 |
-
params.dim, params.vocab_size, bias=False
|
| 454 |
)
|
| 455 |
|
| 456 |
self.freqs_cis = precompute_freqs_cis(
|
|
|
|
| 199 |
"""
|
| 200 |
super().__init__()
|
| 201 |
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
| 202 |
+
model_parallel_size = 1
|
| 203 |
self.n_local_heads = args.n_heads // model_parallel_size
|
| 204 |
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
| 205 |
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
| 206 |
self.head_dim = args.dim // args.n_heads
|
| 207 |
|
| 208 |
+
self.wq = nn.Linear(
|
| 209 |
args.dim,
|
| 210 |
args.n_heads * self.head_dim,
|
| 211 |
bias=False,
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
+
self.wk = nn.Linear(
|
| 214 |
args.dim,
|
| 215 |
self.n_kv_heads * self.head_dim,
|
| 216 |
+
bias=False
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
+
self.wv = nn.Linear(
|
| 219 |
args.dim,
|
| 220 |
self.n_kv_heads * self.head_dim,
|
| 221 |
+
bias=False
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
+
self.wo = nn.Linear(
|
| 224 |
args.n_heads * self.head_dim,
|
| 225 |
args.dim,
|
| 226 |
+
bias=False
|
|
|
|
|
|
|
| 227 |
)
|
| 228 |
|
| 229 |
self.cache_k = torch.zeros(
|
|
|
|
| 328 |
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 329 |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 330 |
|
| 331 |
+
self.w1 = nn.Linear(
|
| 332 |
+
dim, hidden_dim, bias=False
|
| 333 |
)
|
| 334 |
+
self.w2 = nn.Linear(
|
| 335 |
+
hidden_dim, dim, bias=False
|
| 336 |
)
|
| 337 |
+
self.w3 = nn.Linear(
|
| 338 |
+
dim, hidden_dim, bias=False
|
| 339 |
)
|
| 340 |
|
| 341 |
def forward(self, x):
|
|
|
|
| 427 |
self.vocab_size = params.vocab_size
|
| 428 |
self.n_layers = params.n_layers
|
| 429 |
|
| 430 |
+
self.tok_embeddings = nn.Embedding(
|
| 431 |
+
params.vocab_size, params.dim
|
| 432 |
)
|
| 433 |
|
| 434 |
+
self.tok_mixing_embeddings = nn.Linear(
|
| 435 |
+
params.vocab_size, params.dim, bias=False
|
| 436 |
) # dims here are formality (what matters is below)
|
| 437 |
self.tok_mixing_embeddings.weight = nn.Parameter(self.tok_embeddings.weight.T)
|
| 438 |
|
|
|
|
| 441 |
self.layers.append(MixedTransformerBlock(layer_id, params))
|
| 442 |
|
| 443 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 444 |
+
self.output = nn.Linear(
|
| 445 |
+
params.dim, params.vocab_size, bias=False
|
| 446 |
)
|
| 447 |
|
| 448 |
self.freqs_cis = precompute_freqs_cis(
|