torch.mean()
torch.mean
(input, dim, keepdim=False, out=None)
Parameters:
- input (Tensor) – the input tensor
- dim (int) – the dimension to reduce
- keepdim (bool, optional) – whether the output tensor has
dim
retained or not - out (Tensor) – the output tensor
Example:
a = torch.randn(2,4)
print(a)
print(a.shape)
>>tensor([[-0.8934, 1.9322, -0.0035, -0.3122],
[ 0.0951, -0.3979, 0.1591, 1.4001]])
torch.Size([2, 4])
b=torch.mean(a,0,True)
print(b)
print(b.shape)
>>tensor([[-0.3992, 0.7672, 0.0778, 0.5440]])
torch.Size([1, 4])
c=torch.mean(a,1,True)
print(c)
print(c.shape)
>>tensor([[0.1808],
[0.3141]])
torch.Size([2, 1])
d=torch.mean(a,(0,1),True)
print(d)
print(d.shape)
>>tensor([[0.2474]])
torch.Size([1, 1])