Notebooks
M
Microsoft
TransferLearningTF

TransferLearningTF

artificial-intelligencernnganmicrosoft-for-beginnerslessonsAImicrosoft-AI-For-Beginnersmachine-learning08-TransferLearningdeep-learning4-ComputerVisioncomputer-visioncnnNLP

Pre-trained models and transfer learning

Training CNNs can take a lot of time, and a lot of data is required for that task. However, much of the time is spent to learn the best low-level filters that a network is using to extract patterns from images. A natural question arises - can we use a neural network trained on one dataset and adapt it to classifying different images without full training process?

This approach is called transfer learning, because we transfer some knowledge from one neural network model to another. In transfer learning, we typically start with a pre-trained model, which has been trained on some large image dataset, such as ImageNet. Those models can already do a good job extracting different features from generic images, and in many cases just building a classifier on top of those extracted features can yield a good result.

[1]

Cats vs. Dogs Dataset

In this unit, we will solve a real-life problem of classifying images of cats and dogs. For this reason, we will use Kaggle Cats vs. Dogs Dataset, which can also be downloaded from Microsoft.

Let's download this dataset and extract it into data directory (this process may take some time!):

[2]
[2]

Unfortunately, there are some corrupt image files in the dataset. We need to do quick cleaning to check for corrupted files. In order not to clobber this tutorial, we moved the code to verify dataset into a module.

[3]
Corrupt image or wrong format: data/PetImages/Cat/12235.jpg
Corrupt image or wrong format: data/PetImages/Cat/2663.jpg
Corrupt image or wrong format: data/PetImages/Cat/4929.jpg
Corrupt image or wrong format: data/PetImages/Cat/8183.jpg
Corrupt image or wrong format: data/PetImages/Cat/11083.jpg
Corrupt image or wrong format: data/PetImages/Cat/6435.jpg
Corrupt image or wrong format: data/PetImages/Cat/6491.jpg
Corrupt image or wrong format: data/PetImages/Cat/7968.jpg
Corrupt image or wrong format: data/PetImages/Cat/6768.jpg
Corrupt image or wrong format: data/PetImages/Cat/11397.jpg
Corrupt image or wrong format: data/PetImages/Cat/8295.jpg
Corrupt image or wrong format: data/PetImages/Cat/4874.jpg
Corrupt image or wrong format: data/PetImages/Cat/23.jpg
Corrupt image or wrong format: data/PetImages/Cat/11864.jpg
Corrupt image or wrong format: data/PetImages/Cat/3491.jpg
Corrupt image or wrong format: data/PetImages/Cat/11729.jpg
Corrupt image or wrong format: data/PetImages/Cat/3197.jpg
Corrupt image or wrong format: data/PetImages/Cat/10874.jpg
Corrupt image or wrong format: data/PetImages/Cat/6376.jpg
Corrupt image or wrong format: data/PetImages/Cat/9361.jpg
Corrupt image or wrong format: data/PetImages/Cat/9328.jpg
Corrupt image or wrong format: data/PetImages/Cat/910.jpg
Corrupt image or wrong format: data/PetImages/Cat/2021.jpg
Corrupt image or wrong format: data/PetImages/Cat/666.jpg
Corrupt image or wrong format: data/PetImages/Cat/3710.jpg
Corrupt image or wrong format: data/PetImages/Cat/9171.jpg
Corrupt image or wrong format: data/PetImages/Cat/6906.jpg
Corrupt image or wrong format: data/PetImages/Cat/11095.jpg
Corrupt image or wrong format: data/PetImages/Cat/660.jpg
Corrupt image or wrong format: data/PetImages/Cat/8415.jpg
Corrupt image or wrong format: data/PetImages/Cat/3161.jpg
Corrupt image or wrong format: data/PetImages/Cat/850.jpg
Corrupt image or wrong format: data/PetImages/Cat/7003.jpg
Corrupt image or wrong format: data/PetImages/Cat/1267.jpg
Corrupt image or wrong format: data/PetImages/Cat/7642.jpg
Corrupt image or wrong format: data/PetImages/Cat/8553.jpg
Corrupt image or wrong format: data/PetImages/Cat/6900.jpg
Corrupt image or wrong format: data/PetImages/Cat/4334.jpg
Corrupt image or wrong format: data/PetImages/Cat/10404.jpg
Corrupt image or wrong format: data/PetImages/Cat/6980.jpg
Corrupt image or wrong format: data/PetImages/Cat/936.jpg
Corrupt image or wrong format: data/PetImages/Cat/9619.jpg
Corrupt image or wrong format: data/PetImages/Cat/3300.jpg
Corrupt image or wrong format: data/PetImages/Cat/7647.jpg
Corrupt image or wrong format: data/PetImages/Cat/2742.jpg
Corrupt image or wrong format: data/PetImages/Cat/9565.jpg
Corrupt image or wrong format: data/PetImages/Cat/1757.jpg
Corrupt image or wrong format: data/PetImages/Cat/11874.jpg
Corrupt image or wrong format: data/PetImages/Cat/1936.jpg
Corrupt image or wrong format: data/PetImages/Cat/5077.jpg
Corrupt image or wrong format: data/PetImages/Cat/7845.jpg
Corrupt image or wrong format: data/PetImages/Cat/8832.jpg
Corrupt image or wrong format: data/PetImages/Cat/11935.jpg
Corrupt image or wrong format: data/PetImages/Cat/9208.jpg
Corrupt image or wrong format: data/PetImages/Cat/4322.jpg
Corrupt image or wrong format: data/PetImages/Cat/7978.jpg
Corrupt image or wrong format: data/PetImages/Cat/10820.jpg
Corrupt image or wrong format: data/PetImages/Cat/391.jpg
Corrupt image or wrong format: data/PetImages/Cat/10073.jpg
Corrupt image or wrong format: data/PetImages/Cat/8958.jpg
Corrupt image or wrong format: data/PetImages/Cat/8470.jpg
Corrupt image or wrong format: data/PetImages/Cat/445.jpg
Corrupt image or wrong format: data/PetImages/Cat/4821.jpg
Corrupt image or wrong format: data/PetImages/Cat/11565.jpg
Corrupt image or wrong format: data/PetImages/Cat/4351.jpg
Corrupt image or wrong format: data/PetImages/Cat/4833.jpg
Corrupt image or wrong format: data/PetImages/Cat/140.jpg
Corrupt image or wrong format: data/PetImages/Cat/3153.jpg
Corrupt image or wrong format: data/PetImages/Cat/5370.jpg
Corrupt image or wrong format: data/PetImages/Cat/5819.jpg
Corrupt image or wrong format: data/PetImages/Cat/3649.jpg
Corrupt image or wrong format: data/PetImages/Cat/4293.jpg
Corrupt image or wrong format: data/PetImages/Cat/3967.jpg
Corrupt image or wrong format: data/PetImages/Cat/1937.jpg
Corrupt image or wrong format: data/PetImages/Cat/10125.jpg
Corrupt image or wrong format: data/PetImages/Cat/1386.jpg
Corrupt image or wrong format: data/PetImages/Cat/4750.jpg
Corrupt image or wrong format: data/PetImages/Cat/6029.jpg
Corrupt image or wrong format: data/PetImages/Cat/5614.jpg
Corrupt image or wrong format: data/PetImages/Cat/2569.jpg
Corrupt image or wrong format: data/PetImages/Cat/9778.jpg
Corrupt image or wrong format: data/PetImages/Cat/6486.jpg
Corrupt image or wrong format: data/PetImages/Cat/11210.jpg
Corrupt image or wrong format: data/PetImages/Cat/7502.jpg
Corrupt image or wrong format: data/PetImages/Cat/2189.jpg
Corrupt image or wrong format: data/PetImages/Cat/1151.jpg
Corrupt image or wrong format: data/PetImages/Cat/4629.jpg
Corrupt image or wrong format: data/PetImages/Cat/12080.jpg
Corrupt image or wrong format: data/PetImages/Cat/10501.jpg
Corrupt image or wrong format: data/PetImages/Cat/1914.jpg
Corrupt image or wrong format: data/PetImages/Cat/11086.jpg
Corrupt image or wrong format: data/PetImages/Cat/4000.jpg
Corrupt image or wrong format: data/PetImages/Cat/9100.jpg
Corrupt image or wrong format: data/PetImages/Cat/5553.jpg
Corrupt image or wrong format: data/PetImages/Cat/12269.jpg
Corrupt image or wrong format: data/PetImages/Dog/5243.jpg
Corrupt image or wrong format: data/PetImages/Dog/3288.jpg
Corrupt image or wrong format: data/PetImages/Dog/11912.jpg
Corrupt image or wrong format: data/PetImages/Dog/3320.jpg
Corrupt image or wrong format: data/PetImages/Dog/11166.jpg
Corrupt image or wrong format: data/PetImages/Dog/8194.jpg
Corrupt image or wrong format: data/PetImages/Dog/2353.jpg
Corrupt image or wrong format: data/PetImages/Dog/9026.jpg
Corrupt image or wrong format: data/PetImages/Dog/8557.jpg
Corrupt image or wrong format: data/PetImages/Dog/7969.jpg
Corrupt image or wrong format: data/PetImages/Dog/8693.jpg
Corrupt image or wrong format: data/PetImages/Dog/4367.jpg
Corrupt image or wrong format: data/PetImages/Dog/1900.jpg
Corrupt image or wrong format: data/PetImages/Dog/5730.jpg
Corrupt image or wrong format: data/PetImages/Dog/9961.jpg
Corrupt image or wrong format: data/PetImages/Dog/8126.jpg
Corrupt image or wrong format: data/PetImages/Dog/8521.jpg
Corrupt image or wrong format: data/PetImages/Dog/11560.jpg
Corrupt image or wrong format: data/PetImages/Dog/11233.jpg
Corrupt image or wrong format: data/PetImages/Dog/11675.jpg
Corrupt image or wrong format: data/PetImages/Dog/8889.jpg
Corrupt image or wrong format: data/PetImages/Dog/9414.jpg
Corrupt image or wrong format: data/PetImages/Dog/3588.jpg
Corrupt image or wrong format: data/PetImages/Dog/10353.jpg
Corrupt image or wrong format: data/PetImages/Dog/10733.jpg
Corrupt image or wrong format: data/PetImages/Dog/6238.jpg
Corrupt image or wrong format: data/PetImages/Dog/5618.jpg
Corrupt image or wrong format: data/PetImages/Dog/7743.jpg
Corrupt image or wrong format: data/PetImages/Dog/10747.jpg
Corrupt image or wrong format: data/PetImages/Dog/6503.jpg
Corrupt image or wrong format: data/PetImages/Dog/11410.jpg
Corrupt image or wrong format: data/PetImages/Dog/5955.jpg
Corrupt image or wrong format: data/PetImages/Dog/2915.jpg
Corrupt image or wrong format: data/PetImages/Dog/3136.jpg
Corrupt image or wrong format: data/PetImages/Dog/4086.jpg
Corrupt image or wrong format: data/PetImages/Dog/2905.jpg
Corrupt image or wrong format: data/PetImages/Dog/2494.jpg
Corrupt image or wrong format: data/PetImages/Dog/3546.jpg
Corrupt image or wrong format: data/PetImages/Dog/5547.jpg
Corrupt image or wrong format: data/PetImages/Dog/9851.jpg
Corrupt image or wrong format: data/PetImages/Dog/9188.jpg
Corrupt image or wrong format: data/PetImages/Dog/4203.jpg
Corrupt image or wrong format: data/PetImages/Dog/8715.jpg
Corrupt image or wrong format: data/PetImages/Dog/4640.jpg
Corrupt image or wrong format: data/PetImages/Dog/1866.jpg
Corrupt image or wrong format: data/PetImages/Dog/7718.jpg
Corrupt image or wrong format: data/PetImages/Dog/7514.jpg
Corrupt image or wrong format: data/PetImages/Dog/1017.jpg
Corrupt image or wrong format: data/PetImages/Dog/12102.jpg
Corrupt image or wrong format: data/PetImages/Dog/296.jpg
Corrupt image or wrong format: data/PetImages/Dog/6718.jpg
Corrupt image or wrong format: data/PetImages/Dog/11253.jpg
Corrupt image or wrong format: data/PetImages/Dog/10173.jpg
Corrupt image or wrong format: data/PetImages/Dog/11849.jpg
Corrupt image or wrong format: data/PetImages/Dog/10383.jpg
Corrupt image or wrong format: data/PetImages/Dog/10678.jpg
Corrupt image or wrong format: data/PetImages/Dog/5790.jpg
Corrupt image or wrong format: data/PetImages/Dog/6318.jpg
Corrupt image or wrong format: data/PetImages/Dog/1884.jpg
Corrupt image or wrong format: data/PetImages/Dog/4654.jpg
Corrupt image or wrong format: data/PetImages/Dog/9145.jpg
Corrupt image or wrong format: data/PetImages/Dog/2479.jpg
Corrupt image or wrong format: data/PetImages/Dog/12289.jpg
Corrupt image or wrong format: data/PetImages/Dog/2688.jpg
Corrupt image or wrong format: data/PetImages/Dog/4134.jpg
Corrupt image or wrong format: data/PetImages/Dog/10637.jpg
/anaconda/envs/py38_tensorflow/lib/python3.8/site-packages/PIL/TiffImagePlugin.py:793: UserWarning: Truncated File Read
  warnings.warn(str(msg))
