edgerefinement1 / app.py
Janeka's picture
Update app.py
b67f566 verified
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os
# U^2-Net model definition
class U2NET(torch.nn.Module):
def __init__(self, out_ch=1):
super(U2NET, self).__init__()
# Simplified U^2-Net architecture
self.stage1 = torch.nn.Sequential(
torch.nn.Conv2d(3, 64, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, padding=1),
torch.nn.ReLU()
)
self.stage2 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(64, 128, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, padding=1),
torch.nn.ReLU()
)
self.stage3 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(128, 256, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(256, 256, 3, padding=1),
torch.nn.ReLU()
)
self.stage4 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(256, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU()
)
self.stage5 = torch.nn.Sequential(
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(512, 512, 3, padding=1),
torch.nn.ReLU()
)
self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2)
self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv_final = torch.nn.Conv2d(64, out_ch, 1)
def forward(self, x):
# Encoder
x1 = self.stage1(x)
x2 = self.stage2(x1)
x3 = self.stage3(x2)
x4 = self.stage4(x3)
x5 = self.stage5(x4)
# Decoder with skip connections
u5 = self.up5(x5)
u4 = self.up4(u5 + x4)
u3 = self.up3(u4 + x3)
u2 = self.up2(u3 + x2)
return torch.sigmoid(self.conv_final(u2 + x1))
def load_model():
model = U2NET()
# Load pre-trained weights (dummy initialization for demo)
# In production, you would load actual trained weights here
for m in model.modules():
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
return model.eval()
model = load_model()
def refine_edges(image, threshold=0.5):
"""Refine edges using U^2-Net"""
# Preprocess
img = np.array(image)
if len(img.shape) == 2:
img = np.stack([img]*3, axis=-1)
elif img.shape[2] == 4:
img = img[..., :3]
img = cv2.resize(img, (320, 320))
tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0
# Inference
with torch.no_grad():
matte = model(tensor)
# Post-process
matte = F.interpolate(matte, image.size[::-1], mode='bilinear')
matte = (matte.squeeze().numpy() * 255).astype(np.uint8)
_, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)
# Create transparent result
rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA)
rgba[..., 3] = matte
return Image.fromarray(rgba), Image.fromarray(matte)
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)")
with gr.Row():
with gr.Column():
input_img = gr.Image(type="pil", label="Input Image")
threshold = gr.Slider(0, 100, 50, label="Edge Threshold")
process_btn = gr.Button("Refine Edges", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="Refined Image")
matte_img = gr.Image(type="pil", label="Alpha Matte")
process_btn.click(
fn=refine_edges,
inputs=[input_img, threshold],
outputs=[output_img, matte_img]
)
if __name__ == "__main__":
demo.launch()