Search notes:

torch.Tensor: expand()

import torch

t = torch.tensor([ 1, 2, 3, 4 ])
print(t)
#
#  tensor([1, 2, 3, 4])


u = t.expand(3, -1)

#
#  expand() creates a view, therefore, changes to
#  t are reflected in u.
#
t[1  ] = 9

#
#  Because u is a view, changes to u are reflected
#  in t as well (and by extension in all rows of u):
#
u[2,3] = 8

print(u)
#
#  tensor([[1, 9, 3, 8],
#          [1, 9, 3, 8],
#          [1, 9, 3, 8]])

print(t)
#
#  tensor([1, 9, 3, 8])

See also

view()

Index