andy pai's tils

Understanding PyTorch Transforms: ToTensor() and Normalize()

Today I was reviewing code for a convolutional neural network and wanted to understand better what transforms.ToTensor() and transforms.Normalize() were actually doing to the image data. These are typically used as follows:

transform = transforms.Compose(
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

The transform can then be applied to a dataset:

train_dataset = torchvision.datasets.MNIST(root='./data',
                                           transform=transform, # here

These transforms are part of the torchvision.transforms module, which provides a variety of common image transformations for preprocessing data.


I've saved a test image from the MNIST dataset of 9.


The code to convert it to a NumPy array with values ranging from 0 to 255:

# URL of the image
url = ''

# Download the image
response = requests.get(url)
img =

# Convert the image to grayscale
img = img.convert('L')

# Convert the image to a numpy array
img_array = np.array(img) # 28x28 matrix

Plot to verify:

import matplotlib.pyplot as plt

# Plotting the numpy array
plt.imshow(img_array, cmap='gray')
mnist9 matplotlib


The ToTensor() transform converts a PIL Image or NumPy ndarray (representing an image) into a PyTorch tensor.

Here's what it does under the hood:

  1. Changes Data Type: Converts the image data from the original data type (e.g., uint8 for PIL Images) to torch.float32.

  2. Reshapes Dimensions: Rearrange the dimensions from (Height, Width, Channels) in a PIL image to (Channels, Height, Width), which is the standard format for PyTorch tensors.

Illustrative Calculation

Let's consider a 2x2 RGB image represented as a NumPy ndarray:

import numpy as np

image_np = np.array([
  [[255, 0, 0], [0, 255, 0]], # Red, Green
  [[0, 0, 255], [255, 255, 255]]  # Blue, White
], dtype=np.uint8)  # Ensure the dtype is uint8

Applying ToTensor(), this converts the pixel values to the range [0, 1] and changes the dimension order to (C, H, W).

from torchvision import transforms

to_tensor = transforms.ToTensor()
image_tensor = to_tensor(image_np)



tensor([[[1., 0.],
         [0., 1.]],

        [[0., 1.],
         [0., 1.]],

        [[0., 0.],
         [1., 1.]]])


The Normalize() transform normalizes a tensor by subtracting the mean and dividing by the standard deviation for each channel:

# Normalize((mean_R, mean_G, mean_B), (std_R, std_G, std_B))
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
normalized_image = normalize(image_tensor)



tensor([[[ 1., -1.],
         [-1.,  1.]],

        [[-1.,  1.],
         [-1.,  1.]],

        [[-1., -1.],
         [ 1.,  1.]]])

Illustrative Calculation

Using the image_tensor from the previous example:

  1. Subtract the Mean: Subtract the specified mean (0.5 in this case) from each channel:

    • Red Channel: 1.0 - 0.5 = 0.5, 0.0 - 0.5 = -0.5, etc.
    • Green Channel: 0.0 - 0.5 = -0.5, 1.0 - 0.5 = 0.5, etc.
    • Blue Channel: 0.0 - 0.5 = -0.5, 1.0 - 0.5 = 0.5, etc.
  2. Divide by Standard Deviation: Divide the results by the specified standard deviation (0.5 in this case):

    • Red Channel: 0.5 / 0.5 = 1.0, -0.5 / 0.5 = -1.0, etc.
    • Green Channel: -0.5 / 0.5 = -1.0, 0.5 / 0.5 = 1.0, etc.
    • Blue Channel: -0.5 / 0.5 = -1.0, 0.5 / 0.5 = 1.0, etc.

The resulting tensor normalized_image contains values normalized around 0, typically in the range [-1, 1].

Why Normalize?

Normalization is a crucial preprocessing step for several reasons:
