Tensorflow 2.0:在多输入场景中构造tf.data.Dataset输出的最佳方法

ElPapi42

我在Tensorflow上构建GAN进行图像去模糊,这是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。按照这一行,我将输入设计为带有两个Key的Python字典['sharp', 'blur'],每个Key具有一个张量形状[batch_size, 512, 512, 3],这使得将模糊的图像批处理容易地馈送到生成器,然后馈入生成器和清晰图像的输出变得容易了。批判者。

根据最后的要求,我创建了一个tf.data.Dataset输出准确的输出,一个包含两个张量的字典,每个张量都有其批处理尺寸。这与我的GAN实施相得益彰,一切运行顺利。

因此请记住,我的输入不是张量,而是没有批处理维的python dict,这与稍后解释我的问题有关。

最近,我决定使用Tensorflow分配策略增加对分布式培训的支持。Tensorflow的此功能允许将训练分布在多个设备上,包括在多台机器上。某些实现具有一项功能,例如MirroredStrategy,采用输入张量,将其张成相等的部分,并将每个切片馈送到不同的设备,这意味着,如果批处理大小为16和4个GPU,则每个GPU将结束一个本地批处理的4个数据点,在此之后,汇总结果和与我的问题无关的其他内容有些神奇。

正如您已经注意到的那样,对于将张量作为输入或至少具有外部批处理维度的某种输入作为分布策略至关重要,而我拥有的是Python dict,内部字典中具有输入的批处理维度张量值。这是一个很大的问题,我当前的实现与分布式培训不兼容。

我一直在寻找解决方法,但是我无法很好地解决这个问题,也许只是将输入变成张量shape=[batch_size, 2, 512, 512, 3]并切成薄片?不确定现在才想到这大声笑。无论如何,我都觉得这很模棱两可,我无法区分这两个输入,至少不能清楚地说明字典键。编辑:此解决方案的问题是,使我的数据集转换非常昂贵,因此使数据集吞吐速度变慢,考虑到这是图像加载管道,这是重点。

也许我对分布式策略的工作方式的解释不是最严格的,如果我看不到有什么可以纠正我的想法的话。

PD:这不是错误问题或代码错误,主要是“系统设计查询”,希望这里不是违法的

松散的

您可以尝试通过以下方式映射函数,而不是使用字典作为GAN的输入,

def load_image(fileA,fileB):
    imageA = tf.io.read_file(fileA)
    imageA = tf.image.decode_jpeg(imageA, channels=3)

    imageB = tf.io.read_file(fileB)
    imageB = tf.image.decode_jpeg(imageB)
    return imageA,imageB

trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)

#for mirrored strategy

dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)

您可以通过传递两个图像来迭代数据集并更新网络。
我希望这有帮助 !

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

如何在Tensorflow tf.data.Dataset中使用cv2图像增强功能?

来自分类Dev

我如何在Tensorflow 2 LSTM培训中屏蔽多输出?

来自分类Dev

Tensorflow 2中的fit方法中使用Dataset和ndarray有什么区别?

来自分类Dev

在TensorFlow中展平包含向量的2D张量的最佳方法?

来自分类Dev

Tensorflow:tf.contrib.data 中 dataset.map() 的类型不兼容

来自分类Dev

如何使用 Tensorflow 中的其他示例转换扩展 tf.data.Dataset

来自分类Dev

tensorflow.data.Dataset Repeat(count = None)方法如何工作

来自分类Dev

tensorflow中conv2d的顺序输出是什么?

来自分类Dev

如何在TensorFlow回归中指定2个或更多输出标签

来自分类Dev

Tensorflow:如何从CPU tf.data.Dataset(from_generator)预取GPU上的数据

来自分类Dev

使用地图功能时,Tensorflow tf.data.Dataset错误 KeyError

来自分类Dev

在Tensorflow 2中找不到Tensorflow模块,在哪里可以找到新方法的文档?

来自分类Dev

tensorflow 2:使用隐藏层输出的损失

来自分类Dev

