Search notes:

torch.nn.Flatten

t = torch.Tensor(
    [
       [
          [  1.1 ,  2.2  ,  3.3 ],
          [  4.4 ,  5.5  ,  6.6 ]
       ],
       [
          [  7.7 ,  8.8  ,  9.9 ],
          [ 10.1 , 11.11 , 12.12]
       ]
    ])

print(t.shape) # torch.Size([2, 2, 3])

f = torch.nn.Flatten()
u = f(t)

print(u.shape) # torch.Size([2, 6])

print(u)
#  tensor([[ 1.1000,  2.2000,  3.3000,  4.4000,  5.5000,  6.6000],
#          [ 7.7000,  8.8000,  9.9000, 10.1000, 11.1100, 12.1200]])

See also

Flattining a tensor is used for torch.nn.Sequential

Index