我正在尝试在Google Colab中训练自动编码器。使用ImageDataGenerator。我用这个运行这段代码
from keras.preprocessing.image import ImageDataGenerator
batch_size = 128
train_datagen = ImageDataGenerator(rescale=1./255, validation_split = 0.2)
training_generator = train_datagen.flow_from_directory(train_dir,
target_size=(105, 105),
color_mode='grayscale',
batch_size = batch_size,
class_mode=None,
subset='training')
validation_generator = train_datagen.flow_from_directory(train_dir,
target_size=(105, 105),
color_mode='grayscale',
batch_size = batch_size,
class_mode=None,
subset='validation')
history = autoencoder.fit_generator(generator=training_generator,
epochs=5,
steps_per_epoch=training_generator.samples // batch_size,
validation_data=validation_generator,
validation_steps = validation_generator.samples // batch_size,
use_multiprocessing=False)
它一直运行到第一个纪元步骤才引发此错误:
/usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/utils.py:173: UserWarning: Using ".tiff" files with multiple bands will cause distortion. Please verify your output.
warnings.warn('Using ".tiff" files with multiple bands '
Found 1375004 images belonging to 1 classes.
Found 343750 images belonging to 1 classes.
Epoch 1/5
10741/10742 [============================>.] - ETA: 0s - loss: 0.0052Epoch 1/5
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-25-c39bd284b251> in <module>()
23 validation_data=validation_generator,
24 validation_steps = validation_generator.samples // batch_size,
---> 25 use_multiprocessing=False)
5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in __call__(self, *args, **kwargs)
1470 ret = tf_session.TF_SessionRunCallable(self._session._session,
1471 self._handle, args,
-> 1472 run_metadata_ptr)
1473 if run_metadata:
1474 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'conv2d_transpose_1_target' with dtype float and shape [?,?,?,?]
[[{{node conv2d_transpose_1_target}}]]
[[loss/mul/_83]]
(1) Invalid argument: You must feed a value for placeholder tensor 'conv2d_transpose_1_target' with dtype float and shape [?,?,?,?]
[[{{node conv2d_transpose_1_target}}]]
0 successful operations.
0 derived errors ignored.
我用过K.clear_session()
,但不起作用。而且我仍然不知道我做错了什么,因为有一个线程已经在解释这个问题。我是keras的新手,任何帮助和建议都将是一件幸事!
这是自动编码器摘要和代码:
def create_model():
model = Sequential()
model.add(Conv2D(64, kernel_size=(11,11), strides=2, activation='relu', padding='valid', input_shape=(105,105,1)))
model.add(MaxPooling2D(pool_size=(2,2), strides=2, padding='valid'))
model.add(Conv2D(128, kernel_size=(1,1), strides=1, activation='relu', padding='valid'))
model.add(Conv2DTranspose(64, kernel_size=(1,1), strides=1, activation='relu', padding='valid'))
model.add(UpSampling2D(size=(2,2)))
model.add(Conv2DTranspose(1, kernel_size=(11,11), strides=2, activation='relu', padding='valid'))
adam = Adam(lr=0.01)
model.compile(optimizer=adam , loss='mean_squared_error')
return model
with tpu_strategy.scope(): # creating the model in the TPUStrategy scope means we will train the model on the TPU
autoencoder = create_model()
autoencoder.summary()
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_4 (Conv2D) (None, 48, 48, 64) 7808
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 24, 24, 64) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 24, 24, 128) 8320
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 24, 24, 64) 8256
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 48, 48, 64) 0
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 105, 105, 1) 7745
=================================================================
Total params: 32,129
Trainable params: 32,129
Non-trainable params: 0
我找到了答案,
换class_mode = 'input'
工作
并且,
如果要使用class_mode = None
,可以通过添加此功能来修复生成器
def fixed_generator(generator):
for batch in generator:
yield (batch, batch)
这是参考。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句