我在Tensorflow上构建GAN进行图像去模糊,这是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。按照这一行,我将输入设计为带有两个Key的Python字典['sharp', 'blur']
,每个Key具有一个张量形状[batch_size, 512, 512, 3]
,这使得将模糊的图像批处理容易地馈送到生成器,然后馈入生成器和清晰图像的输出变得容易了。批判者。
根据最后的要求,我创建了一个tf.data.Dataset
输出准确的输出,一个包含两个张量的字典,每个张量都有其批处理尺寸。这与我的GAN实施相得益彰,一切运行顺利。
因此请记住,我的输入不是张量,而是没有批处理维的python dict,这与稍后解释我的问题有关。
最近,我决定使用Tensorflow分配策略增加对分布式培训的支持。Tensorflow的此功能允许将训练分布在多个设备上,包括在多台机器上。某些实现具有一项功能,例如MirroredStrategy
,采用输入张量,将其张成相等的部分,并将每个切片馈送到不同的设备,这意味着,如果批处理大小为16和4个GPU,则每个GPU将结束一个本地批处理的4个数据点,在此之后,汇总结果和与我的问题无关的其他内容有些神奇。
正如您已经注意到的那样,对于将张量作为输入或至少具有外部批处理维度的某种输入作为分布策略至关重要,而我拥有的是Python dict,内部字典中具有输入的批处理维度张量值。这是一个很大的问题,我当前的实现与分布式培训不兼容。
我一直在寻找解决方法,但是我无法很好地解决这个问题,也许只是将输入变成张量shape=[batch_size, 2, 512, 512, 3]
并切成薄片?不确定现在才想到这大声笑。无论如何,我都觉得这很模棱两可,我无法区分这两个输入,至少不能清楚地说明字典键。编辑:此解决方案的问题是,使我的数据集转换非常昂贵,因此使数据集吞吐速度变慢,考虑到这是图像加载管道,这是重点。
也许我对分布式策略的工作方式的解释不是最严格的,如果我看不到有什么可以纠正我的想法的话。
PD:这不是错误问题或代码错误,主要是“系统设计查询”,希望这里不是违法的
您可以尝试通过以下方式映射函数,而不是使用字典作为GAN的输入,
def load_image(fileA,fileB):
imageA = tf.io.read_file(fileA)
imageA = tf.image.decode_jpeg(imageA, channels=3)
imageB = tf.io.read_file(fileB)
imageB = tf.image.decode_jpeg(imageB)
return imageA,imageB
trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)
#for mirrored strategy
dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)
您可以通过传递两个图像来迭代数据集并更新网络。
我希望这有帮助 !
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句