当我argmax
在不常见的3维数组上应用时,遇到了一个问题,就像下面的代码所示:
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 8],
[3, 7, 9]
]
])
print('a, axis=0\n', np.argmax(a, axis=0))
print('a, axis=1\n', np.argmax(a, axis=1))
# print('a, axis=2\n', np.argmax(a, axis=2)) if this activated, Erros appears:numpy.AxisError: axis 2 is out of bounds for array of dimension 2.
b = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 9]
]
])
print('b, axis=0\n', np.argmax(b, axis=0))
print('b, axis=1\n', np.argmax(b, axis=1))
print('b, axis=2\n', np.argmax(b, axis=2))
结果:
a, axis=0
[0 1 1]
a, axis=1
[1 1]
b, axis=0
[[0 0 0 0]
[0 1 0 0]
[1 0 1 1]]
b, axis=1
[[1 2 0 1]
[1 2 2 2]]
b, axis=2
[[1 0 1]
[1 0 2]]
对于array的结果b
,我完全理解它的计算规则,但是对于array a
,根据错误,就我而言,它不是真正的3D数组,但是为什么函数argmax
仍然可以得到其结果呢?函数如何获得结果?
B
是Argmax知道如何导航的3d数组,没关系,但是A
由于A是的2d数组而不同lists
。
当你说
print('a, axis=0\n', np.argmax(a, axis=0))
Argmax看你A
的
[
[list11,list12,list13],
[list21,list22,list23]
]
所以它尝试做的就是从每个
max(list11, list21) , max(list12,list22) , max(list13, list23)
这里的问题是,它依赖于<,>,=
为列表实现的运算符,实现如下
比较使用字典顺序:首先比较前两个项目,如果它们不同,则确定比较的结果;如果它们相等,则比较下两个项目,依此类推,直到用尽任何一个序列。
例
[1,2,3] > [2,3,1] # False
[1111,2,3] > [2,3,1] # True
[1,2,3] > [1,3,1] # False comparing second elements
为了获得更多的直觉,让我们尝试对A
数组进行调整
a = np.array([
[
[1, 5, 5, 2],
[19, -6, 2, 18],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 111118],
(-5555, 7, 9)
]
])
print('a, axis=0\n', np.argmax(a, axis=0))
尝试将元组与列表进行比较将引发错误,因为没有实现此比较的运算符:
TypeError: '>' not supported between instances of 'tuple' and 'list'
因此,在您的情况下,它会比较每个列表中的第一个元素,仅此而已,这会使结果看起来很尴尬,可能对您来说是错误的。
如果必须这样做,则可能需要重写列表的那些运算符才能获得所需的结果,换句话说,实现列表的自定义比较运算符以正确地比较那些列表。
您可以阅读有关python中运算符重载的更多信息,这是一个以以下示例开头的基本示例:https : //www.programiz.com/python-programming/operator-overloading
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句