我目前有一个自定义的tensorflow模型model_A,我想将其用作另一个模型的输出。我想不对整个新模型进行交叉验证,但是有以下问题:
如果我定义
model_copy= model_A
然后训练:
model_copy.fit(X_train,y_train)
那么这会影响原始模型。即,在训练了model_copy之后,model_A.predict(X_train)将不再等于model_A.predict(X_train)...如何避免此问题?
我的实施细节:
通过随机网格搜索确定最佳参数后,请执行以下操作:
best_params = model1.best_params_
input_shape_model1 = tf.keras.Input(shape=(best_params['input_dim'],))
deep_readout = fullyConnected_Dense(best_params['height'])(input_shape_model1)
for i in range(best_params['depth']):
# Activation
layer= tf.nn.relu(layer)
layer=fullyConnected_Dense(best_params['height'])(layer)
out_model1 = fullyConnected_Dense(1)(layer)
best_deep_readout_map = tf.keras.Model(input_shape_model1, out_model1)
opt_model1 = Adam(lr=best_params['learning_rate']) best_model1.compile(optimizer=opt_readout, loss="mae", metrics=["mse", "mae", "mape"])
best_model1.fit(X_concat,z_concat, epochs=best_params['epochs'],
batch_size=
best_params ['batch_size'])
完成此操作后...我尝试保存:
best_model1.save('~/Desktop/models/full_max/')
我得到这个错误:
错误:
AttributeError Traceback (most recent call last) <ipython-input-208-c4e56870b2c0> in <module>
----> 1 best_deep_readout_map.save('~/Desktop/models/full_max/')
~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options) 1006 """ 1007 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 1008 signatures, options) 1009 1010 def save_weights(self, filepath, overwrite=True, save_format=None):
~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
113 else:
114 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115 signatures, options)
116
117
~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
76 # we use the default replica context here.
77 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
---> 78 save_lib.save(model, filepath, signatures, options)
79
80 if not include_optimizer:
~/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
897 # Note we run this twice since, while constructing the view the first time
898 # there can be side effects of creating variables.
--> 899 _ = _SaveableView(checkpoint_graph_view)
900 saveable_view = _SaveableView(checkpoint_graph_view)
901
~/.local/lib/python3.7/site-packages/tensorflow_core/python/saved_model/save.py in __init__(self, checkpoint_view)
163 self.checkpoint_view = checkpoint_view
164 trackable_objects, node_ids, slot_variables = (
--> 165 self.checkpoint_view.objects_ids_and_slot_variables())
166 self.nodes = trackable_objects
167 self.node_ids = node_ids
~/.local/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/graph_view.py in objects_ids_and_slot_variables(self)
416 object_names = object_identity.ObjectIdentityDictionary()
417 for obj, path in path_to_root.items():
--> 418 object_names[obj] = _object_prefix_from_path(path)
419 node_ids = object_identity.ObjectIdentityDictionary()
420 for node_id, node in enumerate(trackable_objects):
~/.local/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/graph_view.py in _object_prefix_from_path(path_to_root)
62 return "/".join(
63 (_escape_local_name(trackable.name)
---> 64 for trackable in path_to_root))
65
66
~/.local/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/graph_view.py in <genexpr>(.0)
62 return "/".join(
63 (_escape_local_name(trackable.name)
---> 64 for trackable in path_to_root))
65
66
~/.local/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/graph_view.py in _escape_local_name(name)
55 # edges traversed to reach the variable, so we escape forward slashes in
56 # names.
---> 57 return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
58 .replace(r"/", _ESCAPE_CHAR + "S"))
59
**AttributeError: 'NoneType' object has no attribute 'replace'**
这将影响原始模型,因为这是Python的工作方式,它使用引用,因此实际上它们指向相同的内存位置。
如果要确保您拥有不同的模型,无论您打算如何处理它们,都可以创建一个返回特定模型X的函数。
例如
def retrieve_template_model():
...
...
...
return model
model_1 = retrieve_template_model()
model_1.fit()
model_2 = retrieve_template_model()
#Here model 1 and model 2 are different, and you still have the weights trained for model 1
但是,我建议您model.save()
将权重和体系结构一起保存。同时tf.keras.models.load_model()
,为了加载模型并防止每次重新训练而使用。
如果要转移重量,可以使用:
model_2.set_weights(model_1.get_weights())
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句