What are you wearing? Building a CNN model to predict articles of clothing.

Matthew MacFarquhar
4 min readJan 2, 2023

--

In this article, we will explore the architecture of a CNN, talk about what convolution does, and build a functioning model to predict Fashion MNIST data with a 90% accuracy among 10 classes.

What is Convolution?

So what is this special process that CNNs get their name from? When we Convolve over an image we apply matrix multiplication using another “Kernel” matrix which has some weights for the values (these are learned during the training process). In the above example we have a 5X5 image and a 3X3 Kernel. we pass the Filter along the image and compute the result of the 3X3 section of the image multiplied by our Filter to get the output Feature Map. in this case our output will be 3X3 since we can only move the Filter twice to the right (so 3 columns total) and twice down (so 3 rows total). In general the formula is …

output_dim = (input_dim — kernel_dim)/ stride + 1

Stride is how much to move the Kernel each time, by default the stride is 1 so we only move it one column or row over each time.

Why use convolution?

So why use convolution and not just a fully connected layer where each pixel maps to one node?

The main reason convolution is so useful is that it reduces the number of parameters significantly, which in turn makes a model easier to train. This works great with images which are usually quite large if we just flattened them, so we do a few iterations of convolution to reduce the size before we flatten the image.

Another advantage of using convolution is that these kernel weights are shared over the entire image. For classification, we don’t care where we detect our target, we only care that it is detected. In a fully connected model for detecting Cat vs Dog, we would have to learn how to detect Cats and Dogs in the top right, bottom left, center, etc… of our image separately. By sharing the Detector Filter weights and passing the filter over the entire image, we can generate a generalized feature detector which will trigger when it is passed over its target.

Let’s Build!

Now Let’s Build, we are going to create a CNN to classify clothing items in the Fashion MNIST data set, here is the google colabs link if you want to follow along (make sure to use a GPU instance for faster training).

First we will grab our needed imports.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout
from tensorflow.keras.models import Model

Next we will download and normalize our image data.

fashion_mnist = tf.keras.datasets.fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print("x_train.shape:", x_train.shape)

the images are 28 X 28 and there are 60,000 of them.

x_train.shape: (60000, 28, 28)

We need to reshape our inputs to comply with Tensorflow’s expectations

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print(x_train.shape)

(60000, 28, 28, 1)

We get the number of classes in our dataset by getting the unique values in the y_train.

K = len(set(y_train))

# K = 10
print("number of classes:", K)

Now we will build our model.

i = Input(shape=x_train[0].shape)
x = Conv2D(32, (3,3), strides=2, activation='relu')(i)
x = Conv2D(64, (3,3), strides=2, activation='relu')(x)
x = Conv2D(128, (3,3), strides=2, activation='relu')(x)
x = Flatten()(x)
x = Dropout(0.2)(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.2)(x)
x = Dense(K, activation='softmax')(x)

model = Model(i,x)
model.summary()

We have 3 convolution layers with 32, 64 and 128 filters respectively. We then Flatten the output and add two Dense layers with 512 and 10 nodes (and we had some dropout to help with regularization).

Now let’s train the model for 10 Epochs

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
r = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10)

We can then plot the loss and accuracy and see we get around 90% on the validation set.

plt.plot(r.history['loss'], label='loss')
plt.plot(r.history['val_loss'], label='val_loss')
plt.legend()
plt.plot(r.history['accuracy'], label='acc')
plt.plot(r.history['val_accuracy'], label='val_acc')
plt.legend()

Lastly, if you want to use your model later, we can save, zip and download the model.

from google.colab import files

model.save('saved_model/my_model')
!tar -czvf model.tar.gz saved_model/my_model/
files.download('model.tar.gz')

--

--

Matthew MacFarquhar
Matthew MacFarquhar

Written by Matthew MacFarquhar

I am a software engineer working for Amazon living in SF/NYC.

Responses (1)