scikit-learn中的StratifiedKFold vs KFold

CR7

我使用此代码进行测试KFoldStratifiedKFold

import numpy as np
from sklearn.model_selection import KFold,StratifiedKFold

X = np.array([
    [1,2,3,4],
    [11,12,13,14],
    [21,22,23,24],
    [31,32,33,34],
    [41,42,43,44],
    [51,52,53,54],
    [61,62,63,64],
    [71,72,73,74]
])

y = np.array([0,0,0,0,1,1,1,1])

sfolder = StratifiedKFold(n_splits=4,random_state=0,shuffle=False)
floder = KFold(n_splits=4,random_state=0,shuffle=False)

for train, test in sfolder.split(X,y):
    print('Train: %s | test: %s' % (train, test))
print("StratifiedKFold done")

for train, test in floder.split(X,y):
    print('Train: %s | test: %s' % (train, test))
print("KFold done")

我发现StratifiedKFold可以保留标签的比例,但是KFold不能。

Train: [1 2 3 5 6 7] | test: [0 4]
Train: [0 2 3 4 6 7] | test: [1 5]
Train: [0 1 3 4 5 7] | test: [2 6]
Train: [0 1 2 4 5 6] | test: [3 7]
StratifiedKFold done
Train: [2 3 4 5 6 7] | test: [0 1]
Train: [0 1 4 5 6 7] | test: [2 3]
Train: [0 1 2 3 6 7] | test: [4 5]
Train: [0 1 2 3 4 5] | test: [6 7]
KFold done

似乎StratifiedKFold更好,所以KFold不应该使用?

什么时候使用KFold代替StratifiedKFold

杰伊·佩拉契(Jay Peerachai)

我认为您应该问“何时使用StratifiedKFold而不是KFold? ”。

您需要先了解“ KFold ”和“ Stratified ”是什么。

KFold是一个交叉验证器,可将数据集分为k个折叠。

分层是为了确保每个数据集具有给定标签的观察结果比例相同。

因此,这意味着StratifiedKFoldKFold的改进版本

因此,此问题的答案是,在处理具有不平衡类分布的分类任务时,我们应首选StratifiedKFold而不是KFold


例如

假设有一个包含16个数据点和不平衡类分布的数据集。在数据集中,有12个数据点属于A类,其余(即4个)属于B类。B类与A类的比率为1/3。如果我们使用StratifiedKFold并设置k = 4,则训练集将包括A类的3个数据点和B类的9个数据点,而测试集将包括A类的3个数据点和B类的1个数据点。

如我们所见,数据集的类分布由StratifiedKFold保存在拆分中,KFold没有考虑到这一点。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

Scikit-Learn:如何检索KFold CV的预测概率?

来自分类Dev

scikit-learn StratifiedKFold实现

来自分类Dev

python中的KFold到底能做什么?

来自分类Dev

如何在TfidfVectorizer中应用Kfold?

来自分类Dev

Scikit Learn中的交叉验证

来自分类Dev

Scikit Learn中的距离指标

来自分类Dev

StratifiedKFold vs StratifiedShuffleSplit vs StratifiedKFold + Shuffle

来自分类Dev

StratifiedKFold vs StratifiedShuffleSplit vs StratifiedKFold + Shuffle

来自分类Dev

了解Scikit Learn中的Birch集群设置

来自分类Dev

scikit-learn中的成本敏感分析

来自分类Dev

在Scikit Learn中控制Logistic回归的阈值

来自分类Dev

scikit-learn中的“ verbose”参数

来自分类Dev

scikit-learn中LogisticRegression的GridSearchCV

来自分类Dev

删除scikit Learn中的特定功能

来自分类Dev

在scikit-learn中运行Randomforest的MemoryError

来自分类Dev

Scikit-Learn KDE中的PDF估计

来自分类Dev

Scikit-Learn中的分类数据转换

来自分类Dev

在scikit-learn中运行Randomforest的MemoryError

来自分类Dev

了解Scikit Learn中的Birch集群设置

来自分类Dev

scikit-learn中的叶排序

来自分类Dev

Gridsearch 和 Kfold 中的默认 CV 有什么区别

来自分类Dev

使用 KFold 拆分来拟合模型返回“不在索引中”

来自分类Dev

.arff文件与scikit-learn?

来自分类Dev

Python scikit-learn-TypeError

来自分类Dev

输出 Scikit Learn OLS 报告

来自分类Dev

scikit-learn:最近的邻居

来自分类Dev

Scikit-learn 导入约定

来自分类Dev

scikit-learn 中 NMF 中的自定义矩阵

来自分类Dev

scikit中的RBM预测