torch.mean()

Official Document

torch.mean(input, dim, keepdim=False, out=None)

Parameters:

  • input (Tensor) – the input tensor
  • dim (int) – the dimension to reduce
  • keepdim (bool, optional) – whether the output tensor has dim retained or not
  • out (Tensor) – the output tensor

Example:

a = torch.randn(2,4)

print(a)

print(a.shape)

>>tensor([[-0.8934, 1.9322, -0.0035, -0.3122],

[ 0.0951, -0.3979, 0.1591, 1.4001]])

torch.Size([2, 4])


b=torch.mean(a,0,True)

print(b)

print(b.shape)

>>tensor([[-0.3992, 0.7672, 0.0778, 0.5440]])

torch.Size([1, 4])


c=torch.mean(a,1,True)

print(c)

print(c.shape)

>>tensor([[0.1808],

[0.3141]])

torch.Size([2, 1])


d=torch.mean(a,(0,1),True)

print(d)

print(d.shape)

>>tensor([[0.2474]])

torch.Size([1, 1])