AIOmarRehan's picture
Update app.py
8608fd9 verified
import random
import gradio as gr
from PIL import Image
from model import predict
from datasets import load_dataset
dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
def classify_image(img: Image.Image):
if img is None:
return "No image uploaded", 0, {}
label, confidence, probs = predict(img)
return (
label,
round(confidence, 3),
{k: round(v, 3) for k, v in probs.items()}
)
# Random example from the dataset
def random_example():
item = random.choice(dataset)
img = item["image"].convert("RGB")
label = dataset.features["label"].int2str(item["label"])
# Return image twice: once for input_img (for prediction), once for display
return img, img, label
# Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("## Animal Image Classifier with Random Dataset Samples")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload an image")
rand_img = gr.Button("Random Dataset Image")
pred_btn = gr.Button("Predict")
output_label = gr.Label(label="Predicted Class")
output_conf = gr.Number(label="Confidence")
output_probs = gr.JSON(label="All Probabilities")
rand_display = gr.Image(type="pil", label="Random Dataset Sample")
rand_label = gr.Textbox(label="Sample Label")
# Predict button uses whatever image is currently in input_img
pred_btn.click(
classify_image,
inputs=input_img,
outputs=[output_label, output_conf, output_probs]
)
# Random button picks a dataset image
rand_img.click(
random_example,
outputs=[input_img, rand_display, rand_label]
)
if __name__ == "__main__":
demo.launch()