In this short tutorial we will learn how to get access to Fashion-MNIST dataset. Fashion-MNIST is one of the image datasets useful for learning and training Deep Learning algorithms in action. PyTorch’s torchvision module makes it easy access Fashion-MNIST dataset.
Fashion-MNIST dataset is from Zalando’s article images and it contains a training set of 60,000 examples and a test set of 10,000 examples. Each image is a 28×28 grayscale image with a label from 10 classes, like T-shirt/top, Trouser, and Bag.
In this post, we will learn how to get Fashion-MNIST data using torchvision in PyTorch. And we will also learn how to get a peek at some of the sample images using Matplotlib in Python.
First, let us load Numpy and Matplotlib in Python.
import numpy as np import matplotlib.pyplot as plt
Next, we will load relevant packages/modules from PyTorch. PyTorch’s datasets and DataLoader enable to access datasets. Torchvision package is PyTorch’s package for all things vision. It contains popular image datasets, model architectures, and common image transformations for computer vision.
import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor
Torchvision datasets contain numerous popluar image datasets useful for a variety of machine learning tasks like Image classification, Image Captioning, and video classification.
Here is a few of the built-in datasets in torchvision, MNIST, CIFAR10, CIFAR100,ImageNet,
We can take a look at the list of available datasets using dir() function on datasets.
dir(datasets)
['CIFAR10', 'CIFAR100', 'CLEVRClassification', 'CREStereo', 'Caltech101', 'Caltech256', 'CarlaStereo', 'CelebA', 'Cityscapes', 'CocoCaptions', 'CocoDetection', 'Country211', 'DTD', 'DatasetFolder', 'EMNIST', 'ETH3DStereo', 'EuroSAT', 'FER2013', 'FGVCAircraft', 'FakeData', 'FallingThingsStereo', 'FashionMNIST', 'Flickr30k', 'Flickr8k', .. .. .. .. ]
Download Fashion-MNIST dataset locally
We can use datasets.FashionMNIST() function to download the datasets and store it locally on your computer. Here we download the training dataset as PyTorch ready tensor object.
training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() )
Note that the data set needs to be downloaded using internet connection.
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 100%|???????????????????????????????????????????????????????????????????????????????????????????| 26421880/26421880 [00:09<00:00, 2773741.58it/s] Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
The variable containing the data shows the summary of the downloaded FashionMNIST dataset.
training_data Dataset FashionMNIST Number of datapoints: 60000 Root location: data Split: Train StandardTransform Transform: ToTensor()
We can use DataLoader to load the data by batches. In this example below we are loading 16 samples from the data set. We can access the individual images and their labels using iter() and next() function.
train_loader = torch.utils.data.DataLoader(training_data, batch_size=16) images,labels = next(iter(train_loader))
We can check the size of the images sampled. Note that the first element is the batch size. And the last two elements are the size of the image.
images.shape torch.Size([16, 1, 28, 28]) <h3> Visualizing the Fashin-MNIST data with Matplotlib </h3>
Now that we have the data we needed in a variable, we can loop through and display the images with their labels using Matplotlib in Python.
fig=plt.figure(figsize=(9, 9)) for i in range(16): ax=fig.add_subplot(4,4,i+1) ax.imshow(np.squeeze(images[i]), cmap="gray") ax.set_title(labels[i].item()) plt.tight_layout() plt.savefig('Pytorch_FashionMNIST_data_images_gray.png')