When the transformer model came in the “Attention is all you need” paper, it changed the way NLP tasks were handled. I explained it in my previous post Transformer model
In 2021 there was another paper “An image is worth 16×16 words” to attempt doing computer vision tasks using the transformer model.
The model ignores the traditional approach of filters used for CNNs and models the input image in patches of 16×16. It also uses the NLP approach of context capture using positional encoding to provide reference of different patches coming in as to where they are in the original image
Let’s take an input image of 256×256 with 3 channels.
It gets divided into patches of 16×16, so we have 16 rows and 16 columns resulting in 256 patches.
Every patch has 3 channels so the size of each patch is 16x16x3 = 768
Divide the input image into patches
patches = tf.image.extract_patches(
images=x,
sizes=[1, 16, 16, 1],
strides=[1, 16, 16, 1],
rates=[1, 1, 1, 1],
padding='VALID')
Reshape patches to flatten the 3 layers to get a 768 tensor
patches = tf.reshape(patches, (tf.shape(patches)[0], 256, patches.shape[-1]))
Now you have 256 inputs (pink ovals) with size of 768 each.
create the embeddings for positional reference
positional_embedding = Embedding(256, 768 )
embedding_input = tf.range(start = 0, limit = 256, delta = 1 ) // 1 ... 256
pos_embedding = positional_embedding(embedding_input)
This is same shape as the linear projection of patches, the pos_embeddings are the purple ovals in the image marked 1,2,3.. etc.
Now we join the patches and the embeddings
linear_projection = Dense(768)
output = linear_projection(patches) + pos_embedding
This all above can be written in a class as
class PatchCreator(Layer):
def __init__(self):
super(PatchCreator, self).__init__(name = 'patch_creator')
self.linear_projection = Dense(768)
self.positional_embedding = Embedding(256, 768 )
def call(self, input_images):
patches = tf.image.extract_patches(
images=input_images,
sizes=[1, 16,16, 1],
strides=[1, 16,16, 1],
rates=[1, 1, 1, 1],
padding='VALID')
patches = tf.reshape(patches, (tf.shape(patches)[0], 256, patches.shape[-1]))
embedding_input = tf.range(start = 0, limit = 256, delta = 1 )
pos_embedding = positional_embedding(embedding_input)
output = self.linear_projection(patches) + pos_embedding
return output
Now our input is ready to go in the gray box (Transformer encoder)
The transformer encoder has multiple layers
//First yellow Norm layer
layer_norm_1 = LayerNormalization()
//Green Multi-Head attention
multi_head_att = MultiHeadAttention(N_HEADS, 768 )
//Second yellow Norm layer
layer_norm_2 = LayerNormalization()
The Blue MLP ( multi-layer perceptron) is made of 2 dense layers with gelu non-linearity
dense_1 = Dense(768, activation = tf.nn.gelu)
dense_2 = Dense(768, activation = tf.nn.gelu)
Putting the calls in sequence for transformer encoder:
//First Yellow box
x_1 = self.layer_norm_1(inputPatches)
//Green box
// the inputs (Query, Value, Key) values going into the multi-head attention
// We will keep Q & V same and K is optional
x_1 = self.multi_head_att(x_1, x_1)
// First White + circle
x_1 = Add()([x_1, inputPatches])
//Second Yellow box
x_2 = self.layer_norm_2(x_1)
//Blue MLP box contains 2 dense layers
x_2 = self.dense_1(x_2)
x_2 = self.dense_2(x_2)
// Second White + circle
output = Add()([x_2, x_1])
This all above can be written in a class as
class TransformerEncoder(Layer):
def __init__(self, HEAD_COUNT):
super(TransformerEncoder, self).__init__(name = 'transformer_encoder')
self.layer_norm_1 = LayerNormalization()
self.multi_head_att = MultiHeadAttention(HEAD_COUNT, 768 )
self.layer_norm_2 = LayerNormalization()
self.dense_1 = Dense(768, activation = tf.nn.gelu)
self.dense_2 = Dense(768, activation = tf.nn.gelu)
def call(self, inputPatches):
x_1 = self.layer_norm_1(inputPatches)
x_1 = self.multi_head_att(x_1, x_1)
x_1 = Add()([x_1, inputPatches])
x_2 = self.layer_norm_2(x_1)
x_2 = self.dense_1(x_2)
x_2 = self.dense_2(x_2)
output = Add()([x_2, x_1])
return output
Now joining the Patch creator and the Transformer encoder into a Vision Transformer model
Start with PatchCreator
patch_creator = PatchCreator()
we need L Transformer encoders
transformer_encoders = [TransformerEncoder(HEAD_COUNT) for _ in range(LAYER_COUNT)]
The final orange MLP head which contains 2 dense layers with gelu non-linearity
mlp_1 = Dense(DENSE_COUNT, tf.nn.gelu)
mlp_2 = Dense(DENSE_COUNT, tf.nn.gelu)
The last classification step to generate our classes
fc_classifier = Dense(CLASSES_COUNT, activation = 'softmax')
Calling the layers in order
x = patch_creator(input)
for i in range(self.LAYER_COUNT):
x = transformer_encoders[i](x)
x = Flatten()(x)
//MLP head
x = mlp_1(x)
x = mlp_2(x)
//final classification
x = fc_classifier(x)
This all above can be written in a class as
class VisionTransformer(Model):
def __init__(self, HEAD_COUNT, LAYER_COUNT, DENSE_COUNT, CLASSES_COUNT):
super(VisionTransformer, self).__init__(name = 'vision_transformer')
self.LAYER_COUNT = LAYER_COUNT
self.patch_encoder = PatchEncoder()
self.transformer_encoders = [TransformerEncoder(HEAD_COUNT) for _ in range(LAYER_COUNT)]
self.mlp_1 = Dense(DENSE_COUNT, tf.nn.gelu)
self.mlp_2 = Dense(DENSE_COUNT, tf.nn.gelu)
self.fc_classifier = Dense(CLASSES_COUNT, activation = 'softmax')
def call(self, input, training = True):
x = self.patch_encoder(input)
for i in range(self.LAYER_COUNT):
x = self.transformer_encoders[i](x)
x = Flatten()(x)
x = self.mlp_1(x)
x = self.mlp_2(x)
x = self.fc_classifier(x)
return x
Now we can instantiate the vision transformer as
vit = VisionTransformer(
HEAD_COUNT = 4, LAYER_COUNT = 2,
DENSE_COUNT = 128, CLASSES_COUNT = 3)
And training of the model can be done like
history = vit.fit( training_dataset, validation_dataset, epochs = 10, verbose = 1)
Some notes about the using the transformer model usage in computer vision
As described in paper, the accuracy of this model is better than ResNet only when the training dataset is significantly large.
This is because to generate context from images, this model needs good amount of data since the images gets divided into patches. A CNN based on ResNet gains that accuracy faster because it looks at the whole image.
In case of large training sets the accuracy of this transformer model exceeds the accuracy for ResNet.
-Cheers Amit Tomar