定义numpy索引数组

克里斯珀

我对numpy索引感到困惑。假设我有一个三维数组,例如:

test_arr = np.arange(3*2*3).reshape(3,2,3)
test_arr
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]]])

我想通过沿维度1的布尔数组对此进行索引:

dim1_idx = np.array([True, False])
test_arr[:, dim1_idx, :]

这给了我

array([[[ 0,  1,  2]],

       [[ 6,  7,  8]],

       [[12, 13, 14]]])

到目前为止一切都很好。

我的问题是,有没有一种方法可以预先定义此布尔值索引数组-就像(这不起作用):

all_dim_idx = dim1_idx[np.newaxis, :, np.newaxis]
test_arr[all_dim_idx]

我意识到这样做的原因不是因为它不能以使all_dim_idx数组适合test_arr的方式进行广播。我可以使用np.tile或np.reshape来使索引数组适合更大的数组,但是(以及随后不能推广到其他数组形状)我只是觉得可能有更好的方法。谁能启发我?

提前致谢!

hpaulj
In [600]: test_arr = np.arange(3*2*3).reshape(3,2,3)                            
In [601]: dim1_idx = np.array([True, False])                                    

定义一个索引元组:

In [602]: idx = (slice(None), dim1_idx, slice(None))                            
In [603]: test_arr[idx]                                                         
Out[603]: 
array([[[ 0,  1,  2]],

       [[ 6,  7,  8]],

       [[12, 13, 14]]])

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章