PyTorch example: Hello World with Fashion MNIST

This PyTorch «Hello World» example is basically copied from this Microsoft learn page.
I have slightly changed some code so that I was able to better understand what was going and I have added a few assert statements (mostly to verify the dimensions of the tensors) and also commented the code a bit.

Import libraries

For this example, we need the torch and torchvision library:
import torch
import torchvision

Download training and test data

We download the Fashion MNIST data set with torchvision:
training_data = torchvision.datasets.FashionMNIST(
    root      ='data',
    train     = True,
    download  = True,
    transform = torchvision.transforms.ToTensor(),

test_data = torchvision.datasets.FashionMNIST(
    root      ='data',
    train     = False,
    download  = True,
    transform = torchvision.transforms.ToTensor(),
See also this analysis of torchvision data sets to understand the format of training_data and test_data.

Determine the device

If we happen to have a CUDA device at our disposal, we want to know…
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
We store the value of the available device in the global variable device so that we can move data to this device if needed.

Define the model for the neural network

We define the neural network model. It inherits from torch.nn.Module:
class MNISTFashionNN(torch.nn.Module):
    def __init__(self):
        super(MNISTFashionNN, self).__init__()
      # create a function, self.flatten, that will convert an 28x28 input 
      # matrix (the image pixels) to a 784 element vector
        self.flatten = torch.nn.Flatten()
        self.layers  = torch.nn.Sequential(
                                                          #   Create a model with
          #                  IN  , OUT ,  Activation      #  
            torch.nn.Linear(28*28, 512), torch.nn.ReLU(), # an input layer with 28x28 or 784 inputs (features/pixel) and 512 outputs
            torch.nn.Linear(  512, 512), torch.nn.ReLU(), # a hidden layer with 512 inputs and 512 outputs
            torch.nn.Linear(  512,  10), torch.nn.ReLU()  # and an output layer with 512 inputs and 10 outputs

    def forward(self, x):
        assert x[0].shape == (1, 28, 28) or \
               x.shape    == (1, 28, 28), f'wrong assumption about x.shape, it is {x[0].shape}/{x.shape}'
        xflat   = self.flatten(x)
        assert xflat[0].shape == (784, ), f'wrong assumption about xflat.shape, it is {xflat[0].shape}'
        logits = self.layers(xflat)
        assert logits[0].shape == (10, ), f'wrong assumption about logits.shape, it is {logits[0].shape}'
        return logits
We also put the model on the device we determined in the previous step:
model = MNISTFashionNN().to(device)

Loss function and optimizer

In order to be able to train the model, we need
The learning rate is a so called hyperparameter that influences the speed with which the model learns, but also its accuracy.
learning_rate = 0.001

loss_fn   = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

train() and test() function


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    assert size == 60000, f'Expected size to be 60000, but size was {size}'
    for batch_nr, (X, y) in enumerate(dataloader):
      # Iterate over each batch.
      # dataloader provides the sample data (X) and target values (y)
      # Except for the last batch, X and y contain 'batch size' elements.
      # X[0] is a tensor of 28x28 values (i. e. the image to be classified)
      # y[0] is classified value of the image.
        assert X[0].shape == (1, 28, 28),  'wrong assumption about X[0].shape'
        assert y[0].shape == (         ),  'wrong assumption about y[0].shape'

        X, y =,
      # Use the model to predict the label for each image in the batch
      # The model returns a list with 10 floats. The higher the float
      # the more likely the model believes the image to correspond to the
      # indicated label.
        pred = model(X)
        assert pred[0].shape == (10, ),  f'wrong assumption about pred[0].shape, it is {pred[0].shape}'
      # Calculate loss. The smaller the loss, the more accurate the predictions.
      # The goal of model training is to minimize the amount of the loss.
        loss = loss_fn(pred, y)
        assert loss.shape    == (    ),  f'wrong assumption about loss.shape, it is {type(loss[0])}'
      # Backpropagation
      # optimizer.step() updates the parameters of the model based on the gradients
      # computed by the backpropagation.

        if batch_nr % 100 == 0:
           loss, current = loss.item(), batch_nr * len(X)
           print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model):
  # Test the model.
  # Some of the code is similar to that of train()
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y =,
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
          # argmax(d) returns the index of the element with the maximum value
          # in dimension d.
          # correct counts the number of correct predictions/
            correct   += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct   /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Data loaders

In order to provide training and test samples to the train() and test() function, we create two data loaders ( These provide samples with the correct batch size to the train() and test() function defined in the previous step.
batch_size = 64

train_dataloader =, batch_size=batch_size)
test_dataloader  =    , batch_size=batch_size)

Train the model

We train the model in 15 epochs:
epochs = 15

for ep in range(epochs):
    print(f'Epoch {ep+1}\n-------------------------------')
    train(train_dataloader, model, loss_fn, optimizer)
    test (test_dataloader , model)
print('End of training')

Using the model

We're finally ready to use the model:
classes = [
    "Ankle boot",

test_item = 4


x, y = test_data[test_item][0], test_data[test_item][1]
x =

with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')


