理解张量点乘 numpy.tensordot

Chenxiao Ma | March 7, 2018

首先理解 Scalar, Vector, Matrix 与 Tensor

  • 常数 scalar 是零维数组,常数在内存中所占的空间(以下称为“大小”)是固定的。
  • 向量 vector 是一维数组,一个向量的大小可以用一个整数表示。
  • 矩阵 matrix 是二维数组,一个矩阵的大小可以用一个有序二元整数对表示, 即(行数,列数)
  • 张量 tensor 是任意维数组(并非无穷维)。一个张量的大小是一个向量, 向量的第 n 个元素描述了张量在第 n 维上的大小,可以用三维张量的长宽高来理解。

由此可见,常数、向量和矩阵都是特殊的张量。 不过我们在交流中通常会使用外延最小的概念,也就是能不说张量就不说张量。

常数、向量和矩阵两两之间的点乘都有为人熟知的定义。 numpy.dot 可以用于进行这些计算。

>>> import numpy as np
>>> a = 1
>>> b = 2
>>> c = np.array([3, 4])
>>> d = np.array([['a', 'b'], ['c', 'd']], dtype=object)
>>> np.dot(a, b)  # 数乘 numpy 建议大家使用 `multiply`
2
>>> np.dot(b, c)
array([6, 8])
>>> np.dot(c, d)
array(['aaacccc', 'bbbdddd'], dtype=object)
>>> e = np.array([[1, 2], [3, 4]])
>>> np.dot(d, e)  # 矩阵乘 numpy 建议大家使用 `matmul`
array([['abbb', 'aabbbb'],
       ['cddd', 'ccdddd']], dtype=object)
>>> np.dot(e, d)
array([['acc', 'bdd'],
       ['aaacccc', 'bbbdddd']], dtype=object)

那么 numpy.tensordot 是在做什么呢?根据官方文档:

Given two tensors (arrays of dimension greater than or equal to one), a and b, and an array_like object containing two array_like objects, (a_axes, b_axes), sum the products of a's and b's elements (components) over the axes specified by a_axes and b_axes. The third argument can be a single non-negative integer_like scalar, N; if it is such, then the last N dimensions of a and the first N dimensions of b are summed over.

翻译一下:

给定两个张量(维度大于等于 1 的数组),ab,以及一个包含两个数组的数组, (a_axes, b_axes),把 ab 的元素的乘积沿着 a_axesb_axes 加和。 如果第三个参数是一个常数 N,那么就沿着 a 的最后 N 个 轴和 b 的前 N 个 轴加和。

那到底啥叫把元素的乘积沿着轴加和呢?官方的示例也很难懂,不如直接阅读源码。

首先看到整个函数的最后四行:

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)

也就是说,其实 tensordot 无非是将两个操作数整理成了两个矩阵, 然后调用 dot 进行了一般的矩阵点乘,再把结果整理成了正确大小的张量。

那么这两个矩阵的大小是多少呢?继续阅读源码:

# nda 是 a 的维数,axes_a 是 a 中要被沿着加和的轴,notin 则是余下的轴
notin = [k for k in range(nda) if k not in axes_a]
# 把要加和的轴连在余下的轴后面
newaxes_a = notin + axes_a
# as_ 是 a.shape,axes_a 是要被加和的轴,所以 N2 是要被加和的那些轴方向上的大小的乘积
N2 = 1
for axis in axes_a:
    N2 *= as_[axis]
# 既然 N2 是新矩阵的列数,新矩阵的行数自然是 a.shape 中剩余元素的乘积
newshape_a = (int(multiply.reduce([as_[ax] for ax in notin])), N2)
# 剩余的轴方向上的大小保留在 olda 数组中
olda = [as_[axis] for axis in notin]

举一个例子,比如 a 的形状是 (5, 4, 2, 3),要加和的轴是后两轴, 那么 N2 = 2 * 3 = 6,最后得到的新矩阵的大小就是 (20, 6)

对于另一个操作数 btensordot 的处理是完全一致的, 只不过把 N2 放在了行数的位置。因为只有这样才能跟 a 做矩阵乘法。 由此也可以看到,a.shape 中与 axes_a 对应位置的元素的乘积 必须和 b.shape 中与 axes_b 对应位置元素的乘积是一样的。

>>> import numpy as np
>>> a = np.ones([5, 4, 2, 3])
>>> b = np.ones([3, 2, 6])
>>> np.tensordot(a, b, 2).shape
(5, 4, 6)
>>> np.tensordot(a, b, (2, 1)).shape
(5, 4, 3, 3, 6)
>>> np.tensordot(a, b, (3, 0)).shape
(5, 4, 2, 2, 6)
>>> np.tensordot(a, b, ((2, 3), (1, 0))).shape
(5, 4, 6)
>>> np.tensordot(a, b, ((-2, -1), (1, 0))).shape
(5, 4, 6)

然而,即使这两个乘积一样,也就是说两个矩阵的 N2 一样,可以进行 dot 运算, tensordot 还是会报错:

>>> np.tensordot(a, b, ((2, 3), (2))).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.5/dist-packages/numpy-1.14.1-py3.5-linux-x86_64.egg/numpy/core/numeric.py", line 1283, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

再看出问题的代码:

if not equal:
    raise ValueError("shape-mismatch for sum")

前面对于 equal 的计算有好多好多行,无非是为了确认, a.shape 中与 axes_a 对应位置的元素和 b.shape 中与 axes_b 对应位置元素 一一对应相等。(当然 axes_aaxes_b 也必须等长。)

此时再回过头看官方文档中的例子就好理解了:

>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400.,  4730.],
       [ 4532.,  4874.],
       [ 4664.,  5018.],
       [ 4796.,  5162.],
       [ 4928.,  5306.]])
>>> # A slower but equivalent way of computing the same...
>>> d = np.zeros((5,2))
>>> for i in range(5):
...   for j in range(2):
...     for k in range(3):
...       for n in range(4):
...         d[i,j] += a[k,n,i] * b[n,k,j]

也就是说,tensordot 要求参数满足上面所说的条件,是为了保证能进行这种循环运算。 即使两个 N2 相等,可以经过重新排布之后得到大小正确的张量, tensordot 也只允许用户进行这种能用循环运算表示的点乘, 避免用户(无意中)进行了没有物理意义的运算。

最后补充一点,dot 也允许用户输入两个大于二维的张量, 此时的运算是以 a 的倒数第二个轴和 b 的倒数第一个轴作为累加轴进行 tensordot。 有人知道为什么这么规定吗?

>>> a = np.ones([5, 4, 2, 3])
>>> c = np.ones([2, 3, 6])
>>> np.dot(a, c).shape
(5, 4, 2, 2, 6)
>>> np.tensordot(a, c, (-1, -2)).shape
(5, 4, 2, 2, 6)