File size: 2,796 Bytes
7d1df75
 
 
 
 
 
c9e3c19
1302388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e3c19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d1df75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9e3c19
 
 
900716e
 
 
 
 
c9e3c19
 
 
 
 
 
 
 
 
 
 
 
7d1df75
c9e3c19
7d1df75
 
 
 
 
 
 
155887e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import torch
import gradio as gr
from PIL import Image
from torchvision.transforms import transforms

EN_US = os.getenv("LANG") != "zh_CN.UTF-8"

if EN_US:
    import huggingface_hub

    MODEL_DIR = huggingface_hub.snapshot_download(
        "Genius-Society/HEp2",
        cache_dir="./__pycache__",
    )

else:
    import modelscope

    MODEL_DIR = modelscope.snapshot_download(
        "Genius-Society/HEp2",
        cache_dir="./__pycache__",
    )

ZH2EN = {
    "上传细胞图像": "Upload a cell picture",
    "状态栏": "Status",
    "图片名": "Picture name",
    "识别结果": "Recognition result",
    "请上传 PNG 格式的 HEp2 细胞图片": "It is recommended to upload HEp2 cell images in PNG format.",
}


def _L(zh_txt: str):
    return ZH2EN[zh_txt] if EN_US else zh_txt


TRANSLATE = {
    "Centromere": "着丝粒",
    "Golgi": "高尔基体",
    "Homogeneous": "同质",
    "NuMem": "记忆体",
    "Nucleolar": "核仁",
    "Speckled": "斑核",
}
CLASSES = list(TRANSLATE.keys())


def embeding(img_path: str):
    compose = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomAffine(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    img = Image.open(img_path).convert("RGB")
    return compose(img)


def infer(target: str):
    status = "Success"
    filename = result = None
    try:
        model = torch.load(
            f"{MODEL_DIR}/save.pt",
            map_location=torch.device("cpu"),
            weights_only=False,
        )
        if not target:
            raise ValueError("请上传细胞图片")

        torch.cuda.empty_cache()
        input: torch.Tensor = embeding(target)
        output: torch.Tensor = model(input.unsqueeze(0))
        predict = torch.max(output.data, 1)[1]
        filename = os.path.basename(target)
        result = CLASSES[predict] if EN_US else TRANSLATE[CLASSES[predict]]

    except Exception as e:
        status = f"{e}"

    return status, filename, result


if __name__ == "__main__":
    example_imgs = []
    for cls in CLASSES:
        example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png")

    gr.Interface(
        fn=infer,
        inputs=gr.Image(type="filepath", label=_L("上传细胞图像")),
        outputs=[
            gr.Textbox(label=_L("状态栏"), show_copy_button=True),
            gr.Textbox(label=_L("图片名"), show_copy_button=True),
            gr.Textbox(label=_L("识别结果"), show_copy_button=True),
        ],
        title=_L("请上传 PNG 格式的 HEp2 细胞图片"),
        examples=example_imgs,
        flagging_mode="never",
        cache_examples=False,
    ).launch(ssr_mode=False)