Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from huggingface_hub import hf_hub_download | |
| # ========== Model Definition ========== | |
| class MobileViTSegmentation(nn.Module): | |
| def __init__(self, encoder_name='mobilevit_s', pretrained=True): | |
| super().__init__() | |
| self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained) | |
| self.encoder_channels = self.backbone.feature_info.channels() | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(64, 32, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| feats = self.backbone(x) | |
| out = self.decoder(feats[-1]) | |
| out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False) | |
| return out | |
| # ========== Load Model ========== | |
| def load_model(): | |
| cache_dir = "/tmp/huggingface" # Safe writable directory in HF Spaces | |
| checkpoint_path = hf_hub_download( | |
| repo_id="svsaurav95/ToothSegmentation", | |
| filename="mobilevit_teeth_segmentation.pth", | |
| cache_dir=cache_dir | |
| ) | |
| model = MobileViTSegmentation(pretrained=False) | |
| model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ========== Image Preprocessing ========== | |
| transform = T.Compose([ | |
| T.Resize((256, 256)), | |
| T.ToTensor() | |
| ]) | |
| # ========== UI ========== | |
| st.set_page_config(page_title="Tooth Segmentation", layout="wide") | |
| st.title("Tooth Segmentation using MobileViT") | |
| uploaded_file = st.file_uploader("Upload a mouth image with visible teeth", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| input_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred_mask = model(input_tensor)[0, 0].numpy() | |
| # Threshold and resize to original | |
| pred_mask = (pred_mask > 0.7).astype(np.uint8) * 255 | |
| pred_mask = Image.fromarray(pred_mask).resize(image.size) | |
| # Create translucent blue overlay | |
| overlay = Image.new("RGBA", image.size, (0, 0, 255, 100)) | |
| base = image.convert("RGBA") | |
| pred_mask_rgba = Image.new("L", image.size, 0) | |
| pred_mask_rgba.paste(255, mask=pred_mask) | |
| final = Image.composite(overlay, base, pred_mask_rgba) | |
| # Side-by-side display | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original Image", use_container_width=True) | |
| with col2: | |
| st.image(final, caption="Tooth Area Segmentation", use_container_width=True) | |