AIOmarRehan's picture
Update app.py
d8c58cd verified
raw
history blame
3.93 kB
import gradio as gr
import numpy as np
from PIL import Image
import os
import random
from collections import Counter, defaultdict
from app.model import predict
from app.preprocess import preprocess_audio
# Dataset Paths (download manually from Hugging Face or put in space files)
AUDIO_DATASET_DIR = "General_Audio_Dataset"
IMAGE_DATASET_DIR = "Mel_Spectrogram_Images_for_Audio_Classification"
# Get file lists safely
audio_files = [
os.path.join(AUDIO_DATASET_DIR, f)
for f in os.listdir(AUDIO_DATASET_DIR)
if f.lower().endswith((".wav", ".mp3"))
] if os.path.exists(AUDIO_DATASET_DIR) else []
image_files = [
os.path.join(IMAGE_DATASET_DIR, f)
for f in os.listdir(IMAGE_DATASET_DIR)
if f.lower().endswith(".png")
] if os.path.exists(IMAGE_DATASET_DIR) else []
def safe_load_image(img):
if img is None:
return None
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
img = img.convert("RGBA")
return img
# Process image
def process_image_input(img):
img = safe_load_image(img)
label, confidence, probs = predict(img)
return label, round(confidence, 3), probs
# Process audio
def process_audio_input(audio_path):
imgs = preprocess_audio(audio_path)
all_preds, all_confs, all_probs = [], [], []
for img in imgs:
label, conf, probs = predict(img)
all_preds.append(label)
all_confs.append(conf)
all_probs.append(probs)
# Majority vote
counter = Counter(all_preds)
max_count = max(counter.values())
candidates = [k for k, v in counter.items() if v == max_count]
if len(candidates) == 1:
final_label = candidates[0]
else:
conf_sums = defaultdict(float)
for i, lbl in enumerate(all_preds):
if lbl in candidates:
conf_sums[lbl] += all_confs[i]
final_label = max(conf_sums, key=conf_sums.get)
final_conf = float(np.mean([all_confs[i] for i, lbl in enumerate(all_preds) if lbl == final_label]))
return final_label, round(final_conf, 3), all_preds, [round(c, 3) for c in all_confs]
# Main classifier
def classify(audio_path, image, random_audio=False, random_image=False):
# Pick random audio if selected
if random_audio and audio_files:
audio_path = random.choice(audio_files)
# Pick random image if selected
if random_image and image_files:
img_path = random.choice(image_files)
image = Image.open(img_path).convert("RGBA")
# If spectrogram image
if image is not None:
label, conf, probs = process_image_input(image)
return {
"Final Label": label,
"Confidence": conf,
"Details": probs
}, label
# If raw audio
if audio_path is not None:
label, conf, all_preds, all_confs = process_audio_input(audio_path)
return {
"Final Label": label,
"Confidence": conf,
"All Chunk Labels": all_preds,
"All Chunk Confidences": all_confs
}, label
return "Please upload an audio file OR a spectrogram image.", ""
# Gradio Interface
interface = gr.Interface(
fn=classify,
inputs=[
gr.Audio(type="filepath", label="Upload Audio (WAV/MP3)"),
gr.Image(type="pil", label="Upload Spectrogram Image (PNG RGBA Supported)"),
gr.Checkbox(label="Pick Random Audio from Dataset"),
gr.Checkbox(label="Pick Random Image from Dataset"),
],
outputs=[
gr.JSON(label="Prediction Results"),
gr.Textbox(label="Final Label", interactive=False)
],
title="General Audio Classifier (Audio + Spectrogram Support)",
description=(
"Upload a raw audio file OR a spectrogram image.\n"
"You can also select random samples from the local datasets.\n"
"The output shows a JSON with all details and a separate field for the final label."
),
)
interface.launch()