In this post, we will learn how to resize an image using PyTorch. PyTorch offers a numerous useful functions to manipulate or transform images. With PyTorch’s reSize() function, we can resize images.
We will see a simple example of resizing a single image using Pytorch’s torchvision v2. First, let us load Numpy and Matplotlib.
import numpy as np import matplotlib.pyplot as plt
Let us load PyTorch specific packages and modules that we will use to manipulate an image.
import torch import torchvision from torchvision.io import read_image from torchvision.transforms import v2
Load an image a tensor object
We can load an image as tensor object using read_image() function available in torchvision.io.
image = read_image("Resize_an_image_with_PyTorch.png")
We can see that the loaded image is of tensor type and it is of size 4x3024x4032.
print(f"{type(image) = }, {image.shape = }") type(image) = <class 'torch.Tensor'>, image.shape = torch.Size([4, 3024, 4032])
Display an image stored as tensor object
In python, we can use Matplotlib’s imshow() to visualize an image. When we try to display our image with imshow(), we get a TypeError due to invalid shape.
plt.imshow(image) TypeError: Invalid shape (4, 3024, 4032) for image data
Python’s imshow() expects image to be of dimenstion (M, N, 4): an image with RGBA values, where the first two dimensions (M, N) define the rows and columns of the image. However as we saw earlier the tensor object has image dimension in different order. We can change the dimension order with permute() function as shown below.
Out-of-range RGB(A) values are clipped.
image.permute(1, 2, 0).shape torch.Size([3024, 4032, 4])
Now we can use imshow() with dimension it expects, to display an image. In the example below we display the original image before resizing it.
fig=plt.figure(figsize=(12, 9)) plt.imshow(image.permute(1, 2, 0) ) plt.tight_layout() plt.savefig('Display_image_from_tensor.png')
Resizing an image with ReSize() function
We can use PyTorch’s ReSize() function to resize an image. Here we specify the new dimension we want using the “size” argument and create ReSize object.
transform = v2.Resize(size = (400,300))
We have use the default options other than specifying the dimension we want.
transform Resize(size=[400, 300], interpolation=InterpolationMode.BILINEAR, antialias=True)
Now we apply the transform object on the image to resize. We store the resized image in a variable.
image_resized = transform(image)
And the resized image can be displayed with imshow() function as shown before. You can easily notice the dimension difference between the original and resized image.
plt.imshow(image_resized.permute(1, 2, 0) )