File size: 4,176 Bytes
3cd9d3e
2dcab43
 
1dca306
7dc2911
2dcab43
b67f566
7dc2911
b67f566
 
 
 
 
 
 
 
 
 
1dca306
b67f566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dca306
b67f566
 
 
 
 
 
 
 
 
 
 
 
 
 
1dca306
b67f566
 
 
 
 
 
 
 
ebfc1e7
b67f566
 
 
 
1dca306
b67f566
 
 
 
 
1dca306
b67f566
 
2dcab43
 
 
b67f566
2dcab43
1dca306
b67f566
 
 
2dcab43
b67f566
 
 
2dcab43
b67f566
626b6d0
2dcab43
1dca306
b67f566
2dcab43
 
1dca306
b67f566
2dcab43
 
b67f566
 
2dcab43
 
1dca306
b67f566
 
2dcab43
7dc2911
 
2dcab43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import os

# U^2-Net model definition
class U2NET(torch.nn.Module):
    def __init__(self, out_ch=1):
        super(U2NET, self).__init__()
        # Simplified U^2-Net architecture
        self.stage1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, 3, padding=1),
            torch.nn.ReLU()
        )
        self.stage2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(64, 128, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, 3, padding=1),
            torch.nn.ReLU()
        )
        self.stage3 = torch.nn.Sequential(
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(128, 256, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, 3, padding=1),
            torch.nn.ReLU()
        )
        self.stage4 = torch.nn.Sequential(
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(256, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU()
        )
        self.stage5 = torch.nn.Sequential(
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, 3, padding=1),
            torch.nn.ReLU()
        )
        self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv_final = torch.nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        # Encoder
        x1 = self.stage1(x)
        x2 = self.stage2(x1)
        x3 = self.stage3(x2)
        x4 = self.stage4(x3)
        x5 = self.stage5(x4)
        
        # Decoder with skip connections
        u5 = self.up5(x5)
        u4 = self.up4(u5 + x4)
        u3 = self.up3(u4 + x3)
        u2 = self.up2(u3 + x2)
        
        return torch.sigmoid(self.conv_final(u2 + x1))

def load_model():
    model = U2NET()
    # Load pre-trained weights (dummy initialization for demo)
    # In production, you would load actual trained weights here
    for m in model.modules():
        if isinstance(m, torch.nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight)
    return model.eval()

model = load_model()

def refine_edges(image, threshold=0.5):
    """Refine edges using U^2-Net"""
    # Preprocess
    img = np.array(image)
    if len(img.shape) == 2:
        img = np.stack([img]*3, axis=-1)
    elif img.shape[2] == 4:
        img = img[..., :3]
    
    img = cv2.resize(img, (320, 320))
    tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0
    
    # Inference
    with torch.no_grad():
        matte = model(tensor)
    
    # Post-process
    matte = F.interpolate(matte, image.size[::-1], mode='bilinear')
    matte = (matte.squeeze().numpy() * 255).astype(np.uint8)
    _, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)
    
    # Create transparent result
    rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA)
    rgba[..., 3] = matte
    
    return Image.fromarray(rgba), Image.fromarray(matte)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)")
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="pil", label="Input Image")
            threshold = gr.Slider(0, 100, 50, label="Edge Threshold")
            process_btn = gr.Button("Refine Edges", variant="primary")
        with gr.Column():
            output_img = gr.Image(type="pil", label="Refined Image")
            matte_img = gr.Image(type="pil", label="Alpha Matte")
    
    process_btn.click(
        fn=refine_edges,
        inputs=[input_img, threshold],
        outputs=[output_img, matte_img]
    )

if __name__ == "__main__":
    demo.launch()