Spaces:
Sleeping
Sleeping
File size: 4,176 Bytes
3cd9d3e 2dcab43 1dca306 7dc2911 2dcab43 b67f566 7dc2911 b67f566 1dca306 b67f566 1dca306 b67f566 1dca306 b67f566 ebfc1e7 b67f566 1dca306 b67f566 1dca306 b67f566 2dcab43 b67f566 2dcab43 1dca306 b67f566 2dcab43 b67f566 2dcab43 b67f566 626b6d0 2dcab43 1dca306 b67f566 2dcab43 1dca306 b67f566 2dcab43 b67f566 2dcab43 1dca306 b67f566 2dcab43 7dc2911 2dcab43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os
# U^2-Net model definition
class U2NET(torch.nn.Module):
def __init__(self, out_ch=1):
super(U2NET, self).__init__()
# Simplified U^2-Net architecture
self.stage1 = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, padding=1),
torch.nn.ReLU()
)
self.stage2 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, padding=1),
torch.nn.ReLU()
)
self.stage3 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU()
)
self.stage4 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(256, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU()
)
self.stage5 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU()
)
self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2)
self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv_final = torch.nn.Conv2d(64, out_ch, 1)
def forward(self, x):
# Encoder
x1 = self.stage1(x)
x2 = self.stage2(x1)
x3 = self.stage3(x2)
x4 = self.stage4(x3)
x5 = self.stage5(x4)
# Decoder with skip connections
u5 = self.up5(x5)
u4 = self.up4(u5 + x4)
u3 = self.up3(u4 + x3)
u2 = self.up2(u3 + x2)
return torch.sigmoid(self.conv_final(u2 + x1))
def load_model():
model = U2NET()
# Load pre-trained weights (dummy initialization for demo)
# In production, you would load actual trained weights here
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
return model.eval()
model = load_model()
def refine_edges(image, threshold=0.5):
"""Refine edges using U^2-Net"""
# Preprocess
img = np.array(image)
if len(img.shape) == 2:
img = np.stack([img]*3, axis=-1)
elif img.shape[2] == 4:
img = img[..., :3]
img = cv2.resize(img, (320, 320))
tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0
# Inference
with torch.no_grad():
matte = model(tensor)
# Post-process
matte = F.interpolate(matte, image.size[::-1], mode='bilinear')
matte = (matte.squeeze().numpy() * 255).astype(np.uint8)
_, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)
# Create transparent result
rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA)
rgba[..., 3] = matte
return Image.fromarray(rgba), Image.fromarray(matte)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image")
threshold = gr.Slider(0, 100, 50, label="Edge Threshold")
process_btn = gr.Button("Refine Edges", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="Refined Image")
matte_img = gr.Image(type="pil", label="Alpha Matte")
process_btn.click(
fn=refine_edges,
inputs=[input_img, threshold],
outputs=[output_img, matte_img]
)
if __name__ == "__main__":
demo.launch() |