如何使用Sk-learn提高SVM分类器的速度

Mayur kulkarni

我正在尝试构建垃圾邮件分类器,并且我已经从Internet上收集了多个数据集(例如,针对垃圾邮件/火腿邮件的SpamAssassin数据库)并构建了以下内容:

import os
import numpy
from pandas import DataFrame
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import KFold
from sklearn.metrics import confusion_matrix, f1_score
from sklearn import svm

NEWLINE = '\n'

HAM = 'ham'
SPAM = 'spam'

SOURCES = [
    ('C:/data/spam', SPAM),
    ('C:/data/easy_ham', HAM),
    # ('C:/data/hard_ham', HAM), Commented out, since they take too long
    # ('C:/data/beck-s', HAM),
    # ('C:/data/farmer-d', HAM),
    # ('C:/data/kaminski-v', HAM),
    # ('C:/data/kitchen-l', HAM),
    # ('C:/data/lokay-m', HAM),
    # ('C:/data/williams-w3', HAM),
    # ('C:/data/BG', SPAM),
    # ('C:/data/GP', SPAM),
    # ('C:/data/SH', SPAM)
]

SKIP_FILES = {'cmds'}


def read_files(path):
    for root, dir_names, file_names in os.walk(path):
        for path in dir_names:
            read_files(os.path.join(root, path))
        for file_name in file_names:
            if file_name not in SKIP_FILES:
                file_path = os.path.join(root, file_name)
                if os.path.isfile(file_path):
                    past_header, lines = False, []
                    f = open(file_path, encoding="latin-1")
                    for line in f:
                        if past_header:
                            lines.append(line)
                        elif line == NEWLINE:
                            past_header = True
                    f.close()
                    content = NEWLINE.join(lines)
                    yield file_path, content


def build_data_frame(path, classification):
    rows = []
    index = []
    for file_name, text in read_files(path):
        rows.append({'text': text, 'class': classification})
        index.append(file_name)

    data_frame = DataFrame(rows, index=index)
    return data_frame


data = DataFrame({'text': [], 'class': []})
for path, classification in SOURCES:
    data = data.append(build_data_frame(path, classification))

data = data.reindex(numpy.random.permutation(data.index))

pipeline = Pipeline([
    ('count_vectorizer', CountVectorizer(ngram_range=(1, 2))),
    ('classifier', svm.SVC(gamma=0.001, C=100))
])

k_fold = KFold(n=len(data), n_folds=6)
scores = []
confusion = numpy.array([[0, 0], [0, 0]])
for train_indices, test_indices in k_fold:
    train_text = data.iloc[train_indices]['text'].values
    train_y = data.iloc[train_indices]['class'].values.astype(str)

    test_text = data.iloc[test_indices]['text'].values
    test_y = data.iloc[test_indices]['class'].values.astype(str)

    pipeline.fit(train_text, train_y)
    predictions = pipeline.predict(test_text)

    confusion += confusion_matrix(test_y, predictions)
    score = f1_score(test_y, predictions, pos_label=SPAM)
    scores.append(score)

print('Total emails classified:', len(data))
print('Support Vector Machine Output : ')
print('Score:' + str((sum(scores) / len(scores))*100) + '%')
print('Confusion matrix:')
print(confusion)

我注释掉的行是邮件的集合,即使我注释掉大多数数据集并选择邮件数量最少的行,它仍然运行极慢(〜15分钟),准确度约为91%。如何提高速度和准确性?

大卫·老鼠

您正在使用内核SVM。这有两个问题。

内核SVM的运行时间复杂度:执行内核SVM的第一步是构建相似度矩阵,该矩阵成为功能集。对于30,000个文档,相似度矩阵中的元素数变为90,000,000。随着矩阵的增长,这个增长很快,因为矩阵增长了文本集中文档数量的平方。可以RBFSampler在scikit-learn中使用using来解决此问题,但是由于下一个原因,您可能不想使用它。

维度:您正在使用术语和二元数作为您的功能集。这是一个非常高维的数据集。在高维空间中使用RBF内核,即使很小的差异(噪声)也可能对相似性结果产生重大影响。维度诅咒这很可能就是您的RBF内核比线性内核产生更差结果的原因。

随机梯度下降:可以使用SGD代替标准SVM,并且通过良好的参数调整,它可能会产生相似甚至更好的结果。缺点是SGD具有更多关于学习率和学习率计划的参数。另外,对于几张通行证,SGD也不理想。在那种情况下,其他算法(例如遵循正规化领导者(FTRL))会做得更好。但是,Scikit-learn无法实现FTRL。使用SGDClassifierloss="modified_huber"效果往往更好。

现在我们已经解决了问题,有几种方法可以提高性能:

