计算分类值的精确度

我们直接看下面一段代码:

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

这是一个计算分类数据精确度的函数,这段代码简直是神一样的代码,让人惊叹到下巴都掉到火星上去了!实现这么复杂的任务,居然能够如此优雅!

假设函数的参数out = [[a00, a01, a02], [a10, a11, a12]],其中0维度中每一项的数据对应这一个实体在第1维度的3个可能的权值。第1维度中,我们判定那个值最大,则认可该值对应的项目为所选项目。yb = [b0, b1]表示期望值。例如,a00, a01, a02中,如果a02最大,那么我们认为这一项对应的值为2, 如果b0也等于2,那么我们就说out的第一项判断是准确的。否则,则为错误。我们把准确的记为1, 错误的记为0, 然后统计这些1和0的平均值,并认为这个平均值为out中的准确度。如果是我来解决这个问题,我的代码可能就是这个样子了:

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    total = 0
    def f(i):
        if(preds[i] == yb[i]):
          total += 1
    for i in range(preds.shape[0]):
        if(preds[i] == yb[i]):
          total += 1
    return total / preds.shape[0]

我感觉自己永远需要铭记一点,任何时候我打算用for来遍历一个数组时,我都应该想想有没有可能直接使用数组运算来代替!数组思维!数组思维!目前的我在数据处理的问题上,太缺乏数组思维了!

相比而言,最上面那段代码太完美了!

使用(preds == yb)运算,得到一个bool型的数组,(preds==yb).float()将数组转化为float类型,然后(preds==yb).float().mean()对所有数据求平均值!完美!完美!完美!这段代码中对数组的操作应用得太成熟了!

那么,我应该怎么样思考,才能写出那么优雅的代码呢?以下是我的一些想法:

前面我也提到"我们把准确的记为1, 错误的记为0, 然后统计这些1和0的平均值,并认为这个平均值为out中的准确度。",那么在这里,我就应该想办法要去构建一个描述可能正确与否的一个数组。这要这么想,如果我又熟悉数组级别的条件判断应用,我也应该能写出preds == yb来!

总结起来,我与别人差距至少有以下2个点:

  1. 缺乏构建数组的思维
  2. 对数组的条件判断运算操作缺乏认知

发表评论

邮箱地址不会被公开。 必填项已用*标注