1. torch.tenosrgrad 를 제거한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

tensor_list = [
    torch.tensor([1.0], requires_grad=True),
    torch.tensor([2.0], requires_grad=True),
    torch.tensor([3.0], requires_grad=True),
]

r1 = torch.stack(tensor_list)
r2 = torch.tensor(tensor_list).reshape(3, 1)
r3 = torch.Tensor(tensor_list).reshape(3, 1)

sr1 = torch.sum(r1)
sr2 = torch.sum(r2)
sr3 = torch.sum(r3)

print(sr1)  # tensor(6., grad_fn=<SumBackward0>)
print(sr2)  # tensor(6.)
print(sr3)  # tensor(6.)

태그:

카테고리:

업데이트: