BICORP commited on
Commit
8215be9
·
verified ·
1 Parent(s): 613e8f2

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -70
inference.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from safetensors.torch import load_file
4
- from transformers import BertTokenizer
5
-
6
- class Gemma3ForConditionalGeneration(nn.Module):
7
- def __init__(self, vocab_size, embedding_dim=1344, num_heads=64, num_layers=48):
8
- super(Gemma3ForConditionalGeneration, self).__init__()
9
- self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
10
- self.transformer_layers = nn.ModuleList([
11
- nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
12
- ])
13
- self.output_layer = nn.Linear(embedding_dim, vocab_size)
14
-
15
- def forward(self, input_ids):
16
- text_embeddings = self.token_embeddings(input_ids)
17
- for layer in self.transformer_layers:
18
- text_embeddings = layer(text_embeddings)
19
- return self.output_layer(text_embeddings)
20
-
21
- def load_model(model_path, vocab_size):
22
- model = Gemma3ForConditionalGeneration(vocab_size=vocab_size)
23
- model_weights = load_file(model_path)
24
- model.load_state_dict(model_weights, strict=False)
25
- model.eval()
26
- return model
27
-
28
- def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0):
29
- input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device) # Move to GPU
30
- generated_ids = input_ids
31
-
32
- for _ in range(max_length):
33
- with torch.no_grad():
34
- outputs = model(generated_ids)
35
- next_token_logits = outputs[:, -1, :] # Get the logits for the last token
36
-
37
- # Apply temperature
38
- next_token_logits = next_token_logits / temperature
39
-
40
- # Use softmax to get probabilities
41
- probabilities = torch.softmax(next_token_logits, dim=-1)
42
-
43
- # Sample from the distribution
44
- next_token = torch.multinomial(probabilities, num_samples=1) # Sample a token
45
-
46
- # Reshape next_token to match dimensions for concatenation
47
- next_token = next_token.unsqueeze(0) # Add batch dimension
48
- next_token = next_token.squeeze(2) # Remove the last dimension
49
-
50
- generated_ids = torch.cat((generated_ids, next_token), dim=1) # Append the predicted token
51
-
52
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
53
- return generated_text
54
-
55
- if __name__ == "__main__":
56
- # Check if GPU is available and set device accordingly
57
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
-
59
- vocab_size = 262208 // 4
60
- model_path = './model.safetensors' # Replace with your model path
61
- model = load_model(model_path, vocab_size).to(device) # Move model to GPU
62
-
63
- # Load the default tokenizer
64
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
65
-
66
- prompt = "Hello! Can you say me some info that you know?"
67
-
68
- # Generate text based on the prompt with a specified temperature
69
- generated_text = generate_text(model, tokenizer, prompt, temperature=0.7) # Adjust temperature as needed
70
- print("Generated Text:", generated_text)