Digit Classifier
- 4 min read

Digit Classifier

On this page
Introduction

I wanted to learn PyTorch so here is a project that uses PyTorch to build Convolutional Neural Network (CNN) model that can classify digits on a canvas. Here is the final product: https://cnn.mohammedx.tech/

Requirements

pip install numpy matplotlib tensorflow scikit-learn streamlit streamlit-drawable-canvas pillow pandas

Don't forget to use a virtual env.

from numpy import mean, std
from matplotlib import pyplot as plt
from sklearn.model_selection import KFold
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.optimizers import SGD

def load_dataset():
    (trainX, trainY), (testX, testY) = mnist.load_data()
    trainX = trainX.reshape((trainX.shape[0], 28, 28, 1))
    testX = testX.reshape((testX.shape[0], 28, 28, 1))
    trainY = to_categorical(trainY)
    testY = to_categorical(testY)
    return trainX, trainY, testX, testY

def prep_pixels(train, test):
    train_norm = train.astype('float32') / 255.0
    test_norm = test.astype('float32') / 255.0
    return train_norm, test_norm

def define_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D((2, 2)))
    model.add(Flatten())
    model.add(Dense(100, activation='relu', kernel_initializer='he_uniform'))
    model.add(Dense(10, activation='softmax'))
    opt = SGD(learning_rate=0.01, momentum=0.9)
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def evaluate_model(dataX, dataY, n_folds=5):
    scores, histories = list(), list()
    kfold = KFold(n_folds, shuffle=True, random_state=1)
    for train_ix, test_ix in kfold.split(dataX):
        model = define_model()
        trainX, trainY = dataX[train_ix], dataY[train_ix]
        testX, testY = dataX[test_ix], dataY[test_ix]
        history = model.fit(trainX, trainY, epochs=10, batch_size=32, validation_data=(testX, testY), verbose=0)
        _, acc = model.evaluate(testX, testY, verbose=0)
        print('> %.3f' % (acc * 100.0))
        scores.append(acc)
        histories.append(history)
    return scores, histories

def summarize_diagnostics(histories):
    for i in range(len(histories)):
        plt.subplot(2, 1, 1)
        plt.title('Cross Entropy Loss')
        plt.plot(histories[i].history['loss'], color='blue')
        plt.plot(histories[i].history['val_loss'], color='orange')
        plt.subplot(2, 1, 2)
        plt.title('Classification Accuracy')
        plt.plot(histories[i].history['accuracy'], color='blue')
        plt.plot(histories[i].history['val_accuracy'], color='orange')
    plt.show()

def summarize_performance(scores):
    print('Accuracy: mean=%.3f std=%.3f, n=%d' % (mean(scores)*100, std(scores)*100, len(scores)))
    plt.boxplot(scores)
    plt.show()

def run_test_harness():
    trainX, trainY, testX, testY = load_dataset()
    trainX, testX = prep_pixels(trainX, testX)
    scores, histories = evaluate_model(trainX, trainY)
    summarize_diagnostics(histories)
    summarize_performance(scores)

run_test_harness()

1- Data Loading and Preprocessing

load_dataset()

  • Loads the MNIST dataset (handwritten digits 0–9).
  • Applies one-hot encoding to labels for training.

prep_pixels(train, test)

  • Converts images from uint8 to float32.
  • Normalizes pixel values from [0, 255] to [0.0, 1.0].

define_model()

  • Builds a simple CNN:
    • Conv2D: Detects patterns in images.
    • MaxPooling2D: Reduces spatial dimensions.
    • Flatten: Flattens 2D to 1D.
    • Dense: Fully connected layers to classify digits.
  • Compiles with SGD optimizer and categorical cross-entropy loss.

2- Evaluation with K-Fold Cross-Validation

evaluate_model(dataX, dataY, n_folds=5)

  • Splits the training data into 5 folds.
  • Trains and validates on different splits to reduce overfitting.
  • Tracks accuracy and learning history.

summarize_diagnostics

  • Good to visualize validation accuracy/loss

summarize_performance(scores)

  • Displays average model accuracy and variation.
  • Plots a boxplot of performance across folds.

3- Model Training and Saving

run_test_harness()

  • Manages the full training and evaluation flow.
  • Calls:
    • load_dataset()
    • prep_pixels()
    • evaluate_model()
    • summarize_diagnostics()
    • summarize_performance()

save_model()

  • Trains the model once on all training data (outside cross-validation).
  • Saves the trained model as final_model.h5.

4- Making Predictions

import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets import mnist

def classify_digit(image):
    model = load_model('final_model.h5')
    image = image.reshape(1, 28, 28, 1)
    image = image.astype('float32') / 255.0
    prediction = model.predict(image)
    return np.argmax(prediction, axis=1)[0]

(trainX, trainY), (testX, testY) = mnist.load_data()
sample_image = testX[0]
digit_class = classify_digit(sample_image)
print("Predicted class:", digit_class)

classify_digit(image)

  • Loads the saved model.
  • Accepts a 28x28 grayscale image.
  • Normalizes and reshapes it.
  • Predicts the digit using argmax of softmax probabilities.

4- Build the Front-End

import streamlit as st
from streamlit_drawable_canvas import st_canvas
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
import pandas as pd

@st.cache_resource
def load_mnist_model():
    try:
        return load_model('final_model.h5')
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None

model = load_mnist_model()

st.title("MNIST Digit Classifier")
st.markdown("Draw a digit on the canvas below and see the model predict the digit!")

st.sidebar.header("Configuration")
b_color = st.sidebar.color_picker("Brush color", "#000000")
bg_color = st.sidebar.color_picker("Background color", "#FFFFFF")
drawing_mode = st.sidebar.checkbox("Drawing mode?", True)

canvas_result = st_canvas(
    stroke_width=20,
    stroke_color=b_color,
    background_color=bg_color,
    height=280,
    width=280,
    drawing_mode='freedraw' if drawing_mode else 'transform',
    key="canvas"
)

def preprocess_image(image_data):
    img = Image.fromarray(image_data.astype('uint8'), 'RGBA').convert('L')
    img = img.resize((28, 28))
    img = Image.eval(img, lambda x: 255 - x)
    img = np.array(img).astype('float32') / 255.0
    img = img.reshape(1, 28, 28, 1)
    return img

if model is not None and canvas_result.image_data is not None:
    img = preprocess_image(canvas_result.image_data)
    prediction = model.predict(img)
    pred_digit = np.argmax(prediction)
    probabilities = prediction[0]
    st.write(f"Predicted digit: **{pred_digit}**")
    prob_df = pd.DataFrame(probabilities, index=range(10), columns=["Probability"])
    st.bar_chart(prob_df)
else:
    st.write("Please draw a digit on the canvas.")

@st.cache_resource + load_mnist_model()

  • Loads the trained model once using caching for performance.

st_canvas

  • Canvas where the user can draw a digit.

preprocess_image(image_data)

  • Converts canvas image to grayscale.
  • Resizes to 28x28.
  • Inverts pixel colors (white-on-black → black-on-white).
  • Normalizes and reshapes to CNN input format.

Live Prediction and Visualization

  • Once the user draws:
    • The image is processed and passed to the model.
    • Prediction is shown alongside a bar chart of probabilities (using st.bar_chart()).

Run the App

streamlit run app.py