我从头开始为我的tensorflow模型创建了一个数据集。我正在使用TensorFlow 2.4.0。为了加速,我决定将数据存储为.tfrecord文件类型,现在我要检查数据是否存储在.tfrecord文件中。我写了一个代码在.tfrecord文件中打印一个图像,但是出现以下错误:
imageRaw = imageFeautre['image/width'].numpy()
TypeError: 'TakeDataset' object is not subscriptable
我将自己定位于官方tensorflow教程(https://www.tensorflow.org/tutorials/load_data/tfrecord#write_the_tfrecord_file)
我可以加载数据集并读取它,内容正确。我无法在其中打印一张图像,我只想在其中打印一张图像,但我在网络上找不到解决方案。
这是我的代码:
import tensorflow as tf
import numpy as np
import IPython.display as display
tf.compat.v1.enable_eager_execution()
tfrecordPath='/home/adem/PycharmProjects/dcganAlgorithmus/dataHandler/preparedData/train.tfrecord'
rfrecordDataSet=tf.data.TFRecordDataset(tfrecordPath)
imageFeatureDescription ={
'image/width:':tf.io.FixedLenFeature([],tf.int64),
'image/height':tf.io.FixedLenFeature([], tf.int64),
'image/xmin':tf.io.FixedLenFeature([], tf.int64),
'image/ymin':tf.io.FixedLenFeature([],tf.int64),
'image/xmax':tf.io.FixedLenFeature([],tf.int64),
'image/ymin':tf.io.FixedLenFeature([],tf.int64),
}
def _parse_image_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, imageFeatureDescription)
ParsedImageDataset = rfrecordDataSet.map(_parse_image_function)
imageFeautre=ParsedImageDataset.take(1)
imageRaw = imageFeautre['image/width'].numpy()
display.display(display.Image(data=imageRaw))
您提到了打印图像,但是您的示例显示了提取宽度。这是同时显示两者的示例。
feature_description = {
'image/width': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'image/encoded': tf.io.FixedLenFeature([], tf.string, default_value=''),
...
}
tfrecord_file = 'myfile.tfrecord'
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)
for raw_record in raw_dataset.take(num_records_to_plot):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
record = tf.io.parse_single_example(raw_record, feature_description)
width = record['image/width'].numpy()
image = record['image/encoded']
# Convert image from raw bytes to numpy array
image_decoded = tf.image.decode_image(image)
image_decoded_np = image_decoded.numpy()
....
您可能还需要确保存储有效宽度。这是我创建TFRecord的方法:
from PIL import Image
im = Image.open(file_path)
image_w, image_h = im.size
with tf.io.gfile.GFile(file_path, 'rb') as fid:
encoded_jpg = fid.read()
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/width': dataset_util.int64_feature(image_w),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
...
}
writer.write(tf_example.SerializeToString())
``
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句