Today I was reviewing code for a convolutional neural network and wanted to understand better what transforms.ToTensor()
and transforms.Normalize()
were actually doing to the image data. These are typically used as follows:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
The transform can then be applied to a dataset:
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transform, # here
download=True)
These transforms are part of the torchvision.transforms
module, which provides a variety of common image transformations for preprocessing data.
I've saved a test image from the MNIST dataset of 9.
The code to convert it to a NumPy array with values ranging from 0 to 255:
# URL of the image
url = 'https://static.andypai.me/mnist9.jpg'
# Download the image
response = requests.get(url)
img = Image.open(BytesIO(response.content))
# Convert the image to grayscale
img = img.convert('L')
# Convert the image to a numpy array
img_array = np.array(img) # 28x28 matrix
Plot to verify:
import matplotlib.pyplot as plt
# Plotting the numpy array
plt.imshow(img_array, cmap='gray')
plt.axis('off')
plt.show()
The ToTensor()
transform converts a PIL Image or NumPy ndarray (representing an image) into a PyTorch tensor.
Here's what it does under the hood:
Changes Data Type: Converts the image data from the original data type (e.g., uint8 for PIL Images) to torch.float32
.
Reshapes Dimensions: Rearrange the dimensions from (Height, Width, Channels) in a PIL image to (Channels, Height, Width), which is the standard format for PyTorch tensors.
Let's consider a 2x2 RGB image represented as a NumPy ndarray:
import numpy as np
image_np = np.array([
[[255, 0, 0], [0, 255, 0]], # Red, Green
[[0, 0, 255], [255, 255, 255]] # Blue, White
], dtype=np.uint8) # Ensure the dtype is uint8
Applying ToTensor()
, this converts the pixel values to the range [0, 1] and changes the dimension order to (C, H, W).
from torchvision import transforms
to_tensor = transforms.ToTensor()
image_tensor = to_tensor(image_np)
print(image_tensor)
Output:
tensor([[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 1.]],
[[0., 0.],
[1., 1.]]])
The Normalize()
transform normalizes a tensor by subtracting the mean and dividing by the standard deviation for each channel:
# Normalize((mean_R, mean_G, mean_B), (std_R, std_G, std_B))
normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
normalized_image = normalize(image_tensor)
print(normalized_image)
Output:
tensor([[[ 1., -1.],
[-1., 1.]],
[[-1., 1.],
[-1., 1.]],
[[-1., -1.],
[ 1., 1.]]])
Using the image_tensor
from the previous example:
Subtract the Mean: Subtract the specified mean (0.5 in this case) from each channel:
1.0 - 0.5 = 0.5
, 0.0 - 0.5 = -0.5
, etc.0.0 - 0.5 = -0.5
, 1.0 - 0.5 = 0.5
, etc.0.0 - 0.5 = -0.5
, 1.0 - 0.5 = 0.5
, etc.Divide by Standard Deviation: Divide the results by the specified standard deviation (0.5 in this case):
0.5 / 0.5 = 1.0
, -0.5 / 0.5 = -1.0
, etc.-0.5 / 0.5 = -1.0
, 0.5 / 0.5 = 1.0
, etc.-0.5 / 0.5 = -1.0
, 0.5 / 0.5 = 1.0
, etc.The resulting tensor normalized_image
contains values normalized around 0, typically in the range [-1, 1].
Normalization is a crucial preprocessing step for several reasons:
Improved Numerical Stability: By scaling the features to a similar range, normalization helps prevent issues with numerical stability during training, especially for models with multiple layers.
Faster Convergence: Normalization can help the optimization algorithm converge faster to a good solution, as the gradients are more well-behaved.
Better Generalization: Normalization can improve the model's ability to generalize to new data by reducing the influence of features with large scales.