Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| import pathlib | |
| import sys | |
| import gradio as gr | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| import torchvision.transforms as T # noqa: N812 | |
| from huggingface_hub import hf_hub_download | |
| sys.path.insert(0, "CelebAMask-HQ/face_parsing") | |
| from unet import unet # pyright: ignore[reportMissingImports] | |
| from utils import generate_label # pyright: ignore[reportMissingImports] | |
| TITLE = "CelebAMask-HQ Face Parsing" | |
| DESCRIPTION = "This is an unofficial demo for the model provided in https://github.com/switchablenorms/CelebAMask-HQ." | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| transform = T.Compose( | |
| [ | |
| T.Resize((512, 512), interpolation=PIL.Image.NEAREST), | |
| T.ToTensor(), | |
| T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| path = hf_hub_download("public-data/CelebAMask-HQ-Face-Parsing", "models/model.pth") | |
| state_dict = torch.load(path, map_location="cpu") | |
| model = unet() | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| model.to(device) | |
| def predict(image: PIL.Image.Image) -> np.ndarray: | |
| data = transform(image) | |
| data = data.unsqueeze(0).to(device) | |
| out = model(data) | |
| out = generate_label(out, 512) | |
| out = out[0].cpu().numpy().transpose(1, 2, 0) | |
| out = np.clip(np.round(out * 255), 0, 255).astype(np.uint8) | |
| res = np.asarray(image.resize((512, 512))).astype(float) * 0.5 + out.astype(float) * 0.5 | |
| res = np.clip(np.round(res), 0, 255).astype(np.uint8) | |
| return out, res | |
| examples = sorted(pathlib.Path("images").glob("*.jpg")) | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(label="Input", type="pil"), | |
| outputs=[ | |
| gr.Image(label="Predicted Labels"), | |
| gr.Image(label="Masked"), | |
| ], | |
| examples=examples, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |