|
|
import random |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from model import predict |
|
|
from datasets import load_dataset |
|
|
|
|
|
dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train") |
|
|
|
|
|
def classify_image(img: Image.Image): |
|
|
if img is None: |
|
|
return "No image uploaded", 0, {} |
|
|
|
|
|
label, confidence, probs = predict(img) |
|
|
return ( |
|
|
label, |
|
|
round(confidence, 3), |
|
|
{k: round(v, 3) for k, v in probs.items()} |
|
|
) |
|
|
|
|
|
|
|
|
def random_example(): |
|
|
item = random.choice(dataset) |
|
|
img = item["image"].convert("RGB") |
|
|
label = dataset.features["label"].int2str(item["label"]) |
|
|
|
|
|
return img, img, label |
|
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
|
|
with demo: |
|
|
gr.Markdown("## Animal Image Classifier with Random Dataset Samples") |
|
|
|
|
|
with gr.Row(): |
|
|
input_img = gr.Image(type="pil", label="Upload an image") |
|
|
rand_img = gr.Button("Random Dataset Image") |
|
|
|
|
|
pred_btn = gr.Button("Predict") |
|
|
|
|
|
output_label = gr.Label(label="Predicted Class") |
|
|
output_conf = gr.Number(label="Confidence") |
|
|
output_probs = gr.JSON(label="All Probabilities") |
|
|
|
|
|
rand_display = gr.Image(type="pil", label="Random Dataset Sample") |
|
|
rand_label = gr.Textbox(label="Sample Label") |
|
|
|
|
|
|
|
|
pred_btn.click( |
|
|
classify_image, |
|
|
inputs=input_img, |
|
|
outputs=[output_label, output_conf, output_probs] |
|
|
) |
|
|
|
|
|
|
|
|
rand_img.click( |
|
|
random_example, |
|
|
outputs=[input_img, rand_display, rand_label] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |