我修改了scikit-learn的BernoulliRBM类,以使用softmax可见单位的组。在此过程中,我添加了一个额外的Numpy数组visible_config
作为类属性,该类在构造函数中使用以下方法进行了初始化:
self.visible_config = np.cumsum(np.concatenate((np.asarray([0]),
visible_config), axis=0))
其中visible_config
的Numpy数组作为输入传递给构造函数。当我直接使用该fit()
函数训练模型时,代码可以正确运行。但是,当我使用该GridSearchCV
结构时,出现以下错误
Cannot clone object SoftmaxRBM(batch_size=100, learning_rate=0.01, n_components=100, n_iter=100,
random_state=0, verbose=True, visible_config=[ 0 21 42 63]), as the constructor does not seem to set parameter visible_config
在类实例与sklearn.base.clone创建的副本之间的相等性检查中,这似乎是一个问题,因为visible_config
没有正确地对其进行复制。我不确定如何解决此问题。它说,在文档中sklearn.base.clone
使用deepcopy()
,所以不应该visible_config
也被复制?有人可以解释一下我可以在这里尝试吗?谢谢!
没有看到您的代码,很难准确地指出出了什么问题,但是您违反了此处的scikit-learn API约定。估计器中的构造函数应仅将属性设置为用户作为参数传递的值。所有计算都应在中进行fit
,如果fit
需要存储计算结果,则应在带有下划线(_
)的属性中进行存储。这种约定是使clone
和诸如GridSearchCV
工作之类的元估计量产生的原因。
(*)如果您在主代码库中看到一个估算器违反了此规则:那将是一个错误,欢迎使用补丁程序。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句