antreaspiece commited on
Commit
da892a4
·
verified ·
1 Parent(s): ab0c839

Add usage example script

Browse files
Files changed (1) hide show
  1. README.md +86 -0
README.md CHANGED
@@ -27,6 +27,92 @@ See the README.md for detailed usage examples using our module abstractions.
27
  - **Type**: Pretrained backbone (no fine-tuning)
28
  - **Embedding Dimension**: Varies by backbone model
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  ## Citation
31
 
32
  If you use this model, please cite:
 
27
  - **Type**: Pretrained backbone (no fine-tuning)
28
  - **Embedding Dimension**: Varies by backbone model
29
 
30
+ ## Usage Example
31
+
32
+ ```python
33
+ """
34
+ Example: Using EmbeddingGemma-300m for Tag Classification
35
+
36
+ This example shows how to use the pretrained EmbeddingGemma-300m backbone
37
+ for zero-shot tag classification using our module abstractions.
38
+
39
+ Installation:
40
+ pip install git+https://github.com/Pieces/TAG-module.git@main
41
+ # Or: pip install -e .
42
+ """
43
+
44
+ import torch
45
+ from tags_model.models.backbone import SharedTextBackbone
46
+ from playground.validate_from_checkpoint import compute_ranked_tags
47
+
48
+ # Load the pretrained backbone
49
+ print("Loading EmbeddingGemma-300m...")
50
+ backbone = SharedTextBackbone(
51
+ model_name="google/embeddinggemma-300m",
52
+ embedding_dim=768,
53
+ freeze_backbone=True,
54
+ pooling_mode="mean",
55
+ )
56
+ backbone.eval()
57
+ print("✓ Model loaded!")
58
+
59
+ # Example query
60
+ query_text = "How to implement OAuth2 authentication in a Python Flask API?"
61
+
62
+ # Candidate tags to rank
63
+ candidate_tags = [
64
+ "python", "flask", "oauth2", "authentication", "api",
65
+ "security", "web-development", "jwt", "rest-api", "backend"
66
+ ]
67
+
68
+ print(f"\nQuery: {query_text}")
69
+ print(f"Candidate tags: {candidate_tags}\n")
70
+
71
+ # Encode query and tags
72
+ with torch.inference_mode():
73
+ query_emb = backbone.encode_texts([query_text], max_length=2048, return_dict=False)[0]
74
+ tag_embs = backbone.encode_texts(candidate_tags, max_length=2048, return_dict=False)
75
+
76
+ print(f"Query embedding shape: {query_emb.shape}")
77
+ print(f"Tag embeddings shape: {tag_embs.shape}")
78
+
79
+ # Rank tags by similarity
80
+ ranked_tags = compute_ranked_tags(
81
+ query_emb=query_emb,
82
+ pos_embs=torch.empty(0, 768), # No positives for zero-shot
83
+ neg_embs=torch.empty(0, 768), # No negatives for zero-shot
84
+ general_embs=tag_embs,
85
+ positive_tags=[],
86
+ negative_tags=[],
87
+ general_tags=candidate_tags,
88
+ )
89
+
90
+ # Display top-ranked tags
91
+ print("\n" + "="*60)
92
+ print("Top Ranked Tags:")
93
+ print("="*60)
94
+ for tag, rank, label, score in ranked_tags[:5]:
95
+ print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")
96
+
97
+ print("\n" + "="*60)
98
+ print("Example complete!")
99
+
100
+
101
+ ```
102
+
103
+ ### Running the Example
104
+
105
+ ```bash
106
+ # Install the repository first
107
+ pip install git+https://github.com/Pieces/TAG-module.git@main
108
+ # Or for local development:
109
+ pip install -e .
110
+
111
+ # Run the example
112
+ python embeddinggemma_example.py
113
+ ```
114
+
115
+
116
  ## Citation
117
 
118
  If you use this model, please cite: