Search notes:

torchvision.datasets

All data sets found in torchvision.datasets inherit from torch.utils.data.Dataset, so they can be passed to an instance of torch.utils.data.DataLoader.
Every TorchVision data set includes two arguments, transform and target_transform. These can be used to modify the sample and target data.

Analysis of the data set format

torchvision.datasets.FashionMNIST downloads the data set if is not already found in the directory indicated with the root parameter:
import torchvision

raw_data = torchvision.datasets.FashionMNIST(
    root      ='data',
    train     = True,
    download  = True
)
Inspecting the «raw data», we see that it a list of torchvision.datasets.mnist.FashionMNIST objects.
Each object is a tuple with two elements.
The first element is a PIL.Image.Image object (which represents the data to be classified), the second object an int (which represents the label or target).
print(type(raw_data      ))  # <class 'torchvision.datasets.mnist.FashionMNIST'>
print(type(raw_data[0]   ))  # <class 'tuple'>
print(len (raw_data[0]   ))  # 2
print(type(raw_data[0][0]))  # <class 'PIL.Image.Image'>
print(type(raw_data[0][1]))  # <class 'int'>
In order to be able to work with the data, is is typically transformed using the transform parameter. It turns out using torchvision.transforms.ToTensor converted the PIL images to torch.Tensor objects:
transformed_data = torchvision.datasets.FashionMNIST(
    root      ='data',
    train     = True,
    download  = True,
    transform = torchvision.transforms.ToTensor(),
)

print(type(transformed_data      ))  # <class 'torchvision.datasets.mnist.FashionMNIST'>
print(type(transformed_data[0]   ))  # <class 'tuple'>
print(len (transformed_data[0]   ))  # 2
print(type(transformed_data[0][0]))  # <class 'torch.Tensor'>
print(type(transformed_data[0][1]))  # <class 'int'>
The shape of these tensors is 1, 28, 28:
image_index = 123
print(transformed_data[image_index][0].shape)  # torch.Size([1, 28, 28])
matplotlib allows to plot an image like so:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(1, 1))
plt.imshow(
       transformed_data[image_index][0].squeeze(), # squeeze(): Remove dimensions with size = 1.
       cmap='gray_r'
     )

plt.axis('off')
plt.show()

Index