keras:准确度为98%,但NN始终预测相同。可能是什么原因?

伊娃·安德烈斯(Eva Andres)

我们在训练DL模型以预测贷款得分(分类为0、1或3)时遇到以下问题。

这些步骤是:

步骤1:创建新列“得分”(输出)

conditions = [
(df2['Credit Score'] >= 0) & (df2['Credit Score'] < 1000),
(df2['Credit Score'] >= 1000) & (df2['Credit Score'] < 6000),
(df2['Credit Score'] >= 6000) & (df2['Credit Score'] <= 7000)]
choices = [0,1,2]
df2['Scoring'] = np.select(conditions, choices) 

步骤2:准备训练

array = df2.values
X = np.vstack((array[:,2:3].T, array[:,5:15].T)).T
Y = array[:,15:]
N = Y.shape[0]
T = np.zeros((N, np.max(Y)+1))
for i in range(N):
  T[i,Y[i]] = 1

x_train, x_test, y_train, y_test = train_test_split(X, T, test_size=0.2, random_state=42)

步骤3:拓扑

model = Sequential()

model.add(Dense(80, input_shape=(11,), activation='tanh'))
model.add(Dropout(0.2))
model.add(Dense(80, activation='tanh'))
model.add(Dropout(0.1))
model.add(Dense(40, activation='relu'))
model.add(Dense(3, activation='softmax'))

epochs =200
learning_rate = 0.00001
decay_rate = learning_rate / epochs
momentum = 0.002
sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
ad = Adamax(lr=learning_rate)

步骤4:训练

 epochs = 200 
 batch_size = 16 

 history = model.fit(x_train, y_train, validation_data=(x_test, y_test), nb_epoch=epochs, 
 batch_size=batch_size,validation_split=0.1) 
 print ('fit done!')

指标

365/365 [==============================]-0s 60us / sample-损耗:0.0963-acc:0.9808测试装置损失:0.096准确性:0.981

准确性

第五步:预测

text1 = [1358,1555,1,3,1741,8,0,1596,1518,0,0] #scoring 0 
text2 = [1454,1601,3,11,1763,10,0,685,1044,0,0] #scoring 1 
text3 = [1209,1437,3,11,199,18,1,761,1333,1,0] #scoring 2

tmp = np.vstack(text1).T
textA = tmp.reshape(1,-1)

tmp = np.vstack(text2).T
textB = tmp.reshape(1,-1)

tmp = np.vstack(text3).T
print(tmp)
textC = tmp.reshape(1,-1)

p = model.predict(textA)
t = p[0]
print(textA,np.argmax(t))


p = model.predict(textB)
t = p[0]
print(textB,np.argmax(t))

p = model.predict(textC)
t = p[0]
print(textC,np.argmax(t))

问题:预测中的输出始终相同!!!

[9.9205679e-01 3.8634153e-04 7.5568780e-03] [[1358 1555 1 3 1741 8 0 1596 1518 0 0]] 0 ---得分0

[0.9862417 0.00205712 0.01170125] [[1454 1601 3 11 1763 10 0 685 1044 0 0]] 0 ---得分0

[9.9251783e-01 2.5733517e-04 7.2247880e-03] [[1209 1437 3 11 199 18 1 761 1333 1 0]] 0 ----得分0

发生这种现象的原因是什么?

提前致谢!

路易斯·达席尔瓦

您的数据集非常不平衡。观察它的一种好方法是:如果总是预测0可使您达到98%的准确度,那么说某物属于另一个类是非常冒险的(或者必须非常明显)。NN可能发现的使少数派类别与多数派类别(0)不同的每个模式都必须非常独特,因为即使重叠很小,不预测0的成本也很高。

考虑以下示例:您有一个包含两个类A和B的数据集,这两个类都遵循正态分布。A类的均值为1和std 1,B的均值为3和std 0.1。您有1,000,000个类别0的样本和20,000个类别1的样本,因此始终预测A会为您提供98%的准确性。B类的所有样本将位于2.743和3.257之间,置信度为99%。在这些值之间,A类预计将有29,300个样本,因此,预测B类观测到的任何观测值的成本将在29,300个A样本中产生误差,但是将所有事物预测为A的成本将仅在20,000个B样本中产生误差。 。

该示例以图形方式显示如下:

import numpy as np
import matplotlib.pyplot as plt

# Get A and B
A = np.random.normal(1, 1, 1000000)
B = np.random.normal(3, 0.1, 20000)

# Count the number of observations in A for each B
B.sort()
a = A[np.logical_and(A >= B.min(), A <= B.max())]
a = [(a<i).sum() for i in B]

# Plot results
plt.plot(B, np.arange(B.shape[0]), label='Class B')
plt.plot(B, a, label='Class A')
plt.ylabel('Count of samples')
plt.xlabel('Values')
plt.legend()
plt.show()

不平衡的例子

请参阅有关平衡数据集的这篇文章:https : //www.kdnuggets.com/2017/06/7-techniques-handle-imbalanced-data.html

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

hMatrix中断文件,可能是什么原因?

来自分类Dev

CNN - 训练准确度为 1.0,验证准确度为 1.0。预测返回 0.0

