Tensorflow 2.0:如何从MapDataset(从TFRecord读取后)转换为可以输入到model.fit的某些结构

阿尔贝托A

我已经将训练和验证数据存储在两个单独的TFRecord文件中,在其中存储了4个值:信号A(float32形状(150,)),信号B(float32形状(150)),标签(标量int64), id(字符串)。我的阅读解析功能是:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(sample_proto, raw_signal_description)

SIGNALS字典映射信号名称在哪里->信号形状然后,我读取了原始数据集:

training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

并使用map解析值:

training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

显示training_data的标题val_data,我得到:

<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>

这几乎与预期的一样。我还检查了一些值的一致性,它们似乎是正确的。

现在,我要解决的问题是:如何从MapDataset(具有类似于结构的字典)到可以作为模型输入的对象?

我模型的输入是信号对(信号A,标签),尽管将来我也会使用信号B。

对我而言,最简单的方法似乎是在所需元素上创建生成器。就像是:

def data_generator(mapdataset):
    for sample in mapdataset:
        yield (sample['Signal A'], sample['label'])

但是,使用这种方法时,我失去了一些数据集的便利,例如批处理,并且也不清楚如何对的validation_data参数使用相同的方法model.fit理想情况下,我只会在映射表示形式和数据集表示形式之间进行转换,在该表示形式上对信号A张量和标签对进行迭代。

编辑:我的最终产品应该是类似于以下内容的标题:<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>但不一定TensorSliceDataset

x乘0

您只需在解析函数中执行此操作即可。例如:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)

    return parsed['Signal A'], parsed['label']

如果您在上使用map此函数TFRecordDataset,则将有一个元组 数据集,(signal_a, label)而不是词典数据集。您应该可以model.fit直接将其放入

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

Related 相关文章

热门标签

归档