Data not persistent in scikit-learn transformers

Bob

I'd like to pass additional data to a transformer in scikit-learn:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import Pipeline
import numpy as np
from sklearn.model_selection import GridSearchCV

class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

data = np.random.rand(20,20)
data2 = np.random.rand(6,6)
y = np.array([1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3])

pipe = Pipeline(steps=[('myt', myTransformer(data2)), ('randforest', RandomForestClassifier())])
params = {"randforest__n_estimators": [100, 1000]}
estimators = GridSearchCV(pipe, param_grid=params, verbose=True)
estimators.fit(data, y)

However, when used in a scikit-learn pipeline, it seems to disappear

I'm getting None from the print inside the init method. How do I fix it?

lejlot

This happens because sklearn handles estimators in a very specific way. In general it will create a new instance of the class for things like grid searching, and will pass a parameters to the constructor. This happens because sklearn has its own clone operation (defined in base.py) which takes your estimator class, gets parameters (returned by get_params) and passes it to the constructor of your class

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in six.iteritems(new_object_params):
    new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params) 

In order to support that your object has to override get_params(deep=False) method, which should return dictionary, which will be passed to constructor

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

    def get_params(self, deep=False):
        return {'my_np_array': self.data}

will work as expected.

この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。

侵害の場合は、連絡してください[email protected]

編集
0

コメントを追加

0

関連記事

分類Dev

Cross-validation metrics in scikit-learn for each data split

分類Dev

Error plotting scikit-learn dataset training and test data

分類Dev

scikit learn documentation in PDF

分類Dev

Using scikit-learn (sklearn), how to handle missing data for linear regression?

分類Dev

Scikit-learn SVC always giving accuracy 0 on random data cross validation

分類Dev

Why does scikit-learn demand different data shapes for different regressors?

分類Dev

Scikit-Learn Standard Scaler

分類Dev

repeated FeatureUnion in scikit-learn

分類Dev

scikit-learn StratifiedKFold implementation

分類Dev

scikit-learn:最近傍

分類Dev

Kubernetes PVC data persistent

分類Dev

scikit learnのRandomForestClassifierとExtraTreesClassifier

分類Dev

Scikit-learn tutorial documentation location

分類Dev

Scikit learn split train test for series

分類Dev

Balanced Random Forest in scikit-learn (python)

分類Dev

Scikit-Learn Agglomerative Clustering Connectivity Matrix

分類Dev

Custom tokenizer for scikit-learn vectorizers

分類Dev

「KeyError:0」、xgboost、scikit-learn、pandas

分類Dev

Looping scikit-learn machine learning datasets

分類Dev

Scikit-learn tfidf vectorizer in minibatches?

分類Dev

Target transformation and feature selection in scikit-learn

分類Dev

Installing an old version of scikit-learn

分類Dev

anaconda/spyder scikit learn update 0.21.3 to 0.22.2

分類Dev

API calls from NLTK, Gensim, Scikit Learn

分類Dev

StratifiedKFold vs KFold in scikit-learn

分類Dev

StratifiedKFoldとscikit-learnのKFold

分類Dev

Scikit-Learn Not Properly Updating in IPython

分類Dev

scikit-learn Ridge Regression UnboundLocalError

分類Dev

Predict movie reviews with scikit-learn

Related 関連記事

ホットタグ

アーカイブ