Keras Data Augmentation Example in Python

   Data augmentation is one of the useful techniques in deep learning to improve the model training accuracy. In this method, we can generate additional training data from the existing samples by randomly transforming the images in a certain degree without losing the key characteristics of the target object which helps the model to generalize easily and decrease the overfitting.

    Keras API provides ImageDataGenerator class to augment image data. In this tutorial, we'll briefly learn how to create augmented data with ImageDataGenerator in Python. The tutorial covers:

  1. Loading the image
  2. Defining the ImageDataGenerator
  3. Generating images
  4. Source code listing
   We'll start by loading the required functions and libraries.

from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

 
Loading the image

    First, we'll load sample image to use in this tutorial. You can use any image you have to do this example.

path = "/Pictures/rabbit1.jpg"
image = load_img(path, target_size=(200, 250))
plt.imshow(image)
plt.show()
 
 
Then, convert it into the array type by using img_to_array() function and reshape its dimensions.

img_arr = img_to_array(image)
print(img_arr.shape)
 
(200, 250, 3)
 
img_arr = img_arr.reshape((1,)+img_arr.shape)
print(img_arr.shape)
 
(1, 200, 250, 3)


Defining the ImageDataGenerator

     Next, we'll define the image generator by using ImageDataGenerator class. Here, we can set the options we want to apply the image. You can check the Keras documentation to get more info about each option. 
  • rotation_range defines the rotaion degree
  • width_shift and height_shift translates the image vertically or horizontally
  • shear_range applies shearing transform
  • zoom_range zooms the picture
  • fil_mode fills newly created pixels
  • horizontal_flip flips horizontally

datagen = ImageDataGenerator(rotation_range=20,
                             width_shift_range=0.1, 
                             height_shift_range=0.1, 
                             shear_range=0.1, 
                             zoom_range=0.2, 
                             horizontal_flip=True)
 
 
Generating images
 
    Next, we'll generate image by using the datagen object. Here, we'll create 9 augmented images from the original image. 
 
n = 9
imgs = []
for i in datagen.flow(img_arr, batch_size=1):
    imgs.append(array_to_img(i[0], scale=True))
    if(len(imgs) == n):
        break

Finally, we'll plot the augmented images 
 
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
for i in range(0, n):
    plt.subplot(3, 3, i + 1) 
    plt.tick_params(labelbottom=False)
    plt.tick_params(labelleft=False)
    plt.imshow(imgs[i])
    
plt.show()
 



By changing the parameter values of the ImageDataGenerator class, you can get different outputs and increase the number of training data.

    In this tutorial, we've briefly learned how to generated augmented data by using ImageDataGenerator in Python.


Source code listing

from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import array_to_img
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

path = "/Pictures/rabbit1.jpg"
image = load_img(path, target_size=(200, 250))
plt.imshow(image)
plt.show()

img_arr = img_to_array(image)
print(img_arr.shape) 
 
img_arr = img_arr.reshape((1,)+img_arr.shape)
print(img_arr.shape)

datagen = ImageDataGenerator(rotation_range=20,
                             width_shift_range=0, 
                             height_shift_range=0, 
                             shear_range=0, 
                             zoom_range=0, 
                             horizontal_flip=True,
                             )

n = 9
imgs = []
for i in datagen.flow(img_arr, batch_size=1):
    imgs.append(array_to_img(i[0], scale=True))
    if(len(imgs) == n):
        break
        
plt.subplots_adjust(wspace=0, hspace=0)
plt.tight_layout()
for i in range(0, n):
    plt.subplot(3, 3, i + 1) 
    plt.tick_params(labelbottom=False)
    plt.tick_params(labelleft=False)
    plt.imshow(imgs[i])
    
plt.show() 
 


References:

No comments:

Post a Comment