Analysis of the data set format
torchvision.datasets.FashionMNIST
downloads the data set if is not already found in the directory indicated with the root
parameter:
import torchvision
raw_data = torchvision.datasets.FashionMNIST(
root ='data',
train = True,
download = True
)
Inspecting the «raw data», we see that it a list of torchvision.datasets.mnist.FashionMNIST
objects.
Each object is a
tuple
with two elements.
The first element is a
PIL.Image.Image
object (which represents the data to be classified), the second object an
int
(which represents the label or target).
print(type(raw_data )) # <class 'torchvision.datasets.mnist.FashionMNIST'>
print(type(raw_data[0] )) # <class 'tuple'>
print(len (raw_data[0] )) # 2
print(type(raw_data[0][0])) # <class 'PIL.Image.Image'>
print(type(raw_data[0][1])) # <class 'int'>
transformed_data = torchvision.datasets.FashionMNIST(
root ='data',
train = True,
download = True,
transform = torchvision.transforms.ToTensor(),
)
print(type(transformed_data )) # <class 'torchvision.datasets.mnist.FashionMNIST'>
print(type(transformed_data[0] )) # <class 'tuple'>
print(len (transformed_data[0] )) # 2
print(type(transformed_data[0][0])) # <class 'torch.Tensor'>
print(type(transformed_data[0][1])) # <class 'int'>
The shape of these tensors is 1, 28, 28
:
image_index = 123
print(transformed_data[image_index][0].shape) # torch.Size([1, 28, 28])
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(1, 1))
plt.imshow(
transformed_data[image_index][0].squeeze(), # squeeze(): Remove dimensions with size = 1.
cmap='gray_r'
)
plt.axis('off')
plt.show()