Tensorflow回调作为CTC的自定义指标

乔尼·范·普布鲁克

为了在训练模型(使用TensorFlow版本2.1.0编写)时产生更多度量,例如字符错误率(CER)和字错误率(WER),我创建了一个回调传递给我的fit函数模型。它可以在一个时代结束时生成CER和WER。这是我的第二选择,因为我想为此创建自定义指标,但是您只能将Keras Backend功能用于自定义指标。是否有人对如何将下面的回调转换为自定义指标有任何建议(可以在验证和/或培训数据的培训期间进行计算)?

我遇到的一些障碍是:

  • 无法将K.ctc_decode结果转换为稀疏张量
  • 您如何使用Keras后端计算距离,例如editdistance?
class Metrics(tf.keras.callbacks.Callback):
    def __init__(self, valid_data, steps):
        """
        valid_data is a TFRecordDataset with batches of 100 elements per batch, shuffled and repeated infinitely. 
        steps defines the amount of batches per epoch
        """
        super(Metrics, self).__init__()
        self.valid_data = valid_data
        self.steps = steps

    def on_train_begin(self, logs={}):
        self.cer = []
        self.wer = []

    def on_epoch_end(self, epoch, logs={}):

        imgs = []
        labels = []
        for idx, (img, label) in enumerate(self.valid_data.as_numpy_iterator()):
            if idx >= self.steps:
                break
            imgs.append(img)
            labels.extend(label)

        imgs = np.array(imgs)
        labels = np.array(labels)

        out = self.model.predict((batch for batch in imgs))        
        input_length = len(max(out, key=len))

        out = np.asarray(out)
        out_len = np.asarray([input_length for _ in range(len(out))])

        decode, log = K.ctc_decode(out,
                                    out_len,
                                    greedy=True)

        decode = [[[int(p) for p in x if p != -1] for x in y] for y in decode][0]

        for (pred, lab) in zip(decode, labels):

            dist = editdistance.eval(pred, lab)
            self.cer.append(dist / (max(len(pred), len(lab))))
            self.wer.append(not np.array_equal(pred, lab))


        print("Mean CER: {}".format(np.mean([self.cer], axis=1)[0]))
        print("Mean WER: {}".format(np.mean([self.wer], axis=1)[0]))
乔尼·范·普布鲁克

在TF 2.3.1中解决,但也应适用于2.x的早期版本。

一些说明:

  • 缺乏有关如何正确实施Tensorflow自定义指标的信息。该问题暗示要使用回调来实现指标。结果(由于对度量进行了明确的额外计算on_epoch_end),因此需要更长的时间,或者我相信。将其实现为的子类tensorflow.keras.metrics.Metric似乎是正确的方法,并且在时期进行时会产生结果(如果详细设置正确)。
  • 使用tf.edit_distance(使用稀疏张量)可以很容易地计算出CER的编辑距离,随后可以使用某些tf逻辑将其用于计算WER。
  • las,我还没有找到如何在一个度量标准中同时实现CER和WER的方法(因为它有很多重复的代码),如果有人知道如何实现,请与我联系。
  • 可以将自定义指标简单地添加到TF模型的编译中: self.model.compile(optimizer=opt, loss=loss, metrics=[CERMetric(), WERMetric()])
class CERMetric(tf.keras.metrics.Metric):
    """
    A custom Keras metric to compute the Character Error Rate
    """
    def __init__(self, name='CER_metric', **kwargs):
        super(CERMetric, self).__init__(name=name, **kwargs)
        self.cer_accumulator = self.add_weight(name="total_cer", initializer="zeros")
        self.counter = self.add_weight(name="cer_count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        input_shape = K.shape(y_pred)
        input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')

        decode, log = K.ctc_decode(y_pred,
                                    input_length,
                                    greedy=True)

        decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
        y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))

        decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
        distance = tf.edit_distance(decode, y_true_sparse, normalize=True)

        self.cer_accumulator.assign_add(tf.reduce_sum(distance))
        self.counter.assign_add(len(y_true))

    def result(self):
        return tf.math.divide_no_nan(self.cer_accumulator, self.counter)

    def reset_states(self):
        self.cer_accumulator.assign(0.0)
        self.counter.assign(0.0)


class WERMetric(tf.keras.metrics.Metric):
    """
    A custom Keras metric to compute the Word Error Rate
    """
    def __init__(self, name='WER_metric', **kwargs):
        super(WERMetric, self).__init__(name=name, **kwargs)
        self.wer_accumulator = self.add_weight(name="total_wer", initializer="zeros")
        self.counter = self.add_weight(name="wer_count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        input_shape = K.shape(y_pred)
        input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')

        decode, log = K.ctc_decode(y_pred,
                                    input_length,
                                    greedy=True)

        decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
        y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))

        decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
        distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
        
        correct_words_amount = tf.reduce_sum(tf.cast(tf.not_equal(distance, 0), tf.float32))

        self.wer_accumulator.assign_add(correct_words_amount)
        self.counter.assign_add(len(y_true))

    def result(self):
        return tf.math.divide_no_nan(self.wer_accumulator, self.counter)

    def reset_states(self):
        self.wer_accumulator.assign(0.0)
        self.counter.assign(0.0)

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

自定义代表回调

来自分类Dev

JavaScript:自定义回调函数

来自分类Dev

Nightwatch自定义命令回调

来自分类Dev

提供自定义组件的回调

来自分类Dev

Javascript 自定义回调函数

来自分类Dev

使用自定义数据作为函数参数调用lua回调

来自分类Dev

在Tensorflow Keras中创建自定义指标类

来自分类Dev

使用Tensorflow 2.1的Keras模型的自定义指标

来自分类Dev

在ruby方法上定义自定义回调

来自分类Dev

在ruby方法上定义自定义回调

来自分类Dev

自定义指标混乱

来自分类Dev

自定义指标功能

来自分类Dev

Mockito InvalidUseOfMatchersException,当尝试使用自定义回调作为参数对方法进行单元测试时

来自分类Dev

如何在从自定义回调获取的指标上使用tf.summary.scalar()生成的Tensorboard中的单个图形而不是2中绘制数据?

来自分类Dev

如何从回调访问Polymer自定义元素

来自分类Dev

自定义事件与ReactJS中的回调

来自分类Dev

Passportjs自定义回调-'密码错误'消息

来自分类Dev

如何创建执行回调的树枝自定义标签?

来自分类Dev

Express.js中的自定义回调获取

来自分类Dev

自定义javascript函数/ jQuery回调的正确语法

来自分类Dev

在翻新回调中接收自定义参数

来自分类Dev

NSMutableSet与自定义isEqual:和哈希回调

来自分类Dev

通过自定义回调函数的Woocommerce订单

来自分类Dev

自定义挂钩/回调/宏方法

来自分类Dev

在回调中反应自定义钩子

来自分类Dev

在回调函数中使用自定义钩子

来自分类Dev

Express.js中的自定义回调获取

来自分类Dev

Ajax可自定义的错误回调函数

来自分类Dev

jQuery函数自定义加扰回调