How to resize an image with PyTorch

In this post, we will learn how to resize an image using PyTorch. PyTorch offers a numerous useful functions to manipulate or transform images. With PyTorch’s reSize() function, we can resize images.

We will see a simple example of resizing a single image using Pytorch’s torchvision v2. First, let us load Numpy and Matplotlib.

import numpy as np
import matplotlib.pyplot as plt 

Let us load PyTorch specific packages and modules that we will use to manipulate an image.

import torch
import torchvision
from torchvision.io import read_image
from torchvision.transforms import v2

Load an image a tensor object

We can load an image as tensor object using read_image() function available in torchvision.io.

image = read_image("Resize_an_image_with_PyTorch.png")

We can see that the loaded image is of tensor type and it is of size 4x3024x4032.

print(f"{type(image) = }, {image.shape = }")

type(image) = <class 'torch.Tensor'>, image.shape = torch.Size([4, 3024, 4032])

Display an image stored as tensor object

In python, we can use Matplotlib’s imshow() to visualize an image. When we try to display our image with imshow(), we get a TypeError due to invalid shape.

plt.imshow(image)

TypeError: Invalid shape (4, 3024, 4032) for image data

Python’s imshow() expects image to be of dimenstion (M, N, 4): an image with RGBA values, where the first two dimensions (M, N) define the rows and columns of the image. However as we saw earlier the tensor object has image dimension in different order. We can change the dimension order with permute() function as shown below.

Out-of-range RGB(A) values are clipped.

image.permute(1, 2, 0).shape

torch.Size([3024, 4032, 4])

Now we can use imshow() with dimension it expects, to display an image. In the example below we display the original image before resizing it.

fig=plt.figure(figsize=(12, 9))
plt.imshow(image.permute(1, 2, 0)  )
plt.tight_layout()
plt.savefig('Display_image_from_tensor.png')
Displaying an image stored tensor object with Python
Displaying an image stored tensor object with Python

Resizing an image with ReSize() function

We can use PyTorch’s ReSize() function to resize an image. Here we specify the new dimension we want using the “size” argument and create ReSize object.

transform = v2.Resize(size = (400,300))

We have use the default options other than specifying the dimension we want.

transform

Resize(size=[400, 300], interpolation=InterpolationMode.BILINEAR, antialias=True)

Now we apply the transform object on the image to resize. We store the resized image in a variable.

image_resized = transform(image)

And the resized image can be displayed with imshow() function as shown before. You can easily notice the dimension difference between the original and resized image.

plt.imshow(image_resized.permute(1, 2, 0)  )
Resize image with PyTorch