来自分类Dev

JMeter:尽管用户负载增加了,但吞吐量相同可能是什么原因?

来自分类Dev

GameCenter通知横幅有时会“压扁”-可能是什么原因造成的?

来自分类Dev

删除项目引发以下异常。可能是什么原因造成的?

来自分类Dev

Joblib因无法解释的错误而崩溃,可能是什么原因导致的?

来自分类Dev

Spring JPA无法理解此JSON。可能是什么原因?

来自分类Dev

Python Multiprocessing模块意外输出。可能是什么原因造成的?

来自分类Dev

与腻子相比,ssh / scp命令的连接速度太慢。可能是什么原因造成的?

来自分类Dev

我认为此代码容易受到攻击。可能是什么原因?

来自分类Dev

只能使用https访问网站,这可能是什么原因造成的?

来自分类Dev

FFMPEG的持续时间规范无效-这可能是什么原因?

来自分类Dev

下面显示的拖影和重影可能是什么原因?

来自分类Dev

我的笔记本电脑超级慢。可能是什么原因?

来自分类Dev

卡桑德拉(Cassandra)腐败的sstables可能是什么原因

来自分类Dev

.html,.xml和.txt文件损坏:可能是什么原因?

来自分类Dev

错误“第 1 行中的 MySQL 错误”可能是什么原因

来自分类Dev

代理不适用于 PhantomJS。这可能是什么原因?

来自分类Dev

使用 ifstream 对象设置 failbit,可能是什么原因?

来自分类Dev

遵守时班级没有显示任何内容......可能是什么原因?

来自分类Dev

Keras:使用model.train_on_batch()和model.fit()获得不同的精度。可能是什么原因以及如何解决?

来自分类Dev

为什么带有空白上传表单的表单要花这么长时间?可能是什么原因?

来自分类Dev

网页加载后跳到底部,不知道为什么,这可能是什么原因?

来自分类Dev

为什么重新加载后引导模式弹出而不是窗口加载。可能是什么原因?

来自分类Dev

改变 keras 指标=准确度

来自分类Dev

清除cakephp tmp / cache只能解决一个保存调用的问题。可能是什么原因?

来自分类Dev

“ Prism-ES2错误:GL_VERSION(major.minor)= 1.4”可能是什么原因?

来自分类Dev

java.net.SocketException:权限被拒绝:连接。可能是什么原因,以及如何避免它?

来自分类Dev

Node.js-Socketio无法与https服务器一起使用,可能是什么原因?

Related 相关文章

  1. 1

    hMatrix中断文件,可能是什么原因?

  2. 2

    CNN - 训练准确度为 1.0,验证准确度为 1.0。预测返回 0.0

  3. 3

    JMeter:尽管用户负载增加了,但吞吐量相同可能是什么原因?

  4. 4

    GameCenter通知横幅有时会“压扁”-可能是什么原因造成的?

  5. 5

    删除项目引发以下异常。可能是什么原因造成的?

  6. 6

    Joblib因无法解释的错误而崩溃,可能是什么原因导致的?

  7. 7

    Spring JPA无法理解此JSON。可能是什么原因?

  8. 8

    Python Multiprocessing模块意外输出。可能是什么原因造成的?

  9. 9

    与腻子相比,ssh / scp命令的连接速度太慢。可能是什么原因造成的?

  10. 10

    我认为此代码容易受到攻击。可能是什么原因?

  11. 11

    只能使用https访问网站,这可能是什么原因造成的?

  12. 12

    FFMPEG的持续时间规范无效-这可能是什么原因?

  13. 13

    下面显示的拖影和重影可能是什么原因?

  14. 14

    我的笔记本电脑超级慢。可能是什么原因?

  15. 15

    卡桑德拉(Cassandra)腐败的sstables可能是什么原因

  16. 16

    .html,.xml和.txt文件损坏:可能是什么原因?

  17. 17

    错误“第 1 行中的 MySQL 错误”可能是什么原因

  18. 18

    代理不适用于 PhantomJS。这可能是什么原因?

  19. 19

    使用 ifstream 对象设置 failbit,可能是什么原因?

  20. 20

    遵守时班级没有显示任何内容......可能是什么原因?

  21. 21

    Keras:使用model.train_on_batch()和model.fit()获得不同的精度。可能是什么原因以及如何解决?

  22. 22

    为什么带有空白上传表单的表单要花这么长时间?可能是什么原因?

  23. 23

    网页加载后跳到底部,不知道为什么,这可能是什么原因?

  24. 24

    为什么重新加载后引导模式弹出而不是窗口加载。可能是什么原因?

  25. 25

    改变 keras 指标=准确度

  26. 26

    清除cakephp tmp / cache只能解决一个保存调用的问题。可能是什么原因?

  27. 27

    “ Prism-ES2错误:GL_VERSION(major.minor)= 1.4”可能是什么原因?

  28. 28

    java.net.SocketException:权限被拒绝:连接。可能是什么原因,以及如何避免它?

  29. 29

    Node.js-Socketio无法与https服务器一起使用,可能是什么原因?

热门标签

归档