在TensorFlow中批量访问单个梯度的最佳方法是什么?

来自分类Dev

在tensorflow中仅一次读取数据的最佳方法是?

来自分类Dev

Tensorflow 2.0:我可以更改Tf.data.Dataset上的设置-特别是`repeat()`功能吗?

来自分类Dev

Tensorflow image_dataset_from_directory用于输入数据集和输出数据集

来自分类Dev

tensorflow.keras.dataset.minst.load_data()返回的解释

来自分类Dev

TensorFlow.Data.Dataset与DatasetV1Adapter相同吗?

来自分类Dev

Tensorflow 2中的tf.contrib.layers.fully_connected()吗?

来自分类Dev

在tf 2.x中以图形模式运行TensorFlow op

来自分类Dev

Tensorflow 2中tf.variable的条件赋值

来自分类Dev

tensorflow tf.nn.conv2d 中的特征数

来自分类Dev

TensorFlow word2vec 教程输入

来自分类Dev

在TensorFlow 2.0中使用tf.Dataset进行训练

来自分类Dev

TensorFlow 2-tf.keras:如何使用tf.data API和TFRecords训练像MTCNN这样的tf.keras多任务网络

来自分类Dev

重塑输出或tf.gather(),tensorflow

来自分类Dev

重塑输出或tf.gather(),tensorflow

来自分类Dev

为什么您必须重塑Keras / Tensorflow 2中的输入?

Related 相关文章

  1. 1

    如何在Tensorflow tf.data.Dataset中使用cv2图像增强功能?

  2. 2

    我如何在Tensorflow 2 LSTM培训中屏蔽多输出?

  3. 3

    Tensorflow 2中的fit方法中使用Dataset和ndarray有什么区别?

  4. 4

    在TensorFlow中展平包含向量的2D张量的最佳方法?

  5. 5

    Tensorflow:tf.contrib.data 中 dataset.map() 的类型不兼容

  6. 6

    如何使用 Tensorflow 中的其他示例转换扩展 tf.data.Dataset

  7. 7

    tensorflow.data.Dataset Repeat(count = None)方法如何工作

  8. 8

    tensorflow中conv2d的顺序输出是什么?

  9. 9

    如何在TensorFlow回归中指定2个或更多输出标签

  10. 10

    Tensorflow:如何从CPU tf.data.Dataset(from_generator)预取GPU上的数据

  11. 11

    使用地图功能时,Tensorflow tf.data.Dataset错误 KeyError

  12. 12

    在Tensorflow 2中找不到Tensorflow模块,在哪里可以找到新方法的文档?

  13. 13

    tensorflow 2:使用隐藏层输出的损失

  14. 14

    在TensorFlow中批量访问单个梯度的最佳方法是什么?

  15. 15

    在tensorflow中仅一次读取数据的最佳方法是?

  16. 16

    Tensorflow 2.0:我可以更改Tf.data.Dataset上的设置-特别是`repeat()`功能吗?

  17. 17

    Tensorflow image_dataset_from_directory用于输入数据集和输出数据集

  18. 18

    tensorflow.keras.dataset.minst.load_data()返回的解释

  19. 19

    TensorFlow.Data.Dataset与DatasetV1Adapter相同吗?

  20. 20

    Tensorflow 2中的tf.contrib.layers.fully_connected()吗?

  21. 21

    在tf 2.x中以图形模式运行TensorFlow op

  22. 22

    Tensorflow 2中tf.variable的条件赋值

  23. 23

    tensorflow tf.nn.conv2d 中的特征数

  24. 24

    TensorFlow word2vec 教程输入

  25. 25

    在TensorFlow 2.0中使用tf.Dataset进行训练

  26. 26

    TensorFlow 2-tf.keras:如何使用tf.data API和TFRecords训练像MTCNN这样的tf.keras多任务网络

  27. 27

    重塑输出或tf.gather(),tensorflow

  28. 28

    重塑输出或tf.gather(),tensorflow

  29. 29

    为什么您必须重塑Keras / Tensorflow 2中的输入?

热门标签

归档