Add usage example script
Browse files
README.md
CHANGED
|
@@ -27,6 +27,92 @@ See the README.md for detailed usage examples using our module abstractions.
|
|
| 27 |
- **Type**: Pretrained backbone (no fine-tuning)
|
| 28 |
- **Embedding Dimension**: Varies by backbone model
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
## Citation
|
| 31 |
|
| 32 |
If you use this model, please cite:
|
|
|
|
| 27 |
- **Type**: Pretrained backbone (no fine-tuning)
|
| 28 |
- **Embedding Dimension**: Varies by backbone model
|
| 29 |
|
| 30 |
+
## Usage Example
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
"""
|
| 34 |
+
Example: Using EmbeddingGemma-300m for Tag Classification
|
| 35 |
+
|
| 36 |
+
This example shows how to use the pretrained EmbeddingGemma-300m backbone
|
| 37 |
+
for zero-shot tag classification using our module abstractions.
|
| 38 |
+
|
| 39 |
+
Installation:
|
| 40 |
+
pip install git+https://github.com/Pieces/TAG-module.git@main
|
| 41 |
+
# Or: pip install -e .
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
from tags_model.models.backbone import SharedTextBackbone
|
| 46 |
+
from playground.validate_from_checkpoint import compute_ranked_tags
|
| 47 |
+
|
| 48 |
+
# Load the pretrained backbone
|
| 49 |
+
print("Loading EmbeddingGemma-300m...")
|
| 50 |
+
backbone = SharedTextBackbone(
|
| 51 |
+
model_name="google/embeddinggemma-300m",
|
| 52 |
+
embedding_dim=768,
|
| 53 |
+
freeze_backbone=True,
|
| 54 |
+
pooling_mode="mean",
|
| 55 |
+
)
|
| 56 |
+
backbone.eval()
|
| 57 |
+
print("✓ Model loaded!")
|
| 58 |
+
|
| 59 |
+
# Example query
|
| 60 |
+
query_text = "How to implement OAuth2 authentication in a Python Flask API?"
|
| 61 |
+
|
| 62 |
+
# Candidate tags to rank
|
| 63 |
+
candidate_tags = [
|
| 64 |
+
"python", "flask", "oauth2", "authentication", "api",
|
| 65 |
+
"security", "web-development", "jwt", "rest-api", "backend"
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
print(f"\nQuery: {query_text}")
|
| 69 |
+
print(f"Candidate tags: {candidate_tags}\n")
|
| 70 |
+
|
| 71 |
+
# Encode query and tags
|
| 72 |
+
with torch.inference_mode():
|
| 73 |
+
query_emb = backbone.encode_texts([query_text], max_length=2048, return_dict=False)[0]
|
| 74 |
+
tag_embs = backbone.encode_texts(candidate_tags, max_length=2048, return_dict=False)
|
| 75 |
+
|
| 76 |
+
print(f"Query embedding shape: {query_emb.shape}")
|
| 77 |
+
print(f"Tag embeddings shape: {tag_embs.shape}")
|
| 78 |
+
|
| 79 |
+
# Rank tags by similarity
|
| 80 |
+
ranked_tags = compute_ranked_tags(
|
| 81 |
+
query_emb=query_emb,
|
| 82 |
+
pos_embs=torch.empty(0, 768), # No positives for zero-shot
|
| 83 |
+
neg_embs=torch.empty(0, 768), # No negatives for zero-shot
|
| 84 |
+
general_embs=tag_embs,
|
| 85 |
+
positive_tags=[],
|
| 86 |
+
negative_tags=[],
|
| 87 |
+
general_tags=candidate_tags,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Display top-ranked tags
|
| 91 |
+
print("\n" + "="*60)
|
| 92 |
+
print("Top Ranked Tags:")
|
| 93 |
+
print("="*60)
|
| 94 |
+
for tag, rank, label, score in ranked_tags[:5]:
|
| 95 |
+
print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")
|
| 96 |
+
|
| 97 |
+
print("\n" + "="*60)
|
| 98 |
+
print("Example complete!")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Running the Example
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
# Install the repository first
|
| 107 |
+
pip install git+https://github.com/Pieces/TAG-module.git@main
|
| 108 |
+
# Or for local development:
|
| 109 |
+
pip install -e .
|
| 110 |
+
|
| 111 |
+
# Run the example
|
| 112 |
+
python embeddinggemma_example.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
|
| 116 |
## Citation
|
| 117 |
|
| 118 |
If you use this model, please cite:
|