MNIST Image Classification with PyTorch

    In this tutorial, we'll learn how to build a convolutional neural network (CNN) using PyTorch to classify handwritten digits from the MNIST dataset. The MNIST dataset consists of 28x28 pixel grayscale images of handwritten digits (0-9), and the task is to correctly identify which digit is represented in each image. The tutorial covers:

  1. Preparing data
  2. Model definition
  3. Model training
  4. Model evaluation
  5. Prediction
  6. Conclusion

    This tutorial requires the following prerequisites, which you should have:

  1. Basic understanding of the Python programming language.
  2. Familiarity with deep learning concepts, especially convolutional neural networks (CNNs).
  3. Installation of the PyTorch library. You can install it via pip:
 
pip install torch torchvision
 

 

Preparing data

    We start by loading the necessary libraries.

 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import os
     

    The first step is to download and prepare the MNIST dataset. PyTorch provides a convenient way to do this using the torchvision.datasets module. We add defined transformer to normalize the data.

    To train the model quickly, I use only 40 percent of the data in the MNIST dataset. If you prefer to train the model on the full dataset, you can skip the subsetting part of the code. We randomly select 40% of the data and create subsets for both the training and test sets. Then define the data loaders.

 
# Download and load the MNIST training data, then randomly select 40% of it for loading.
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
subset_indices = torch.randperm(len(trainset))[:int(0.4 * len(trainset))]
trainset = Subset(trainset, subset_indices)

# Download and load the MNIST test data, then randomly select 40% of it for loading.
testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
subset_indices_test = torch.randperm(len(testset))[:int(0.4 * len(testset))]
testset = Subset(testset, subset_indices_test)

# Define the data loaders
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=False)


Model definition

    Next, we'll define our CNN model architecture. For this task, we'll use a simple CNN architecture consisting of convolutional layers, activation functions, pooling layers, and fully connected layers. In forward method we implement convolution operation (conv1 and conv2) followed by ReLU activation function, perform max pooling with a 2x2 kernel size, apply dropout regularization, flatten the output to 1D tensor, and return the final tensor representing class probabilities.

 
# Define the CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = nn.ReLU()(x)
x = self.conv2(x)
x = nn.ReLU()(x)
x = nn.MaxPool2d(2)(x)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.ReLU()(x)
x = self.dropout2(x)
x = self.fc2(x)
return x
 

 

Model training

    Now, let's train the model using the training dataset. We'll define the loss function, optimizer, and then iterate over the dataset to update the model parameters.

 
# Initialize the model, loss function, and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

We save the trained model with the specified file name. The training is carried out if the trained model file is not available; otherwise, it loads the model file to use. 

# Save the trained model
model_name = './mnist_cnn.pth'

# Train the model if it doesn't exist
if not os.path.exists(model_name):
# Training the model
for epoch in range(5): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0

print('Finished Training')
torch.save(model.state_dict(), model_name)
print("Model saved successfully.")
 

The model training proceeds as follows: 


[1, 100] loss: 0.861
[1, 200] loss: 0.324
[1, 300] loss: 0.229 
....
[5, 300] loss: 0.058
[5, 400] loss: 0.057
[5, 500] loss: 0.068
[5, 600] loss: 0.071
[5, 700] loss: 0.059
Finished Training
Model saved successfully.

 

Model evaluation and prediction

    Finally, let's evaluate the trained model on the test dataset to measure its performance. We load the pre-trained model weights from the specified file, set the model to evaluation mode, disable gradient computation for efficiency, and iterate over the test dataset. Finally, we compute and print the accuracy of the model on the test data.

 
# Load the pre-trained model weights from the specified file
model = CNN()
model.load_state_dict(torch.load(model_name))
model.eval()

# Testing the model
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
# Determine the predicted class for each image
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the {total} test images: {100 * correct / total}")

Next we implement visualization of a batch of test images along with their ground truth and predicted labels, providing a visual assessment of the model's performance.

 
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()


# Get some random test images
dataiter = iter(testloader)
images, labels = next(dataiter) # Get the next batch of data

# print and visualize the first 8 images
print('GroundTruth: ', ' '.join('%5s' % labels[j].item() for j in range(8)))

# Predict the class for images
outputs = model(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % predicted[j].item() for j in range(8)))

# Show images
imshow(torchvision.utils.make_grid(images[:8]))
 

The result appear as below:

 
Accuracy of the network on the 4000 test images: 98.8
GroundTruth: 2 0 5 0 3 8 5 6
Predicted: 2 0 5 0 3 8 5 6 


Conclusion

    In this tutorial, we learned how to build a CNN model using PyTorch for image classification on the MNIST dataset. We defined the model architecture, trained it on the training dataset, and evaluated its performance on the test dataset. This tutorial serves as a basic introduction to deep learning with PyTorch and can be extended to more complex tasks and datasets.

 

Source code listing

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import os


# Define transforms to normalize the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# Download and load the MNIST training data, then randomly select 40% of it for loading.
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
subset_indices = torch.randperm(len(trainset))[:int(0.4 * len(trainset))]
trainset = Subset(trainset, subset_indices)

# Download and load the MNIST test data, then randomly select 40% of it for loading.
testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
subset_indices_test = torch.randperm(len(testset))[:int(0.4 * len(testset))]
testset = Subset(testset, subset_indices_test)

# Define the data loaders
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# Define the CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = nn.ReLU()(x)
x = self.conv2(x)
x = nn.ReLU()(x)
x = nn.MaxPool2d(2)(x)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.ReLU()(x)
x = self.dropout2(x)
x = self.fc2(x)
return x

# Initialize the model, loss function, and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Save the trained model
model_name = './mnist_cnn.pth'

# Train the model if it doesn't exist
if not os.path.exists(model_name):
# Training the model
for epoch in range(5): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 100 == 99: # print every 100 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0

print('Finished Training')
torch.save(model.state_dict(), model_name)
print("Model saved successfully.")

# Load the model
model = CNN()
model.load_state_dict(torch.load(model_name))
model.eval()

# Testing the model
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
# Determine the predicted class for each image
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the {total} test images: {100 * correct / total}")

# Predicting a new test image with the loaded model
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()


# Get some random test images
dataiter = iter(testloader)
images, labels = next(dataiter) # Get the next batch of data

# print and visualize the first 8 images
print('GroundTruth: ', ' '.join('%5s' % labels[j].item() for j in range(8)))

# Predict the class for images
outputs = model(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % predicted[j].item() for j in range(8)))

# Show images
imshow(torchvision.utils.make_grid(images[:8]))

 

 

 

 

 

 

 

No comments:

Post a Comment