Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| # U^2-Net model definition | |
| class U2NET(torch.nn.Module): | |
| def __init__(self, out_ch=1): | |
| super(U2NET, self).__init__() | |
| # Simplified U^2-Net architecture | |
| self.stage1 = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(64, 64, 3, padding=1), | |
| torch.nn.ReLU() | |
| ) | |
| self.stage2 = torch.nn.Sequential( | |
| torch.nn.MaxPool2d(2, 2), | |
| torch.nn.Conv2d(64, 128, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(128, 128, 3, padding=1), | |
| torch.nn.ReLU() | |
| ) | |
| self.stage3 = torch.nn.Sequential( | |
| torch.nn.MaxPool2d(2, 2), | |
| torch.nn.Conv2d(128, 256, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(256, 256, 3, padding=1), | |
| torch.nn.ReLU() | |
| ) | |
| self.stage4 = torch.nn.Sequential( | |
| torch.nn.MaxPool2d(2, 2), | |
| torch.nn.Conv2d(256, 512, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(512, 512, 3, padding=1), | |
| torch.nn.ReLU() | |
| ) | |
| self.stage5 = torch.nn.Sequential( | |
| torch.nn.MaxPool2d(2, 2), | |
| torch.nn.Conv2d(512, 512, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.Conv2d(512, 512, 3, padding=1), | |
| torch.nn.ReLU() | |
| ) | |
| self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2) | |
| self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2) | |
| self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2) | |
| self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2) | |
| self.conv_final = torch.nn.Conv2d(64, out_ch, 1) | |
| def forward(self, x): | |
| # Encoder | |
| x1 = self.stage1(x) | |
| x2 = self.stage2(x1) | |
| x3 = self.stage3(x2) | |
| x4 = self.stage4(x3) | |
| x5 = self.stage5(x4) | |
| # Decoder with skip connections | |
| u5 = self.up5(x5) | |
| u4 = self.up4(u5 + x4) | |
| u3 = self.up3(u4 + x3) | |
| u2 = self.up2(u3 + x2) | |
| return torch.sigmoid(self.conv_final(u2 + x1)) | |
| def load_model(): | |
| model = U2NET() | |
| # Load pre-trained weights (dummy initialization for demo) | |
| # In production, you would load actual trained weights here | |
| for m in model.modules(): | |
| if isinstance(m, torch.nn.Conv2d): | |
| torch.nn.init.kaiming_normal_(m.weight) | |
| return model.eval() | |
| model = load_model() | |
| def refine_edges(image, threshold=0.5): | |
| """Refine edges using U^2-Net""" | |
| # Preprocess | |
| img = np.array(image) | |
| if len(img.shape) == 2: | |
| img = np.stack([img]*3, axis=-1) | |
| elif img.shape[2] == 4: | |
| img = img[..., :3] | |
| img = cv2.resize(img, (320, 320)) | |
| tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0 | |
| # Inference | |
| with torch.no_grad(): | |
| matte = model(tensor) | |
| # Post-process | |
| matte = F.interpolate(matte, image.size[::-1], mode='bilinear') | |
| matte = (matte.squeeze().numpy() * 255).astype(np.uint8) | |
| _, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY) | |
| # Create transparent result | |
| rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA) | |
| rgba[..., 3] = matte | |
| return Image.fromarray(rgba), Image.fromarray(matte) | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(type="pil", label="Input Image") | |
| threshold = gr.Slider(0, 100, 50, label="Edge Threshold") | |
| process_btn = gr.Button("Refine Edges", variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(type="pil", label="Refined Image") | |
| matte_img = gr.Image(type="pil", label="Alpha Matte") | |
| process_btn.click( | |
| fn=refine_edges, | |
| inputs=[input_img, threshold], | |
| outputs=[output_img, matte_img] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |