Performing classification using prototypical networks
Now, we will see how to use prototypical networks to perform a classification task. We use an omniglot dataset for performing classification. This dataset comprises 1,623 handwritten characters from 50 different alphabets, and each character has 20 different examples written by different people. Since we want our network to learn from data, we train it in the same way. We sample five examples from each class and use that as our support set. We learn the embeddings of our support set using a sequence of four convolution blocks as our encoder and build the class prototype. Similarly, we sample five examples from each class for our query set, learn the query set embeddings, and predict the query set class by comparing the Euclidean distance between the query set embeddings and the class prototype. Let's better understand this by going through it step by step.
import os import glob from PIL import Image import numpy as np import tensorflow as tf
Now, we will explore and see what we got in our data. As we know, we have different characters from different alphabets and each character has twenty different variants written by different people. Let's plot and check some of them.
Let's plot one character from the Japanese alphabet:
Now that we have understood what is in our dataset, we load our dataset:
root_dir = 'data/'
We have the splitting details in the /data/omniglot/splits/train.txt file which has the language name, character number, and rotation information and images in /data/omniglot/data/:
Now, we read all of the images, convert them into a NumPy array and store it in our train_dataset array with their label and values, that is, train_dataset = [label, values]:
for label, name in enumerate(train_classes): alphabet, character, rotation = name.split('/') rotation = float(rotation[3:]) img_dir = os.path.join(root_dir, 'data', alphabet, character) img_files = sorted(glob.glob(os.path.join(img_dir, '*.png')))
for index, img_file in enumerate(img_files): values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)), np.float32, copy=False) train_dataset[label, index] = values
The shape of the training data would be as follows:
train_dataset.shape
(4112, 20, 28, 28)
Now that we have loaded our training data, we need to create embeddings for them. We generate the embeddings using convolution operation as our input values are images. So, we define a convolutional block with 64 filters with batch normalization and ReLU as the activation function. We follow this with performing a max pooling operation:
net = convolution_block(support_set, h_dim) net = convolution_block(net, h_dim) net = convolution_block(net, h_dim) net = convolution_block(net, z_dim) net = tf.contrib.layers.flatten(net)
return net
Remember, we don't use our whole dataset for training; since we are using one-shot learning, we sample some data points from each class as a support set and train the network using the support set in an episodic fashion.
Now, we define some of the important variables—we consider a 50-way five-shot learning scenario:
#number of classes num_way = 50
#number of examples per class in a support set num_shot = 5
#number of query points for query set num_query = 5
#number of examples num_examples = 20
h_dim = 64 z_dim = 64
Next, we initialize placeholders for our support and query sets:
We get the number of classes, the number of data points in the support set, and the number of data points in the query set for initializing our support and query sets:
Now that we have the class prototype and query set embeddings, we define a distance function that gives us the distance between the class prototypes and query set embeddings:
def euclidean_distance(a, b):
N, D = tf.shape(a)[0], tf.shape(a)[1] M = tf.shape(b)[0] a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1)) b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1)) return tf.reduce_mean(tf.square(a - b), axis=2)
We calculate the distance between the class prototype and query set embeddings: