Fashion-MNIST data from PyTorch

Visualizing Fashion MNIST data images
Visualizing Fashion MNIST data images

In this short tutorial we will learn how to get access to Fashion-MNIST dataset. Fashion-MNIST is one of the image datasets useful for learning and training Deep Learning algorithms in action. PyTorch’s torchvision module makes it easy access Fashion-MNIST dataset.

Fashion-MNIST dataset is from Zalando’s article images and it contains a training set of 60,000 examples and a test set of 10,000 examples. Each image is a 28×28 grayscale image with a label from 10 classes, like T-shirt/top, Trouser, and Bag.

Fashion-MNIST Data

In this post, we will learn how to get Fashion-MNIST data using torchvision in PyTorch. And we will also learn how to get a peek at some of the sample images using Matplotlib in Python.

First, let us load Numpy and Matplotlib in Python.

import numpy as np
import matplotlib.pyplot as plt 

Next, we will load relevant packages/modules from PyTorch. PyTorch’s datasets and DataLoader enable to access datasets. Torchvision package is PyTorch’s package for all things vision. It contains popular image datasets, model architectures, and common image transformations for computer vision.

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

Torchvision datasets contain numerous popluar image datasets useful for a variety of machine learning tasks like Image classification, Image Captioning, and video classification.

Here is a few of the built-in datasets in torchvision, MNIST, CIFAR10, CIFAR100,ImageNet,

We can take a look at the list of available datasets using dir() function on datasets.

dir(datasets)
['CIFAR10',
 'CIFAR100',
 'CLEVRClassification',
 'CREStereo',
 'Caltech101',
 'Caltech256',
 'CarlaStereo',
 'CelebA',
 'Cityscapes',
 'CocoCaptions',
 'CocoDetection',
 'Country211',
 'DTD',
 'DatasetFolder',
 'EMNIST',
 'ETH3DStereo',
 'EuroSAT',
 'FER2013',
 'FGVCAircraft',
 'FakeData',
 'FallingThingsStereo',
 'FashionMNIST',
 'Flickr30k',
 'Flickr8k',
..
..
..
..
]

Download Fashion-MNIST dataset locally

We can use datasets.FashionMNIST() function to download the datasets and store it locally on your computer. Here we download the training dataset as PyTorch ready tensor object.

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

Note that the data set needs to be downloaded using internet connection.

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|???????????????????????????????????????????????????????????????????????????????????????????| 26421880/26421880 [00:09<00:00, 2773741.58it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

The variable containing the data shows the summary of the downloaded FashionMNIST dataset.

training_data

Dataset FashionMNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

We can use DataLoader to load the data by batches. In this example below we are loading 16 samples from the data set. We can access the individual images and their labels using iter() and next() function.

train_loader = torch.utils.data.DataLoader(training_data, batch_size=16)
images,labels = next(iter(train_loader))

We can check the size of the images sampled. Note that the first element is the batch size. And the last two elements are the size of the image.

images.shape
torch.Size([16, 1, 28, 28])
<h3> Visualizing the Fashin-MNIST data with Matplotlib </h3>

Now that we have the data we needed in a variable, we can loop through and display the images with their labels using Matplotlib in Python.

fig=plt.figure(figsize=(9, 9))
for i in range(16):
    ax=fig.add_subplot(4,4,i+1)
    ax.imshow(np.squeeze(images[i]), cmap="gray")
    ax.set_title(labels[i].item())
plt.tight_layout()
plt.savefig('Pytorch_FashionMNIST_data_images_gray.png')
Visualizing Fashion MNIST data images