Numpy中axis参数的理解

要点:

  1. Numpy 中 axis=i 中的 i 代表的轴不固定,不同维度编排不同,可以把它看成是种由外向内的编排方式
  2. Numpy 中 axis=i 中的 i 表示数组索引的第 i+1 层
  3. axis=i 即沿着 i 轴方向进行相应操作
  4. axis=i 即对索引的第 i+1 层元素进行相应操作

道生一,一生二,二生三,三生万物。

——《道德经》

先从一维开始。

1
2
3
4
5
6
In [1]: import numpy as np

In [2]: d1 = np.array([1,2,3])

In [3]: d1
Out[3]: array([1, 2, 3])

因为只有一个维度,该数组的排列方向就是 axis=0 的方向。

numpy中axis的理解

在一维的基础上,增加一个维度,变成一个二维数组。

1
2
3
4
5
6
In [4]: d2 = np.array([[1,2,3],[4,5,6]])

In [5]: d2
Out[5]: 
array([[1, 2, 3],
       [4, 5, 6]])

对于二维数组,Numpy 把新增加的这个维度编在前面即 axis=0,而把原来的低维度编在后面即 axis=1。对初学者而言,困惑就在这里,我们在小学学习一维坐标系时,它的名称为「x轴」;学习到二维坐标系时,原来的x轴还叫 x,只是增加了一个「y轴」;之后三维是「z轴」。而 Numpy 这是要重新编号的,这也就是要点1:

1.Numpy 中 axis=i 中的 i 代表的轴不固定,不同维度编排不同,可以把它看成是种由外向内的编排方式

numpy中axis的理解2

在二维的基础上,增加一个维度,变成一个三维数组。

1
2
3
4
5
6
7
8
9
In [6]: d3 = np.array([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]])

In [7]: d3
Out[7]: 
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

根据要点1,新来的轴优先,设为 axis=0,于是原来的 axis=0 变成了 axis=1 了,axis=1 的变成了 axis=2。类比我们学习的空间坐标系,第三维相当于与「xy平面」垂直的「z轴」,我们在其方向上堆砌二维数据。(参考资料3的第三维的数据堆砌反了,导致后面的数据排列有问题)

numpy中axis的理解3

对于三维以上的数组同理。总之,axis 的编号是由外向内的。

为什么 axis 这样编排?可能它本身就是从数组索引层级命名的。

对于任意维度 m 的数组A,其 m 维度里的第一个数组是 A[0], 第二个是A[1]等等,它们都是最高维度「m维」内的数组,即数组索引的第一个值对应的正是最高维度的轴的元素(也就是上面说的相对低维轴的新加的轴);A[0]里的第一个数组是A[0][0],第二个是A[0][1]等等,即数组索引的第二个值对应的是次高维度的轴的元素,依此类推。于是我们有了要点2:

2.Numpy 中 axis=i 中的 i 表示数组索引的第 i+1 层

numpy中axis的理解4

例如上面的例子中,三维数组d3的第一层索引对应的就是 axis=0,d3[1]取的就是其第二个元素,是个二维数组。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
In [7]: d3
Out[7]: 
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

In [8]: d3[0]
Out[8]: 
array([[ 7,  8,  9],
       [10, 11, 12]])

于是,我们可以指定不同的轴对数据进行操作,在空间几何角度也就是要点3:

3.axis=i 即沿着 i 轴方向进行相应操作

对于二维数组,设 axis=0 就是沿着竖直方向进行相应操作,axis=1 是沿着水平方向进行操作。

1
2
3
4
5
6
7
In [9]: d2
Out[9]: 
array([[1, 2, 3],
       [4, 5, 6]])

In [10]: d2.sum(axis=0)
Out[10]: array([5, 7, 9])

numpy中axis的理解5

对于三维数组,设 axis=0 就是沿着垂直方向进行相应操作。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
In [11]: d3
Out[11]: 
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

In [12]: d3.sum(axis=0)
Out[12]: 
array([[ 8, 10, 12],
       [14, 16, 18]])

numpy中axis的理解6

设 axis=1 就是沿着竖直方向进行相应操作。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
In [13]: d3
Out[13]: 
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])

In [14]: d3.sum(axis=1)
Out[14]: 
array([[ 5,  7,  9],
       [17, 19, 21]])

numpy中axis的理解7

注意距离原点更近的数据层排在前([ 5, 7, 9]),越上面的排在越后([17, 19, 21]),组合为[[ 5, 7, 9],[17, 19, 21]],与计算结果一致。

axis=2 同理。

指定不同的轴对数据进行操作,在解析几何角度也就是要点4:

4.axis=i 即对索引的第 i+1 层元素进行相应操作

对于三维数组,axis=0 就是对第一层索引元素进行对应操作,例如sum函数就是将第一层的所有元素的值对应相加。

1

axis=1 就是对第二层索引元素进行对应操作。

2

axis=2 就是对第三层索引元素进行对应操作。

3


参考资料:

  1. Python之NumPy(axis=0 与axis=1)区分 - caiqingfei - 博客园 (cnblogs.com)
  2. (axis=0/1/2...)的透彻理解 - 知乎 (zhihu.com)
  3. NumPy中的维度(dimension)、轴(axis)、秩(rank)的含义 - 知乎 (zhihu.com)