Corrupt image or wrong format: data/PetImages/Dog/543.jpg
Corrupt image or wrong format: data/PetImages/Dog/8730.jpg
Corrupt image or wrong format: data/PetImages/Dog/12114.jpg
Corrupt image or wrong format: data/PetImages/Dog/522.jpg
Corrupt image or wrong format: data/PetImages/Dog/10863.jpg
Corrupt image or wrong format: data/PetImages/Dog/10401.jpg
Corrupt image or wrong format: data/PetImages/Dog/7739.jpg
Corrupt image or wrong format: data/PetImages/Dog/561.jpg
Corrupt image or wrong format: data/PetImages/Dog/1308.jpg
Corrupt image or wrong format: data/PetImages/Dog/9643.jpg
Corrupt image or wrong format: data/PetImages/Dog/3155.jpg
Corrupt image or wrong format: data/PetImages/Dog/5736.jpg
Corrupt image or wrong format: data/PetImages/Dog/4257.jpg
Corrupt image or wrong format: data/PetImages/Dog/11853.jpg
Corrupt image or wrong format: data/PetImages/Dog/6855.jpg
Corrupt image or wrong format: data/PetImages/Dog/1259.jpg
Corrupt image or wrong format: data/PetImages/Dog/7652.jpg
Corrupt image or wrong format: data/PetImages/Dog/573.jpg
Corrupt image or wrong format: data/PetImages/Dog/11702.jpg
Corrupt image or wrong format: data/PetImages/Dog/414.jpg
Corrupt image or wrong format: data/PetImages/Dog/4301.jpg
Corrupt image or wrong format: data/PetImages/Dog/9500.jpg
Corrupt image or wrong format: data/PetImages/Dog/10726.jpg
Corrupt image or wrong format: data/PetImages/Dog/9367.jpg
Corrupt image or wrong format: data/PetImages/Dog/2877.jpg
Corrupt image or wrong format: data/PetImages/Dog/10907.jpg
Corrupt image or wrong format: data/PetImages/Dog/10972.jpg
Corrupt image or wrong format: data/PetImages/Dog/7311.jpg
Corrupt image or wrong format: data/PetImages/Dog/10797.jpg
Corrupt image or wrong format: data/PetImages/Dog/2317.jpg
Corrupt image or wrong format: data/PetImages/Dog/7128.jpg
Corrupt image or wrong format: data/PetImages/Dog/6500.jpg
Corrupt image or wrong format: data/PetImages/Dog/11285.jpg
Corrupt image or wrong format: data/PetImages/Dog/6430.jpg
Corrupt image or wrong format: data/PetImages/Dog/6032.jpg
Corrupt image or wrong format: data/PetImages/Dog/10969.jpg
Corrupt image or wrong format: data/PetImages/Dog/8364.jpg
Corrupt image or wrong format: data/PetImages/Dog/11692.jpg
Corrupt image or wrong format: data/PetImages/Dog/3038.jpg
Corrupt image or wrong format: data/PetImages/Dog/5604.jpg
Corrupt image or wrong format: data/PetImages/Dog/565.jpg
Corrupt image or wrong format: data/PetImages/Dog/9640.jpg
Corrupt image or wrong format: data/PetImages/Dog/7459.jpg
Corrupt image or wrong format: data/PetImages/Dog/6305.jpg
Corrupt image or wrong format: data/PetImages/Dog/6555.jpg
Corrupt image or wrong format: data/PetImages/Dog/11590.jpg
Corrupt image or wrong format: data/PetImages/Dog/6059.jpg
Corrupt image or wrong format: data/PetImages/Dog/3927.jpg
Corrupt image or wrong format: data/PetImages/Dog/10705.jpg
Corrupt image or wrong format: data/PetImages/Dog/2384.jpg
Corrupt image or wrong format: data/PetImages/Dog/3885.jpg
Corrupt image or wrong format: data/PetImages/Dog/663.jpg
Corrupt image or wrong format: data/PetImages/Dog/9043.jpg
Corrupt image or wrong format: data/PetImages/Dog/8563.jpg
Corrupt image or wrong format: data/PetImages/Dog/9556.jpg
Corrupt image or wrong format: data/PetImages/Dog/9967.jpg
Corrupt image or wrong format: data/PetImages/Dog/10158.jpg
Corrupt image or wrong format: data/PetImages/Dog/5263.jpg
Corrupt image or wrong format: data/PetImages/Dog/8641.jpg
Corrupt image or wrong format: data/PetImages/Dog/7369.jpg
Corrupt image or wrong format: data/PetImages/Dog/4924.jpg
Corrupt image or wrong format: data/PetImages/Dog/10351.jpg
Corrupt image or wrong format: data/PetImages/Dog/6213.jpg
Corrupt image or wrong format: data/PetImages/Dog/5104.jpg
Corrupt image or wrong format: data/PetImages/Dog/1356.jpg
Corrupt image or wrong format: data/PetImages/Dog/7133.jpg
Corrupt image or wrong format: data/PetImages/Dog/7112.jpg
Corrupt image or wrong format: data/PetImages/Dog/1168.jpg
Corrupt image or wrong format: data/PetImages/Dog/719.jpg
Corrupt image or wrong format: data/PetImages/Dog/50.jpg

Loading the Dataset

In previous examples, we were loading datasets that are built into Keras. Now we are about to deal with our own dataset, which we need to load from a directory of images.

In real life, the size of image datasets can be pretty large, and one cannot rely on all data being able to fit into memory. Thus, datasets are often represented as generators that can return data in minibatches suitable for training.

To deal with image classification, Keras includes special function image_dataset_from_directory, which can load images from subdirectories corresponding to different classes. This function also takes care of scaling images, and it can also split dataset into train and test subsets:

[4]
Found 24769 files belonging to 2 classes.
Using 19816 files for training.
Found 24769 files belonging to 2 classes.
Using 4953 files for validation.

It is important to set the same seed value for both calls, because it affects the split of images between train and test dataset.

Dataset automatically picks up class names from directories, and you can access them if needed by calling:

[6]
['Cat', 'Dog']

Datasets that we have obtained can be directly passed to fit function to train the model. They contain both corresponding images and labels, which can be looped over using the following construction:

[7]
Training batch shape: features=(64, 224, 224, 3), labels=(64,)
Output

Note: All images in the dataset are represented as floatint point tensors with range 0-255. Before passing them to the neural network, we need to scale those values into 0-1 range. When plotting images, we either need to do the same, or convert values to the int type (which we do in the code above), in order to show matplotlib that we want to plot the original unscaled image.

Pre-trained models

For many image classification tasks one can find pre-trained neural network models. Many of those models are available inside keras.applications namespace, and even more models can be found on the Internet. Let's see how simplest VGG-16 model can be loaded and used:

[8]
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467904/553467096 [==============================] - 6s 0us/step
Most probable class = [208]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
40960/35363 [==================================] - 0s 0us/step
[[('n02099712', 'Labrador_retriever', 0.5340957),
,  ('n02100236', 'German_short-haired_pointer', 0.0939442),
,  ('n02092339', 'Weimaraner', 0.08160535),
,  ('n02099849', 'Chesapeake_Bay_retriever', 0.057179328),
,  ('n02109047', 'Great_Dane', 0.03733857)]]

There are a couple of important things here:

  • Before passing an input to any pre-trained network it has to be pre-processed in a certain way. This is done by calling corresponding preprocess_input function, which receives a batch of images, and returns their processed form. In the case of VGG-16, images are normalized, and some pre-defined avarage value for each channels is subtracted. That is because VGG-16 was originally trained with this pre-processing.
  • Neural network is applied to the input batch, and we receive as the result a batch of 1000-element tensors that show probability of each class. We can find the most probable class number by calling argmax on this tensor.
  • Obtained result is a number of an ImageNet class. To make sense of this result, we can also use decode_predictions function, that returns top n classes together with their names.

Let's also see the architecture of the VGG-16 network:

[9]
Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

GPU computations

Deep neural networks, such as VGG-16 and other more modern architectures require quite a lot of computational power to run. It makes sense to use GPU acceleration, if it is available. Luckily, Keras automatically speeds up the computatons on the GPU if it is available. We can check if Tensorflow is able to use GPU using the following code:

[10]
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Extracting VGG features

If we want to use VGG-16 to extract features from our images, we need the model without final classification layers. We can instantiate VGG-16 model without top layers using this code:

[15]
Shape after applying VGG-16: (7, 7, 512)
<matplotlib.image.AxesImage at 0x7fafcc685ac0>
Output

The dimension of feature tensor is 7x7x512, but in order to visualize it we had to reshape it to 2D form.

Now let's try to see if those features can be used to classify images. Let's manually take some portion of images (50 minibatches, in our case), and pre-compute their feature vectors. We can use Tensorflow dataset API to do that. map function takes a dataset and applies a given lambda-function to transform it. We use this mechanism to construct new datasets, ds_features_train and ds_features_test, that contain VGG-extracted features instead of original images.

[41]
(64, 7, 7, 512) (64,)

We used construction .take(50) to limit the dataset size, to speed up our demonstration. You can of course perform this experiment on the full dataset.

Now that we have a dataset with extracted features, we can train a simple dense classifier to distinguish between cats and dogs. This network will take feature vector of shape (7,7,512), and produce one output that corresponds either to a dog or to a cat. Because it is a binary classification, we use sigmoid activation function and binary_crossentropy loss.

[44]
50/50 [==============================] - 1896s 38s/step - loss: 1.4845 - acc: 0.9144 - val_loss: 0.7220 - val_acc: 0.9516

The result is great, we can distinguish between a cat and a dog with almost 95% probability! However, we have only tested this approach on a subset of all images, because manual feature extraction seems to take a lot of time.

Transfer learning using one VGG network

We can also avoid manually pre-computing the features by using the original VGG-16 network as a whole during training, by adding feature extractor to our network as a first layer.

The beauty of Keras architecture is that VGG-16 model that we have defined above can also be used as a layer in another neural network! We just need to construct a network with dense classifier on top of it, and then train the whole network using back propagation.

[6]
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 1)                 25089     
=================================================================
Total params: 14,739,777
Trainable params: 25,089
Non-trainable params: 14,714,688
_________________________________________________________________

This model looks like and end-to-end classification network, which takes an image and returns the class. However, the tricky thing is that we want VGG16 to act as a feature extractor, and not to be re-trained. Thus, we need to freeze weights of convolutional feature extractor. We can access first layer of the network by calling model.layers[0], and we just need to set trainable property to False.

Note: Freezing of feature extractor weights is needed, because otherwise untrained classifier layer can destroy the original pre-trained weights of convolutional extractor.

You can notice that while the total number of parameters in our network is around 15 million, we are only training 25k parameters. All other parameters of top-level convolutional filters are pre-trained. That is good, because we are able to fine-tune smaller number of parameters with smaller number of examples.

We will now train our network and see how good we can get. Expect rather long running time, and do not worry if the execution seems frozen for some time.

[7]
310/310 [==============================] - 265s 716ms/step - loss: 0.9917 - acc: 0.9512 - val_loss: 0.8156 - val_acc: 0.9671

It looks like we have obtained reasonably accurate cats vs. dogs classifier!

Saving and Loading the Model

Once we have trained the model, we can save model architecture and trained weights to a file for future use:

[8]
INFO:tensorflow:Assets written to: data/cats_dogs.tf/assets

We can then load the model from file at any time. You may find it useful in case the next experiment destroys the model - you would not have to re-start from scratch.

[9]

Fine-tuning transfer learning

In the previous section, we have trained the final classifier layer to classify images in our own dataset. However, we did not re-train the feature extractor, and our model relied on the features that the model has learned on ImageNet data. If your objects visually differ from ordinary ImageNet images, this combination of features might not work best. Thus it makes sense to start training convolutional layers as well.

To do that, we can unfreeze the convolutional filter parameters that we have previously frozen.

Note: It is important that you freeze parameters first and perform several epochs of training in order to stabilize weights in the classification layer. If you immediately start training end-to-end network with unfrozen parameters, large errors are likely to destroy the pre-trained weights in the convolutional layers.

Our convolutional VGG-16 model is located inside the first layer, and it consists of many layers itself. We can have a look at its structure:

[10]
Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
=================================================================
Total params: 14,714,688
Trainable params: 0
Non-trainable params: 14,714,688
_________________________________________________________________

We can unfreeze all layers of convolutional base:

[ ]

However, unfeezing all of them at once is not the best idea. We can first unfreeze just a few final layers of convolutions, because they contain higher level patterns that are relevant for our images. For example, to begin with, we can freeze all layers except the last 4:

[11]
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 1)                 25089     
=================================================================
Total params: 14,739,777
Trainable params: 7,104,513
Non-trainable params: 7,635,264
_________________________________________________________________

Observe that the number of trainable parameters increased significantly, but it is still around 50% of all parameters.

After unfreezing, we can do a few more epochs of training (in our example, we will do just one). You can also select lower learning rate, in order to minimize the impact on the pre-trained weights. However, even with low learning rate, you can expect the accuracy to drop in the beginning of the training, until finally reaching slightly higher level than in the case of fixed weights.

Note: This training happens much slower, because we need to propagate gradients back through many layers of the network!

[12]
310/310 [==============================] - 201s 645ms/step - loss: 0.5270 - acc: 0.9776 - val_loss: 1.4132 - val_acc: 0.9653

We are likely to achieve higher training accuracy, because we are using more poweful network with more parameters, but validation accuracy would increase not as much.

Feel free to unfreeze a few more layers of the network and train more, to see if you are able to achieve higher accuracy!

Other computer vision models

VGG-16 is one of the simplest computer vision architectures. Keras provides many more pre-trained networks. The most frequently used ones among those are ResNet architectures, developed by Microsoft, and Inception by Google. For example, let's explore the architecture of the simplest ResNet-50 model (ResNet is a family of models with different depth, you can try experimenting with ResNet-152 if you want to see what a really deep model looks like):

[16]
Model: "resnet50"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
conv1_bn (BatchNormalization)   (None, 112, 112, 64) 256         conv1_conv[0][0]                 
__________________________________________________________________________________________________
conv1_relu (Activation)         (None, 112, 112, 64) 0           conv1_bn[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           conv1_relu[0][0]                 
__________________________________________________________________________________________________
pool1_pool (MaxPooling2D)       (None, 56, 56, 64)   0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
conv2_block1_1_conv (Conv2D)    (None, 56, 56, 64)   4160        pool1_pool[0][0]                 
__________________________________________________________________________________________________
conv2_block1_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_1_relu (Activation (None, 56, 56, 64)   0           conv2_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block1_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_2_relu (Activation (None, 56, 56, 64)   0           conv2_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_0_conv (Conv2D)    (None, 56, 56, 256)  16640       pool1_pool[0][0]                 
__________________________________________________________________________________________________
conv2_block1_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block1_0_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_add (Add)          (None, 56, 56, 256)  0           conv2_block1_0_bn[0][0]          
                                                                 conv2_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block1_out (Activation)   (None, 56, 56, 256)  0           conv2_block1_add[0][0]           
__________________________________________________________________________________________________
conv2_block2_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block1_out[0][0]           
__________________________________________________________________________________________________
conv2_block2_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_1_relu (Activation (None, 56, 56, 64)   0           conv2_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block2_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_2_relu (Activation (None, 56, 56, 64)   0           conv2_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block2_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block2_add (Add)          (None, 56, 56, 256)  0           conv2_block1_out[0][0]           
                                                                 conv2_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block2_out (Activation)   (None, 56, 56, 256)  0           conv2_block2_add[0][0]           
__________________________________________________________________________________________________
conv2_block3_1_conv (Conv2D)    (None, 56, 56, 64)   16448       conv2_block2_out[0][0]           
__________________________________________________________________________________________________
conv2_block3_1_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_1_relu (Activation (None, 56, 56, 64)   0           conv2_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_2_conv (Conv2D)    (None, 56, 56, 64)   36928       conv2_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv2_block3_2_bn (BatchNormali (None, 56, 56, 64)   256         conv2_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_2_relu (Activation (None, 56, 56, 64)   0           conv2_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_3_conv (Conv2D)    (None, 56, 56, 256)  16640       conv2_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv2_block3_3_bn (BatchNormali (None, 56, 56, 256)  1024        conv2_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv2_block3_add (Add)          (None, 56, 56, 256)  0           conv2_block2_out[0][0]           
                                                                 conv2_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv2_block3_out (Activation)   (None, 56, 56, 256)  0           conv2_block3_add[0][0]           
__________________________________________________________________________________________________
conv3_block1_1_conv (Conv2D)    (None, 28, 28, 128)  32896       conv2_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block1_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_1_relu (Activation (None, 28, 28, 128)  0           conv3_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block1_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_2_relu (Activation (None, 28, 28, 128)  0           conv3_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_0_conv (Conv2D)    (None, 28, 28, 512)  131584      conv2_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block1_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block1_0_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block1_add (Add)          (None, 28, 28, 512)  0           conv3_block1_0_bn[0][0]          
                                                                 conv3_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block1_out (Activation)   (None, 28, 28, 512)  0           conv3_block1_add[0][0]           
__________________________________________________________________________________________________
conv3_block2_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block1_out[0][0]           
__________________________________________________________________________________________________
conv3_block2_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_1_relu (Activation (None, 28, 28, 128)  0           conv3_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block2_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_2_relu (Activation (None, 28, 28, 128)  0           conv3_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block2_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block2_add (Add)          (None, 28, 28, 512)  0           conv3_block1_out[0][0]           
                                                                 conv3_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block2_out (Activation)   (None, 28, 28, 512)  0           conv3_block2_add[0][0]           
__________________________________________________________________________________________________
conv3_block3_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block2_out[0][0]           
__________________________________________________________________________________________________
conv3_block3_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_1_relu (Activation (None, 28, 28, 128)  0           conv3_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block3_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_2_relu (Activation (None, 28, 28, 128)  0           conv3_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block3_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block3_add (Add)          (None, 28, 28, 512)  0           conv3_block2_out[0][0]           
                                                                 conv3_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block3_out (Activation)   (None, 28, 28, 512)  0           conv3_block3_add[0][0]           
__________________________________________________________________________________________________
conv3_block4_1_conv (Conv2D)    (None, 28, 28, 128)  65664       conv3_block3_out[0][0]           
__________________________________________________________________________________________________
conv3_block4_1_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_1_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_1_relu (Activation (None, 28, 28, 128)  0           conv3_block4_1_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_2_conv (Conv2D)    (None, 28, 28, 128)  147584      conv3_block4_1_relu[0][0]        
__________________________________________________________________________________________________
conv3_block4_2_bn (BatchNormali (None, 28, 28, 128)  512         conv3_block4_2_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_2_relu (Activation (None, 28, 28, 128)  0           conv3_block4_2_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_3_conv (Conv2D)    (None, 28, 28, 512)  66048       conv3_block4_2_relu[0][0]        
__________________________________________________________________________________________________
conv3_block4_3_bn (BatchNormali (None, 28, 28, 512)  2048        conv3_block4_3_conv[0][0]        
__________________________________________________________________________________________________
conv3_block4_add (Add)          (None, 28, 28, 512)  0           conv3_block3_out[0][0]           
                                                                 conv3_block4_3_bn[0][0]          
__________________________________________________________________________________________________
conv3_block4_out (Activation)   (None, 28, 28, 512)  0           conv3_block4_add[0][0]           
__________________________________________________________________________________________________
conv4_block1_1_conv (Conv2D)    (None, 14, 14, 256)  131328      conv3_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block1_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_1_relu (Activation (None, 14, 14, 256)  0           conv4_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block1_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_2_relu (Activation (None, 14, 14, 256)  0           conv4_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_0_conv (Conv2D)    (None, 14, 14, 1024) 525312      conv3_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block1_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block1_0_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block1_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_0_bn[0][0]          
                                                                 conv4_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block1_out (Activation)   (None, 14, 14, 1024) 0           conv4_block1_add[0][0]           
__________________________________________________________________________________________________
conv4_block2_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block1_out[0][0]           
__________________________________________________________________________________________________
conv4_block2_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_1_relu (Activation (None, 14, 14, 256)  0           conv4_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block2_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_2_relu (Activation (None, 14, 14, 256)  0           conv4_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block2_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block2_add (Add)          (None, 14, 14, 1024) 0           conv4_block1_out[0][0]           
                                                                 conv4_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block2_out (Activation)   (None, 14, 14, 1024) 0           conv4_block2_add[0][0]           
__________________________________________________________________________________________________
conv4_block3_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block2_out[0][0]           
__________________________________________________________________________________________________
conv4_block3_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_1_relu (Activation (None, 14, 14, 256)  0           conv4_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block3_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_2_relu (Activation (None, 14, 14, 256)  0           conv4_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block3_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block3_add (Add)          (None, 14, 14, 1024) 0           conv4_block2_out[0][0]           
                                                                 conv4_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block3_out (Activation)   (None, 14, 14, 1024) 0           conv4_block3_add[0][0]           
__________________________________________________________________________________________________
conv4_block4_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block3_out[0][0]           
__________________________________________________________________________________________________
conv4_block4_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_1_relu (Activation (None, 14, 14, 256)  0           conv4_block4_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block4_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block4_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block4_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_2_relu (Activation (None, 14, 14, 256)  0           conv4_block4_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block4_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block4_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block4_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block4_add (Add)          (None, 14, 14, 1024) 0           conv4_block3_out[0][0]           
                                                                 conv4_block4_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block4_out (Activation)   (None, 14, 14, 1024) 0           conv4_block4_add[0][0]           
__________________________________________________________________________________________________
conv4_block5_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block4_out[0][0]           
__________________________________________________________________________________________________
conv4_block5_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_1_relu (Activation (None, 14, 14, 256)  0           conv4_block5_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block5_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block5_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block5_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_2_relu (Activation (None, 14, 14, 256)  0           conv4_block5_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block5_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block5_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block5_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block5_add (Add)          (None, 14, 14, 1024) 0           conv4_block4_out[0][0]           
                                                                 conv4_block5_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block5_out (Activation)   (None, 14, 14, 1024) 0           conv4_block5_add[0][0]           
__________________________________________________________________________________________________
conv4_block6_1_conv (Conv2D)    (None, 14, 14, 256)  262400      conv4_block5_out[0][0]           
__________________________________________________________________________________________________
conv4_block6_1_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_1_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_1_relu (Activation (None, 14, 14, 256)  0           conv4_block6_1_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_2_conv (Conv2D)    (None, 14, 14, 256)  590080      conv4_block6_1_relu[0][0]        
__________________________________________________________________________________________________
conv4_block6_2_bn (BatchNormali (None, 14, 14, 256)  1024        conv4_block6_2_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_2_relu (Activation (None, 14, 14, 256)  0           conv4_block6_2_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_3_conv (Conv2D)    (None, 14, 14, 1024) 263168      conv4_block6_2_relu[0][0]        
__________________________________________________________________________________________________
conv4_block6_3_bn (BatchNormali (None, 14, 14, 1024) 4096        conv4_block6_3_conv[0][0]        
__________________________________________________________________________________________________
conv4_block6_add (Add)          (None, 14, 14, 1024) 0           conv4_block5_out[0][0]           
                                                                 conv4_block6_3_bn[0][0]          
__________________________________________________________________________________________________
conv4_block6_out (Activation)   (None, 14, 14, 1024) 0           conv4_block6_add[0][0]           
__________________________________________________________________________________________________
conv5_block1_1_conv (Conv2D)    (None, 7, 7, 512)    524800      conv4_block6_out[0][0]           
__________________________________________________________________________________________________
conv5_block1_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_1_relu (Activation (None, 7, 7, 512)    0           conv5_block1_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block1_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block1_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block1_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_2_relu (Activation (None, 7, 7, 512)    0           conv5_block1_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_0_conv (Conv2D)    (None, 7, 7, 2048)   2099200     conv4_block6_out[0][0]           
__________________________________________________________________________________________________
conv5_block1_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block1_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block1_0_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_0_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block1_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block1_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_0_bn[0][0]          
                                                                 conv5_block1_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block1_out (Activation)   (None, 7, 7, 2048)   0           conv5_block1_add[0][0]           
__________________________________________________________________________________________________
conv5_block2_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block1_out[0][0]           
__________________________________________________________________________________________________
conv5_block2_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_1_relu (Activation (None, 7, 7, 512)    0           conv5_block2_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block2_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block2_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block2_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_2_relu (Activation (None, 7, 7, 512)    0           conv5_block2_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block2_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block2_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block2_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block2_add (Add)          (None, 7, 7, 2048)   0           conv5_block1_out[0][0]           
                                                                 conv5_block2_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block2_out (Activation)   (None, 7, 7, 2048)   0           conv5_block2_add[0][0]           
__________________________________________________________________________________________________
conv5_block3_1_conv (Conv2D)    (None, 7, 7, 512)    1049088     conv5_block2_out[0][0]           
__________________________________________________________________________________________________
conv5_block3_1_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_1_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_1_relu (Activation (None, 7, 7, 512)    0           conv5_block3_1_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_2_conv (Conv2D)    (None, 7, 7, 512)    2359808     conv5_block3_1_relu[0][0]        
__________________________________________________________________________________________________
conv5_block3_2_bn (BatchNormali (None, 7, 7, 512)    2048        conv5_block3_2_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_2_relu (Activation (None, 7, 7, 512)    0           conv5_block3_2_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_3_conv (Conv2D)    (None, 7, 7, 2048)   1050624     conv5_block3_2_relu[0][0]        
__________________________________________________________________________________________________
conv5_block3_3_bn (BatchNormali (None, 7, 7, 2048)   8192        conv5_block3_3_conv[0][0]        
__________________________________________________________________________________________________
conv5_block3_add (Add)          (None, 7, 7, 2048)   0           conv5_block2_out[0][0]           
                                                                 conv5_block3_3_bn[0][0]          
__________________________________________________________________________________________________
conv5_block3_out (Activation)   (None, 7, 7, 2048)   0           conv5_block3_add[0][0]           
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 2048)         0           conv5_block3_out[0][0]           
__________________________________________________________________________________________________
predictions (Dense)             (None, 1000)         2049000     avg_pool[0][0]                   
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

As you can see, the model contains the same familiar building blocks: convolutional layers, pooling layers and final dense classifier. We can use this model in exactly the same manner as we have been using VGG-16 for transfer learning. You can try experimenting with the code above, using different ResNet models as the base model, and see how accuracy changes.

Batch Normalization

This network contains yet another type of layer: Batch Normalization. The idea of batch normalization is to bring values that flow through the neural network to right interval. Usually neural networks work best when all values are in the range of [-1,1] or [0,1], and that is the reason that we scale/normalize our input data accordingly. However, during training of a deep network, it can happen that values get significantly out of this range, which makes training problematic. Batch normalization layer computes average and standard deviation for all values of the current minibatch, and uses them to normalize the signal before passing it through a neural network layer. This significantly improves the stability of deep networks.

Takeaway

Using transfer learning, we were able to quickly put together a classifier for our custom object classification task, and achieve high accuracy. However, this example was not completely fair, because original VGG-16 network was pre-trained to recognize cats and dogs, and thus we were just reusing most of the patterns that were already present in the network. You can expect lower accuracy on more exotic domain-specific objects, such as details on production line in a plant, or different tree leaves.

You can see that more complex tasks that we are solving now require higher computational power, and cannot be easily solved on the CPU. In the next unit, we will try to use more lightweight implementation to train the same model using lower compute resources, which results in just slightly lower accuracy.