MINST / app.py
ved1beta
MINST ready
fd35fa1
raw
history blame
1.49 kB
import gradio as gr
import numpy as np
import pickle
from PIL import Image
# Load the model
with open('model.pkl', 'rb') as f:
model_params = pickle.load(f)
W1 = model_params['W1']
b1 = model_params['b1']
W2 = model_params['W2']
b2 = model_params['b2']
def ReLu(Z):
return np.maximum(Z, 0)
def softmax(Z):
return np.exp(Z) / sum(np.exp(Z))
def forward_prop(W1, b1, W2, b2, X):
Z1 = W1.dot(X) + b1
A1 = ReLu(Z1)
Z2 = W2.dot(A1) + b2
A2 = softmax(Z2)
return Z1, Z2, A1, A2
def get_predictions(A2):
return np.argmax(A2, 0)
def preprocess_image(image):
# Convert to grayscale
img = image.convert('L')
# Resize the image
img = img.resize((28, 28))
# Convert to numpy array and normalize
img_array = np.array(img).reshape(1, 28*28) / 255.0
return img_array.T # Transpose to match the shape (784, 1)
def predict_digit(image):
X = preprocess_image(image)
# Forward propagation
_, _, _, A2 = forward_prop(W1, b1, W2, b2, X)
# Get the prediction
prediction = get_predictions(A2)
return int(prediction[0])
# Gradio interface
iface = gr.Interface(
fn=predict_digit,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=1),
title="Handwritten Digit Recognition",
description="Upload an image of a handwritten digit (0-9) and the model will predict which digit it is."
)
iface.launch()