| from util import UIDataset, Vocabulary |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from model import * |
| from torchvision import transforms |
| from PIL import Image |
|
|
| |
| net = Pix2Code() |
| net.load_state_dict(torch.load('./pix2code.weights')) |
| net.cuda().eval() |
|
|
| |
| vocab = Vocabulary('voc.pkl') |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| |
| def generate_gui(image): |
| |
| image = transform(image).unsqueeze(0).cuda() |
| |
| |
| context = torch.tensor([vocab.to_vec(' '), vocab.to_vec('<START>')]).unsqueeze(0).float().cuda() |
| |
| |
| code = [] |
| |
| |
| for i in range(200): |
| |
| index = torch.argmax(net(image, context), 2).squeeze()[-1:].squeeze() |
| |
| |
| token = vocab.to_vocab(int(index)) |
| |
| |
| if token == '<END>': |
| break |
| |
| |
| code.append(token) |
| |
| |
| context = torch.cat([context, torch.tensor([vocab.to_vec(token)]).unsqueeze(0).float().cuda()], dim=1) |
| |
| |
| return ''.join(code) |
|
|
| import gradio as gr |
|
|
| |
| image_input = gr.inputs.Image() |
|
|
| |
| text_output = gr.outputs.Textbox() |
|
|
| |
| iface = gr.Interface( |
| fn=generate_gui, |
| inputs=image_input, |
| outputs=text_output, |
| title='Pix2Code', |
| description='Gerador de código GUI a partir de imagens', |
| theme='default' |
| ) |
|
|
| |
| iface.launch() |
|
|