embeddinggemma-300m-tag-classification

This is a pretrained backbone model (google/embeddinggemma-300m) used for tag classification via contrastive learning.

Model Description

This model uses the google/embeddinggemma-300m backbone directly without fine-tuning. It's designed for zero-shot tag classification tasks where you want to use a pretrained embedding model for semantic similarity computation.

Usage

See the README.md for detailed usage examples using our module abstractions.

Model Architecture

  • Backbone: google/embeddinggemma-300m
  • Type: Pretrained backbone (no fine-tuning)
  • Embedding Dimension: Varies by backbone model

Usage Example

"""
Example: Using EmbeddingGemma-300m for Tag Classification

This example shows how to use the pretrained EmbeddingGemma-300m backbone
for zero-shot tag classification using our module abstractions.

Installation:
    pip install git+https://github.com/Pieces/TAG-module.git@main
    # Or: pip install -e .
"""

import torch
from tags_model.models.backbone import SharedTextBackbone
from playground.validate_from_checkpoint import compute_ranked_tags

# Load the pretrained backbone
print("Loading EmbeddingGemma-300m...")
backbone = SharedTextBackbone(
    model_name="google/embeddinggemma-300m",
    embedding_dim=768,
    freeze_backbone=True,
    pooling_mode="mean",
)
backbone.eval()
print("โœ“ Model loaded!")

# Example query
query_text = "How to implement OAuth2 authentication in a Python Flask API?"

# Candidate tags to rank
candidate_tags = [
    "python", "flask", "oauth2", "authentication", "api",
    "security", "web-development", "jwt", "rest-api", "backend"
]

print(f"\nQuery: {query_text}")
print(f"Candidate tags: {candidate_tags}\n")

# Encode query and tags
with torch.inference_mode():
    query_emb = backbone.encode_texts([query_text], max_length=2048, return_dict=False)[0]
    tag_embs = backbone.encode_texts(candidate_tags, max_length=2048, return_dict=False)

print(f"Query embedding shape: {query_emb.shape}")
print(f"Tag embeddings shape: {tag_embs.shape}")

# Rank tags by similarity
ranked_tags = compute_ranked_tags(
    query_emb=query_emb,
    pos_embs=torch.empty(0, 768),  # No positives for zero-shot
    neg_embs=torch.empty(0, 768),  # No negatives for zero-shot
    general_embs=tag_embs,
    positive_tags=[],
    negative_tags=[],
    general_tags=candidate_tags,
)

# Display top-ranked tags
print("\n" + "="*60)
print("Top Ranked Tags:")
print("="*60)
for tag, rank, label, score in ranked_tags[:5]:
    print(f"{rank:2d}. {tag:20s} (score: {score:.4f})")

print("\n" + "="*60)
print("Example complete!")

Running the Example

# Install the repository first
pip install git+https://github.com/Pieces/TAG-module.git@main
# Or for local development:
pip install -e .

# Run the example
python embeddinggemma_example.py

Citation

If you use this model, please cite:

@software{{tag_module,
  title = {{TAG Module: Persona-Conditioned Contrastive Learning for Tag Classification}},
  author = {{Your Name}},
  year = {{2025}},
  url = {{https://github.com/yourusername/tag-module}}
}}

License

Please refer to the original model license for the backbone model.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support