t = torch.Tensor( [ [ [ 1.1 , 2.2 , 3.3 ], [ 4.4 , 5.5 , 6.6 ] ], [ [ 7.7 , 8.8 , 9.9 ], [ 10.1 , 11.11 , 12.12] ] ]) print(t.shape) # torch.Size([2, 2, 3]) f = torch.nn.Flatten() u = f(t) print(u.shape) # torch.Size([2, 6]) print(u) # tensor([[ 1.1000, 2.2000, 3.3000, 4.4000, 5.5000, 6.6000], # [ 7.7000, 8.8000, 9.9000, 10.1000, 11.1100, 12.1200]])
torch.nn.Sequential