计算矩阵的log_softmax

这个运算在人工智能网络的损失函数中能够经常看的到。假设给定一个数组[x0, x1, x2],其本质的运算是要对每一个元素得到一个值 y = log(e^xi/(e^x0 + e^x1 + e^x2))。

那么基于这个思想,按常规写法可能会是这样:

a = torch.tensor([x0, x1, x2])
b = torch.tensor(list(map(lambda x: log(e^x/(e^x0 + e^x1 + e^x2)), a)))

然后为了对其优化一下,把每一次都需要运算e^x0 + e^x1 + e^x2提取出来,可能会是这样:

a = torch.tensor([x0, x1, x2])
sum = e^x0 + e^x1 + e^x2
b = torch.tensor(list(map(lambda x: log(e^x/sum), a)))

可能对于我这种还不擅长直接应用numpy, tensor这些数组的人,走到这里可能就到尽头了。但是,但是,但是!如果能够更合理的应用numpy,或者tensor数组提供的便利,以上代码还是可以更进一步优化的:

a = torch.tensor([x0, x1, x2])
b = a - a.exp().sum(-1).log()

这里先利用对数公式,将上面的log(e^x/sum) 转化为 x - log(sum)。然后这里就变成了数组之间的线性组合了,于是乎就能直接做数组的运算。如果a是2维数组,则上述运算可以这么写:

a = torch.tensor([[x0, x1, x2], [y0, y1, y2]])
b = a - a.exp().sum(-1).log().unsqueeze(-1)

从最开始我们想要运算每一个值的 y = log(e^xi/(e^x0 + e^x1 + e^x2)),到最后这一行代码解决数组整列的运算。简直优雅到了一定程度!我认为这段优秀的代码给我们提供了2个关键的思考方式:

  1. 对于基于矩阵的运算,需要想办法利用矩阵提供的迭代式运算便利。
  2. 要想合理的应用矩阵提供的迭代式运算便利,这需要把那个被迭代的值尽可能从公式中独立出来,让它只和其他的部分做加减乘除的运算。

在上述例子中, y = log(e^xi/(e^x0 + e^x1 + e^x2))的xi是唯一迭代式的值,我们需要把它独立出来,将公式转化为y = xi - log(e^x0 + e^x1 + e^x2)。

上面那段优雅的代码里,我们除了能够学习到上面2条思考方式以外,还能得到一个数组话思考的方式。那就是使用unsqueeze增加维度来做数组运算。数组和常数是可以做加减乘除的。常数的运算会被应用在数组的最后一个轴上的每一个运算上。但是数组之间要做加减乘除运算则需要满足一些基本条件。我在这篇文献中详细描述了不同维度数组之间做运算的条件。

发表评论

邮箱地址不会被公开。 必填项已用*标注