Transfer Learning for Image Classification

Image classification is one of the areas of deep learning that has developed very rapidly over the last decade. However, due to limited computation resources and training data, many companies found it difficult to train a good image classification model. Therefore, one of the emerging techniques that overcomes this barrier is the concept of transfer learning.

What is Transfer Learning?

Transfer learning involves taking a pre-trained model, extracting one of the layers, then taking that as the input layer to a series of dense layers. This pre-trained model is usually trained by institutions or companies that have much larger computation and financial resources. Some of these popular trained models for image recognition tasks are VGG, Inception and ResNet.

Using this newly formed model, we can then set the parameters within the pre-trained model to be non-trainable while only optimizing the parameters of the subsequent dense layers during training.

Due to limited computation resources and training data, many companies found it difficult to train a good image classification model

In order to illustrate the value of transfer learning, I will be comparing a simple convolutional neural network model against a model that utilises transfer learning in the following examples.

Implementation

The following sections will be focusing on implementation using Python.

Dataset

Before I go into the comparison, I will like to introduct you to the Fashion MNist dataset. This dataset consist of 10 different apparel classes, each of them is a 28x28 grayscale image. Fashion MNist was created to test the performance of categorical image classifier, making it ideal for the task that we are trying accomplish.

Note that you will have to download the images as PNG files for the following examples. Please refer to this repository for the steps to obtain the dataset.

Import libraries

Let’s start off by importing the necessary libraries

import os
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt

Defining data generators

train_dir = 'data/train'
validation_dir = 'data/test'

train_datagen = ImageDataGenerator(rescale=1./255,
                                   rotation_range=40,
                                   width_shift_range=0.2,
                                   height_shift_range=0.2,
                                   zoom_range=0.2,
                                   shear_range=0.2,
                                   horizontal_flip=True,
                                   fill_mode='nearest')  # with data augmentation for train set

valid_datagen = ImageDataGenerator(rescale=1./255)  # no augmentation for validation set

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    batch_size=100,
                                                    class_mode='categorical',
                                                    target_size=(150, 150))


validation_generator = valid_datagen.flow_from_directory(validation_dir,
                                                         batch_size=100,
                                                         class_mode='categorical',
                                                         target_size=(150, 150))

Define a callback to stop training after certain performance is achieved

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc') > 0.99 and logs.get('val_acc') > 0.99):
            print("\nCancelling training as model has reached 99% accuracy and 99% validation accuracy!")
            self.model.stop_training = True

Define function to plot result

def plot_result(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs = range(len(acc))

    plt.plot(epochs, acc, 'r', label='Training accuracy')
    plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
    plt.title('Training and validation accuracy')
    plt.legend(loc=0)
    plt.figure()
    plt.show()

Simple Convolutional Neural Network

For the simple CNN model, I will be having 3 convolutional layers followed by a single dense layer and the output layer.

Defining CNN model

img_width, img_height = 150, 150

if K.image_data_format() == 'channels_first':
    input_shape = (3, img_width, img_height)
else:
    input_shape = (img_width, img_height, 3)

inputs = layers.Input(shape=input_shape)

x = layers.Conv2D(32, (2, 2), activation='relu')(inputs)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Conv2D(32, (2, 2), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Conv2D(64, (2, 2), activation='relu')(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)

x = layers.Flatten()(x)

x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)

outputs = layers.Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=RMSprop(lr=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Model summary

model.summary()
plot_model(model, to_file='cnn_model.png', show_shapes=True, show_layer_names=True)

cnn_model_architecture
Fig. 1. CNN Model Architecture

Fitting the model

callbacks = myCallback()
history = model.fit_generator(generator=train_generator,
                              validation_data=validation_generator,
                              steps_per_epoch=100,
                              epochs=10,
                              validation_steps=100,
                              verbose=2,
                              callbacks=[callbacks])

Plotting model training and validation result

plot_result(history)

cnn_model_result
Fig. 2. CNN Model Result

Using a simple CNN model, we are able to achieve a validation accuracy of 0.7 after 10 epochs. Can we do better using transfer learning?

Transfer Learning using Inception v3

Inception is a convolutional neural network architecture introduced by Google which achieved top results in ImageNet Large Scale Visual Recognition Challenge 2014.

inception_v3
Fig. 3. Inception V3 - (Image source: here)

Download model weights, import model, load weights into model

# Download Inception v3 weights to local machine
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 \
    -O /tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5

# Import Inception v3 Model
from tensorflow.keras.applications.inception_v3 import InceptionV3
pre_trained_model = InceptionV3(input_shape=(150, 150, 3), include_top=False, weights=None)

# Load Inception v3 weights into model
local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
pre_trained_model.load_weights(local_weights_file)

Set layers to be non-trainable for pre-trained model

for layer in pre_trained_model.layers:
    layer.trainable = False

Model summary of Inception v3

pre_trained_model.summary()
plot_model(pre_trained_model, to_file='inception_v3_model.png', show_shapes=False, show_layer_names=True)

Obtain last layer output of the pre-trained model

last_layer = pre_trained_model.get_layer('mixed7')
last_output = last_layer.output

Adding dense layers after pre-trained model

x = layers.Flatten()(last_output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(10, activation='softmax')(x)

model = Model(pre_trained_model.input, x)

model.compile(optimizer=RMSprop(lr=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Model summary of Inception v3 with dense layers

model.summary()
plot_model(model, to_file='inception_v3_with_dense_layers_model.png', show_shapes=False, show_layer_names=True)

inception_v3_with_dense_layers_model_architecture
Fig. 4. Inception v3 with Dense Layers Model Architecture

Fitting the model

callbacks = myCallback()
history = model.fit_generator(generator=train_generator,
                              validation_data=validation_generator,
                              steps_per_epoch=100,
                              epochs=10,
                              validation_steps=100,
                              verbose=2,
                              callbacks=[callbacks])

Plotting model training and validation result

plot_result(history)

inception_v3_model_result
Fig. 5. Inception v3 Model Result

As you can see, using Inception v3 for transfer learning, we are able to obtain a validation accuracy of 0.8 after 10 epochs. This is a 14% improvement from the previous CNN model.

Remarks

In this simple example, we can see how transfer learning is able outperform a simple CNN model for the Fashion MNist dataset. In real-life, most of our images are often more difficult to classify. Therefore, being able to leverage on a pre-trained model is really a big step forward for the community!

You can check out this Jupyter notebook for the codes above.

Reference

[1] Laurence Moroney et al. Coursera: Convolutional Neural Networks in TensorFlow

[2] Christian Szegedy et al. Going Deeper with Convolutions


Thank you for reading! See you in the next post!

Written on July 8, 2020