tf-idf权重:使用tf-idf,较常见的单词的权重较小。这使分类器可以更好地表示更有意义的稀有词。这可以通过切换CountVectorizerTfidfVectorizer来实现

参数调整:对于线性SVM,没有gamma参数,但是C可以使用参数极大地改善结果。在的情况下SGDClassifier,也可以调整alpha和学习率参数。

集成:在多个子样本上运行模型并取平均结果通常会比单次运行生成更可靠的模型。这可以在scikit-learn中使用来完成BaggingClassifier同时结合不同的方法可以产生明显更好的结果。如果使用的方法大不相同,请考虑将堆叠模型与树模型(RandomForestClassifier或GradientBoostingClassifier)一起用作最后阶段。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

使用SVM分类器和多种算法提高准确率

来自分类Dev

在Android中使用经过训练的Scikit-learn svm分类器

来自分类Dev

SK了解如何获取LinearSVC分类器的决策概率

来自分类Dev

如何使用SIFT和SVM实现常规图像分类器

来自分类Dev

如何使用python多次训练SVM分类器?

来自分类Dev

如何提高Android模拟器速度

来自分类Dev

如何提高模拟器的速度?

来自分类Dev

如何提高Android模拟器速度

来自分类Dev

如何提高Android模拟器的速度?

来自分类Dev

如何通过使用计时器提高游戏速度?

来自分类Dev

提高 kNN 分类器的性能(速度)

来自分类Dev

特征长度如何取决于SVM分类器中的预测

来自分类Dev

如何在python中从sklearn训练多次SVM分类器?

来自分类Dev

如何在MATLAB中找到SVM分类器的分数?

来自分类Dev

如何提高SVM的性能?

来自分类Dev

如何使用scikit-learn执行集成(多分类器)分类?

来自分类Dev

使用SVM进行分类

来自分类Dev

python:如何在scikit学习分类器(SVM)等中使用POS(词性)功能

来自分类Dev

如何在Matlab SVM分类器中使用crossval()函数的输出创建混淆矩阵?

来自分类Dev

python:如何在scikit学习分类器(SVM)等中使用POS(词性)功能

来自分类Dev

tensorflow tf.contrib.learn.SVM 如何重新加载训练好的模型并使用预测对新数据进行分类

来自分类Dev

如何提高LoginAuthentication的速度?

来自分类Dev

如何提高查询速度?

来自分类Dev

如何提高查询速度?

来自分类Dev

使用 python 多处理运行 sk-learn model.predict

来自分类Dev

如何使用RallyAPIForJava提高速度

来自分类Dev

如何提高使用MySQL的PHP脚本的速度?

来自分类Dev

使用CloudFront如何提高上传速度?

来自分类Dev

如何使用RallyAPIForJava提高速度

Related 相关文章

  1. 1

    使用SVM分类器和多种算法提高准确率

  2. 2

    在Android中使用经过训练的Scikit-learn svm分类器

  3. 3

    SK了解如何获取LinearSVC分类器的决策概率

  4. 4

    如何使用SIFT和SVM实现常规图像分类器

  5. 5

    如何使用python多次训练SVM分类器?

  6. 6

    如何提高Android模拟器速度

  7. 7

    如何提高模拟器的速度?

  8. 8

    如何提高Android模拟器速度

  9. 9

    如何提高Android模拟器的速度?

  10. 10

    如何通过使用计时器提高游戏速度?

  11. 11

    提高 kNN 分类器的性能(速度)

  12. 12

    特征长度如何取决于SVM分类器中的预测

  13. 13

    如何在python中从sklearn训练多次SVM分类器?

  14. 14

    如何在MATLAB中找到SVM分类器的分数?

  15. 15

    如何提高SVM的性能?

  16. 16

    如何使用scikit-learn执行集成(多分类器)分类?

  17. 17

    使用SVM进行分类

  18. 18

    python:如何在scikit学习分类器(SVM)等中使用POS(词性)功能

  19. 19

    如何在Matlab SVM分类器中使用crossval()函数的输出创建混淆矩阵?

  20. 20

    python:如何在scikit学习分类器(SVM)等中使用POS(词性)功能

  21. 21

    tensorflow tf.contrib.learn.SVM 如何重新加载训练好的模型并使用预测对新数据进行分类

  22. 22

    如何提高LoginAuthentication的速度?

  23. 23

    如何提高查询速度?

  24. 24

    如何提高查询速度?

  25. 25

    使用 python 多处理运行 sk-learn model.predict

  26. 26

    如何使用RallyAPIForJava提高速度

  27. 27

    如何提高使用MySQL的PHP脚本的速度?

  28. 28

    使用CloudFront如何提高上传速度?

  29. 29

    如何使用RallyAPIForJava提高速度

热门标签

归档