Tensorflow 中的不平衡二元分类

妈妈福库

我正在尝试使用 tensorflow (V.1.1.0) 在输出层使用单个神经元执行二元分类。下面的代码段对应于我目前使用的损失函数和优化器(灵感来自这里的答案)。

ratio=.034 #minority/population ratio
learning_rate=0.001
class_weight=tf.constant([[ratio,1.0-ratio]],name='unbalanced_ratio') #weight vector, (lab_feed is one_hot labels)
weight_per_label=tf.transpose(tf.matmul(lab_feed,tf.transpose(class_weight)),name='weights_per_label')
xent=tf.multiply(weight_per_label,tf.nn.sigmoid_cross_entropy_with_logits(labels=lab_feed,logits=output),name='loss')
loss=tf.reduce_mean(xent)
optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate,name='GradientDescent').minimize(loss)

然而,我的问题是,由于某种原因,所有实例在历元进展后都被归类为同一类。我必须在中间停止训练还是损失函数有问题?

在此处输入图片说明

P-Gn

您误用了 sigmoid 交叉熵,就好像它是 softmax 交叉熵一样。

Sigmoid 交叉熵适用于二元分类——你的问题是二元分类,所以没关系。但是,对于每个二元分类任务,网络的输出应该只有一个通道——在你的例子中,你有一个二元分类任务,所以你的网络应该只有一个输出通道。

要平衡 sigmoid 交叉熵,您需要平衡交叉熵的每个单独部分,即来自正的部分和来自负的部分。这不能像您所做的那样在输出上完成,因为输出已经是正负部分的总和。

希望 tensorflow 中有一个函数可以做到这一点,tf.nn.weighted_cross_entropy_with_logits. 它的使用类似于tf.nn.sigmoid_cross_entropy带有一个对应于正类权重的附加参数。

您目前正在做的是在两个不同的通道上使用两个二元分类器,并且只将负样本发送到第一个,将正样本发送到第二个。这不能产生有用的东西。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

使用TensorFlow训练不平衡数据

来自分类Dev

编辑 TensorFlow 源以修复不平衡的数据

来自分类Dev

处理二进制分类中的类不平衡

来自分类Dev

初始训练是否适用于不平衡的数据集?(Tensorflow)

来自分类Dev

类别不平衡的catboost分类器?

来自分类Dev

如何在相当平衡的二元分类中修复很高的假阴性率?

来自分类Dev

二进制分类情况下数据集不平衡的问题

来自分类Dev

Keras中类不平衡多类分类器的损失函数

来自分类Dev

如何处理 Scikit.learn 管道中不平衡的 xgboost 多类分类?

来自分类Dev

iOS 8中的通话不平衡

来自分类Dev

Chrome中不平衡的CSS列

来自分类Dev

Matlab中不平衡的Anova

来自分类Dev

DEoptim中的堆栈不平衡

来自分类Dev

在sklearn中使用RandomForestClassifier进行不平衡分类

来自分类Dev

TensorFlow 中的简单分类

来自分类Dev

如何使用 tensorflow softmax_cross_entropy_with_logits 缩放和重新归一化输出以解决类不平衡问题

来自分类Dev

从不平衡二叉树中随机选择一个节点

来自分类Dev

打印不平衡的二叉树

来自分类Dev

如何从不平衡三元转换为平衡三元?

来自分类Dev

Tensorflow 中离散标签的分类

来自分类Dev

多个二元分类器组合

来自分类Dev

客户旅程的二元分类

来自分类Dev

二元分类模型不准确

来自分类Dev

KMeans的不平衡因子?

来自分类Dev

PInvoke使堆栈不平衡

来自分类Dev

如何使图像不平衡?

来自分类Dev

使用libSVM的SVM中的数据不平衡

来自分类Dev

根据不平衡数据在ggplot中创建重叠直方图

来自分类Dev

在Spark MLlib中处理不平衡的数据集