Using VGG for Image Classification with PyTorch

     In this tutorial, we'll learn how to use a pre-trained VGG model for image classification in PyTorch.  We'll go through the steps of loading a pre-trained model, preprocessing image, and using the model to predict its class label, as well as displaying the results.The tutorial covers:
  1. Introduction to VGG networks
  2. Load a Pre-Trained VGG16 Model
  3. Define Image Preprocessing
  4. Load ImageNet Class Labels
  5. Make a Prediction
  6. Conclusion
  7. Full code listing

 

Introduction to VGG networks

    VGG (Visual Geometry Group) networks are a family of deep convolutional neural networks that were introduced by the Visual Geometry Group from the University of Oxford in the paper titled "Very Deep Convolutional Networks for Large-Scale Image Recognition" in 2014. VGG models are famous for their simplicity and effectiveness, making them a popular choice in the field of computer vision.

    Key Characteristics of VGG Networks

  • Deep Architecture: VGG networks, like VGG16 and VGG19, have 16 and 19 layers, respectively. This depth enables the model to recognize complex patterns in data.

  • Small Convolutional Filters: VGG uses 3x3 filters, unlike earlier models that used larger ones (e.g., 7x7). These small filters, when stacked, effectively cover a larger area and capture more detailed features.

  • Uniform Design: The architecture is consistent, with the same 3x3 filters and 2x2 max-pooling throughout. This uniformity simplifies design and enhances robustness.

  • Fully Connected Layers: VGG ends with three fully connected layers for making predictions, with the final layer outputting class probabilities.

  • Large Model Size: A downside is the large number of parameters (e.g., 138 million in VGG16), making the model resource-intensive.

   Limitations

    VGG networks have several limitations. Their primary issue is the large number of parameters, with models like VGG16 containing around 138 million, leading to high computational and memory costs. This size can make training and prediction slow, especially on resource-constrained devices. Additionally, VGG doesn’t include skip connections, which makes its deeper versions (like VGG16, VGG19) more likely to face vanishing gradient issues.

  

Load a Pre-Trained VGG16 Model

    Before starting, make sure you have the following Python libraries installed:

  • torch (PyTorch)
  • torchvision (for pre-trained models and transformations)
  • PIL (Python Imaging Library to handle image files)
  • matplotlib (for displaying images)
  • requests (for downloading class labels)

    You can install these libraries using pip.

 
pip install torch torchvision pillow matplotlib requests 
 

    PyTorch provides a variety of pre-trained models via the torchvision library. In this tutorial, we use the VGG16 model, which has been pre-trained on the ImageNet dataset. We’ll load the model and set it to evaluation mode (which disables certain layers like dropout that are used only during training).

 
import torch
from torchvision import models

# Load the pre-trained VGG16 model
model = models.vgg16(pretrained=True)

# Set the model to evaluation mode (this disables dropout and batch normalization layers)
model.eval()
 

    

Define Image Preprocessing

    To use the VGG16 model, the input image needs to be preprocessed in the same way the model was trained. For VGG16, this includes resizing, center-cropping, and normalizing the image. We’ll use torchvision.transforms to define the following transformations:

  1. Resize the image to 256x256 pixels.
  2. Center-crop the image to 224x224 pixels (VGG16's input size).
  3. Convert the image to a tensor.
  4. Normalize the image with the same mean and standard deviation used in ImageNet training.
 
from torchvision import transforms

# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256), # Resize the image to 256x256 pixels
transforms.CenterCrop(224), # Crop the center 224x224 pixels
transforms.ToTensor(), # Convert the image to a tensor
# Normalize with ImageNet mean and std
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

 

Load ImageNet Class Labels

    The model outputs a tensor of raw scores corresponding to ImageNet class labels. We need to download these labels to interpret the output. We'll fetch the class labels from PyTorch's GitHub repository using the requests library and convert them into a Python list.

 
import requests

# URL to fetch the ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"

# Send a GET request to download the class labels
response = requests.get(url)
response.raise_for_status() # Check if the request was successful

# Convert the response text directly into a list of class labels
class_labels = [line.strip() for line in response.text.splitlines()]

# Print the first 10 class labels as a quick check
print(class_labels[:10])
 

The output of class_labels:

 
['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich']

 

Load and Preprocess the Image

    Next, we’ll load a sample image, apply the transformations, and prepare it for the model. The image is loaded using the PIL library.

 
from PIL import Image

# Path to the local image file
image_path = "/test/vgg/images/IMG_1036.JPG" # Replace with your local image path

# Load and preprocess the image
img = Image.open(image_path) # Open the image file
img_t = transform(img) # Apply the transformations
img_t = img_t.unsqueeze(0) # Add batch dimension (required by the model)

 

 Make a Prediction

    The image is ready, we can pass it through the VGG16 model to get predictions. The output will be a tensor of raw scores for each class. We’ll use the following steps:

  1. Perform a forward pass through the network.
  2. Get the predicted class index using torch.max().
  3. Convert the predicted scores to probabilities using softmax.
  4. Map the predicted index to the corresponding class label.
 
import torch
import torch.nn as nn

# Forward pass through the network
output = model(img_t)

# The output is a tensor of raw scores for each class
# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities using softmax
probabilities = nn.Softmax(dim=1)(output)

# Get the predicted class label and its probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Print the result
print(f"Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}")

Finally, we’ll display the input image alongside its predicted class label and probability using matplotlib.

 
import matplotlib.pyplot as plt

# Display the image with the predicted class and probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off') # Hide axes for a cleaner display
plt.show()




Conclusion

    This tutorial showed how to use a pre-trained VGG16 model in PyTorch to classify an image. You learned about:

  • VGG model architecture
  • Loading the VGG16 model.
  • Preprocessing an image with the correct transformations.
  • Making predictions and interpret the results using class labels.

 Complete code for this tutorial is listed below.

 

Full code listing

 
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests

# Load the pre-trained VGG16 model
model = models.vgg16(pretrained=True)
model.eval()

# Define the transformation for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Download ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(url)
response.raise_for_status()
class_labels = [line.strip() for line in response.text.splitlines()]

# Load and preprocess the image
image_path = "/path/to/your/image.jpg" # Replace with your image path
img = Image.open(image_path)
img_t = transform(img)
img_t = img_t.unsqueeze(0) # Add batch dimension

# Forward pass through the network
output = model(img_t)

# Get the predicted class index
_, predicted = torch.max(output, 1)

# Convert the predicted scores to probabilities
probabilities = nn.Softmax(dim=1)(output)

# Get predicted class label and probability
predicted_class_label = class_labels[predicted.item()]
predicted_probability = probabilities[0, predicted].item()

# Display the image with predicted class and probability
plt.imshow(img)
plt.title(f'Predicted: {predicted_class_label}, Probability: {predicted_probability:.4f}')
plt.axis('off')
plt.show()
 

 

 


No comments:

Post a Comment