TensorFlow中的tf.matmul不广播

阿莱西奥(Alessio B)

我一直在努力解决一个问题。它与tf.matmul()广播有关,并且没有广播。

我在https://github.com/tensorflow/tensorflow/issues/216上发现了类似的问题,但是tf.batch_matmul()对于我的情况而言,这似乎并不是一个解决方案。

我需要将输入数据编码为4D张量:X = tf.placeholder(tf.float32, shape=(None, None, None, 100))第一个维度是批处理的大小,第二个维度是批处理中的条目数。您可以将每个条目想象成由许多对象(三维)组成。最后,每个对象由100个浮点值的向量描述。

请注意,我在第二和第三个维度中使用了“无”,因为实际尺寸可能会在每批中发生变化。但是,为简单起见,让我们用实际数字塑造张量:X = tf.placeholder(tf.float32, shape=(5, 10, 4, 100))

这些是我计算的步骤:

  1. 计算100个浮点值的每个向量的函数(例如,线性函数)W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1)) Y = tf.matmul(X, W) 问题tf.matmul()使用tf.batch_matmul()Y的预期形状进行广播或成功:(5,10,4,50)

  2. 对批次的每个条目(在每个条目的对象上)应用平均池:Y_avg = tf.reduce_mean(Y, 2)Y_avg的预期形状:(5,10,50)

我希望那tf.matmul()会支持广播。然后我发现tf.batch_matmul(),但看起来仍然不适用于我的情况(例如W需要至少具有3个维度,不清楚原因)。

顺便说一句,上面我使用了一个简单的线性函数(其权重存储在W中)。但是在我的模型中,我有一个深层的网络。因此,我遇到的更普遍的问题是为张量的每个切片自动计算一个函数。这就是为什么我希望那tf.matmul()会产生广播行为(如果是这样,也许tf.batch_matmul()甚至没有必要)。

期待您的学习!亚历西奥

您可以通过重塑X为shape来实现这一点[n, d],其中d,单个计算“实例”的维数(在您的示例中n100),在多维对象中为这些实例的数量(5*10*4=200在您的示例中)。重塑后,您可以使用tf.matmul然后重塑回所需的形状。前三个尺寸可以变化,这一点有点棘手,但是您可以tf.shape用来确定运行时的实际形状。最后,您可以执行计算的第二步,tf.reduce_mean在相应的维度上应该很简单总而言之,它看起来像这样:

X = tf.placeholder(tf.float32, shape=(None, None, None, 100))
W = tf.Variable(tf.truncated_normal([100, 50], stddev=0.1))
X_ = tf.reshape(X, [-1, 100])
Y_ = tf.matmul(X_, W)
X_shape = tf.gather(tf.shape(X), [0,1,2]) # Extract the first three dimensions
target_shape = tf.concat(0, [X_shape, [50]])
Y = tf.reshape(Y_, target_shape)
Y_avg = tf.reduce_mean(Y, 2)

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

TensorFlow中的tf.matmul不广播

来自分类Dev

TensorFlow tf.sparse_tensor_dense_matmul

来自分类Dev

MainActivity 中的 BroadcastReceiver 不接收广播意图

来自分类Dev

Tensorflow:tf.contrib.data 中 dataset.map() 的类型不兼容

来自分类Dev

Tensorflow - 转换特征的 tf.matmul 和作为批处理 matmul 的向量

来自分类Dev

Tensorflow中的广播动态尺寸

来自分类Dev

在Tensorflow 2.2中将tf.metrics.MeanIoU()与SparseCategoricalCrossEntropy损失一起使用时,尺寸不匹配错误

来自分类Dev

tensorflow:使用 tf.matmul 将稀疏矩阵与密集矩阵相乘时出错

来自分类Dev

什么是tensorflow.matmul?

来自分类Dev

严重的LSTM(Keras,TensorFlow)ValueError:形状不匹配:无法将对象广播为单个形状

来自分类Dev

AngularJS 1.5中的组件不监听来自$ scope的广播事件

来自分类Dev

BroadcastReceiver不接收广播

来自分类Dev

BroadcastReceiver不接收广播

来自分类Dev

tf.matmul无法正常工作

来自分类Dev

TensorFlow中的tf.nn.sigmoid实现

来自分类Dev

在Tensorflow中检查广播兼容性-C ++ API

来自分类Dev

Tensorflow - 使用 tf.losses.hinge_loss 导致形状不兼容错误

来自分类Dev

tensorflow batch_matmul如何工作?

来自分类Dev

使用Tensorflow执行Matmul时发生ValueError

来自分类Dev

GPU上的Tensorflow Matmul计算比CPU慢

来自分类Dev

Keras / TF的不兼容形状错误中的未知值

来自分类Dev

Wifi扫描结果不广播

来自分类Dev

python中的tensorflow版本控制不匹配

来自分类Dev

Tensorflow尺寸在CNN中不兼容

来自分类Dev

Tensorflow 中的 Logits 和标签不匹配

来自分类Dev

张量流中“MatMul”的ValueError

来自分类Dev

使用tf广播发布位置

来自分类Dev

Tensorflow 2 API:不推荐使用名称tf.get_default_graph。请改用tf.compat.v1.get_default_graph

来自分类常见问题

如何在Tensorflow中从tf.keras导入keras?

Related 相关文章

  1. 1

    TensorFlow中的tf.matmul不广播

  2. 2

    TensorFlow tf.sparse_tensor_dense_matmul

  3. 3

    MainActivity 中的 BroadcastReceiver 不接收广播意图

  4. 4

    Tensorflow:tf.contrib.data 中 dataset.map() 的类型不兼容

  5. 5

    Tensorflow - 转换特征的 tf.matmul 和作为批处理 matmul 的向量

  6. 6

    Tensorflow中的广播动态尺寸

  7. 7

    在Tensorflow 2.2中将tf.metrics.MeanIoU()与SparseCategoricalCrossEntropy损失一起使用时,尺寸不匹配错误

  8. 8

    tensorflow:使用 tf.matmul 将稀疏矩阵与密集矩阵相乘时出错

  9. 9

    什么是tensorflow.matmul?

  10. 10

    严重的LSTM(Keras,TensorFlow)ValueError:形状不匹配:无法将对象广播为单个形状

  11. 11

    AngularJS 1.5中的组件不监听来自$ scope的广播事件

  12. 12

    BroadcastReceiver不接收广播

  13. 13

    BroadcastReceiver不接收广播

  14. 14

    tf.matmul无法正常工作

  15. 15

    TensorFlow中的tf.nn.sigmoid实现

  16. 16

    在Tensorflow中检查广播兼容性-C ++ API

  17. 17

    Tensorflow - 使用 tf.losses.hinge_loss 导致形状不兼容错误

  18. 18

    tensorflow batch_matmul如何工作?

  19. 19

    使用Tensorflow执行Matmul时发生ValueError

  20. 20

    GPU上的Tensorflow Matmul计算比CPU慢

  21. 21

    Keras / TF的不兼容形状错误中的未知值

  22. 22

    Wifi扫描结果不广播

  23. 23

    python中的tensorflow版本控制不匹配

  24. 24

    Tensorflow尺寸在CNN中不兼容

  25. 25

    Tensorflow 中的 Logits 和标签不匹配

  26. 26

    张量流中“MatMul”的ValueError

  27. 27

    使用tf广播发布位置

  28. 28

    Tensorflow 2 API:不推荐使用名称tf.get_default_graph。请改用tf.compat.v1.get_default_graph

  29. 29

    如何在Tensorflow中从tf.keras导入keras?

热门标签

归档