Understanding Array Dimensions in NumPy
I find the use of the word dimensions NumPy kinda confusing. The reference I was approaching it from was dimensions purely in terms of a coordinate plane (like (x)
for 1D or (x, y)
for 2D). This mental model works fine in geometry, but it's slightly misleading when working with array libraries like NumPy, PyTorch, and TensorFlow.
My Original Thinking
I used to think:
- 1D → single coordinate
(x)
- 2D → pairs
(x, y)
- 3D → triplets
(x, y, z)
But in NumPy and similar libraries, "dimension" has a slightly different meaning.
NumPy's Definition of Dimension:
- Dimension = number of indices needed to access an element.
Examples:
-
1D array (
shape: (3,)
):arr = np.array([1, 2, 3]) arr[1] # single index → 2
-
2D array (
shape: (2, 3)
):arr = np.array([[1, 2, 3], [4, 5, 6]]) arr[1, 2] # two indices → 6
-
3D array (
shape: (2, 2, 2)
):arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) arr[1, 0, 1] # three indices → 6
Commonly Used Broadcasting Operations
Some frequently used functions to manipulate dimensions clearly:
-
Unsqueeze (
np.expand_dims
in NumPy):arr = np.array([1, 2, 3]) arr = np.expand_dims(arr, axis=0) # Shape (1, 3)
PyTorch has a similar function called unsqueeze
.
-
Stack (
np.stack
):
Combines multiple arrays along a new dimension:a = np.array([1, 2, 3]) b = np.array([4, 5, 6]) stacked = np.stack((a, b), axis=0) # Shape (2, 3)
-
Unstack (
np.split
or indexing):
Splitting or indexing along an axis:stacked = np.array([[1, 2, 3], [4, 5, 6]]) first_row = stacked[0, :] # Shape (3,)
Heuristics for Predicting Resulting Dimensions
Some simple rules to quickly guess the dimension after an operation:
-
Element-wise operations (like addition or multiplication):
- Result dimensions match the shape after broadcasting.
- Rule: Compare shapes right-to-left. Dimensions of size 1 are stretched to match.
arr1 shape: (3, 1) arr2 shape: (1, 4) result shape: (3, 4)
-
Matrix multiplication (
@
ornp.dot
):- Result shape:
(m, n) @ (n, p) → (m, p)
arr1 shape: (2, 3) arr2 shape: (3, 4) result shape: (2, 4)
- Result shape:
The key takeaway: NumPy dimensions = number of indices required, not just coordinate pairs. Understanding this made broadcasting finally click!