Tensorflow 2.0:相当于numpy。take_along_axis

朱塞佩·安哥拉(Giuseppe Angora)

这是我的问题:我实现了一个简单的函数,该函数返回组织为矩阵的信号峰值。

@tf.function
def get_peaks(X, X_err):
    prominence = 0.9
    # X shape (B, N, 1)
    max_pooled = tf.nn.pool(X, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(X, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(X * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = X_err * tf.squeeze(maxima)
    peaks_idxs, idxs = tf.math.top_k(peaks, k=2)
    return peaks_idxs, idxs 

如您所见,输入具有shape (B, N, 1),即批处理样本,每个样本都是N个元素的一维向量。返回的idxs也是正确的peaks_idxs,它们的形状为(B,2),即批次中每个样品的两个最大值的位置(和峰值)。

问题是我也想拿peak_err对应的idxs随着numpy我将使用:

np.take_along_axis(peaks_err, idxs, axis=1)

实际上返回正确的shape矩阵(B, 2)如何使用tf做同样的事情?我实际上已经尝试使用tf.gather

tf.gather(peaks_err, idxs, axis=1)

但它不起作用,对于形状(B,B,2)和很多零,结果不正确。你知道我该怎么解决吗?谢谢!

朱塞佩·安哥拉(Giuseppe Angora)

我解决了添加三行的问题:

@tf.function
def get_local_maxima3(XC, SXC):
    prominence = 0.9
    # x shape (1, N, 1)
    max_pooled = tf.nn.pool(XC, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(XC, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(XC * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = SXC * tf.squeeze(maxima)
    #maxima = tf.where(tf.greater(peaks, prominence)) # shape (N,)
    peaks, idxs = tf.math.top_k(peaks, k=2)

    idxs_shape = tf.shape(idxs)
    grid = tf.meshgrid(*(tf.range(idxs_shape[i]) for i in range(idxs.shape.ndims)), indexing='ij')
    index_full = tf.stack(grid[:-1] + [idxs], axis=-1)
    peaks_err = tf.gather_nd(peaks_err, index_full)
    return peaks, peaks_err

有用!如果您找到/有一个更聪明/更快的解决方案,我将不胜感激。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

tensorflow 相当于 pytorch ReplicationPad2d

来自分类Dev

Tensorflow:相当于tf批次的Numpy和重塑

来自分类Dev

Tensorflow,相当于Theano的pydotprint?

来自分类Dev

相当于Tensorflow线性层的PyTorch

来自分类Dev

相当于Tensorflow损失函数的PyTorch

来自分类Dev

Tensorflow相当于Theano的dimshuffle

来自分类Dev

相当于np.in1d的TensorFlow

来自分类Dev

ng2:相当于require

来自分类Dev

2D numpy数组搜索(相当于Matlab的相交“行”选项)

来自分类常见问题

相当于$ document.ready()的Angular2

来自分类Dev

Angular2中的工厂相当于什么?

来自分类Dev

相当于MATLAB ind2sub的Python

来自分类Dev

H2相当于Oracle的用户

来自分类Dev

roxygen2相当于python

来自分类Dev

相当于GtkSourceView2的Python3模块

来自分类Dev

有相当于word2vec的图像吗?

来自分类Dev

CameraX相当于Camera2的CaptureRequest

来自分类Dev

H2相当于Oracle的用户

来自分类Dev

相当于tcp_retries2的Aix

来自分类Dev

相当于Python 2中心的Python 3

来自分类Dev

Angular2中的工厂相当于什么?

来自分类Dev

相当于 ionic 2 上的 style.css

来自分类Dev

相当于 Mage::helper('core')-> 的 magento 2 是什么?

来自分类Dev

Tensorflow numpy 到 tensorflow

来自分类Dev

Tensorflow 2:嵌套TensorArray

来自分类Dev

numpy相当于熊猫

来自分类Dev

Scala:相当于numpy.where()[0]

来自分类Dev

HK2相当于@Provides的Guice for Jersey 2

来自分类Dev

Tensorflow:numpy.take 的模拟?