私argmax
は次のように定義されているPyTorchの機能に取り組んでいます:
torch.argmax(input, dim=None, keepdim=False)
例を考えてみましょう
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
ここで、列ベクトルを検索する代わりにdim = 1を使用すると、関数は次のように行ベクトルを検索します。
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
私の仮定が進む限り、dim = 0は行を表し、dim = 1は列を表します。
または引数がPyTorchでどのように機能するかを正しく理解する時が来ました。axis
dim
次の例は、上の図を理解したら意味があります。
| v dim-0 ---> -----> dim-1 ------> -----> --------> dim-1 | [[-1.7739, 0.8073, 0.0472, -0.4084], v [ 0.6378, 0.6575, -1.2970, -0.0625], | [ 1.7970, -1.3463, 0.9011, -0.8704], v [ 1.5639, 0.7123, 0.0385, 1.8410]] | v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
注:( 「dimension」のdim
略)は、NumPyの「axis」に相当するトーチです。
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加