저는 개인 Re-ID 시스템을 구축하려고하는데 모델 훈련에 샴 아키텍처를 사용합니다. callbacks.ModelCheckpoint를 사용하여 각 시대에 모델을 저장합니다. 저장된 모델을로드하는 중에 오류가 발생했습니다.
훈련을 위해 VGG16 사전 훈련 된 모델을 사용합니다.
input_shape = (160,60,3)
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(160, 60, 3))
output = conv_base.layers[-5].output
x=Flatten()(output)
x=Dense(512,activation='relu')(x)
out=Dense(512,activation='relu')(x)
conv_base = Model(conv_base.input, output=out)
for layer in conv_base.layers[:-11]:
layer.trainable = False
샴 모델 생성 :
# We have 2 inputs, 1 for each picture
left_input = Input((160,60,3))
right_input = Input((160,60,3))
# We will use 2 instances of 1 network for this task
convnet = Sequential([
InputLayer(input_shape=(160, 60, 3)),
conv_base
])
# Connect each 'leg' of the network to each input
# Remember, they have the same weights
encoded_l = convnet(left_input)
encoded_r = convnet(right_input)
# Getting the L1 Distance between the 2 encodings
L1_layer = Lambda(lambda tensor:K.abs(tensor[0] - tensor[1]))
# Add the distance function to the network
L1_distance = L1_layer([encoded_l, encoded_r])
prediction = Dense(1,activation='sigmoid')(L1_distance)
siamese_net = Model(inputs=[left_input,right_input],outputs=prediction)
#optimizer = Adam(0.00006, decay=2.5e-4)
sgd = optimizers.RMSprop(lr=1e-4)
#//TODO: get layerwise learning rates and momentum annealing scheme described in paperworking
siamese_net.compile(loss="binary_crossentropy", optimizer=sgd, metrics=['accuracy'])
기차 네트워크 :
checkpoint = ModelCheckpoint('drive/My Drive/thesis/new change parametr/model/model-{epoch:03d}.h5', verbose=1, save_weights_only=False,monitor='val_loss', mode='auto')
newmodel=siamese_net.fit([left_train,right_train], targets,
batch_size=64,
epochs=2,
verbose=1,shuffle=True, validation_data=([valid_left,valid_right],valid_targets),callbacks=[checkpoint])
모델은 각 Epoch에 저장되지만로드 할 때 다음 오류가 발생합니다.
loaded_model= load_model('drive/My Drive/thesis/new change parametr/model/model-001.h5')
오류:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-6-8de2283b355f> in <module>()
1
----> 2 loaded_model= load_model('drive/My Drive/thesis/new change parametr/model/model-001.h5')
3 print('Load succesfuly')
4
5 #siamese_net.load_weights('drive/My Drive/thesis/new change parametr/weight/model-{epoch:03d}.h5')
7 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py in preprocess_weights_for_loading(layer, weights, original_keras_version, original_backend, reshape)
939 str(weights[0].size) + '. ')
940 weights[0] = np.reshape(weights[0], layer_weights_shape)
--> 941 elif layer_weights_shape != weights[0].shape:
942 weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
943 if layer.__class__.__name__ == 'ConvLSTM2D':
IndexError: list index out of range
내 코드는 Google Colaboratory에서 실행됩니다. 나는 온라인으로 검색했는데 아마도 샴 아키텍처를 사용했기 때문일 것입니다. 어떤 도움을 주시면 감사하겠습니다!
다음과 같이 네트워크를 생성 할 때 저장된 모델을로드하는 동안 오류가 발생했습니다.
input_shape = (160,60,3)
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(160, 60, 3))
output = conv_base.layers[-5].output
x=Flatten()(output)
x=Dense(512,activation='relu')(x)
out=Dense(512,activation='relu')(x)
conv_base = Model(conv_base.input, output=out)
for layer in conv_base.layers[:-11]:
layer.trainable = False
# We have 2 inputs, 1 for each picture
left_input = Input((160,60,3))
right_input = Input((160,60,3))
# We will use 2 instances of 1 network for this task
convnet = Sequential([
InputLayer(input_shape=(160, 60, 3)),
conv_base
])
이 문제는 생성 모델을 변경하여 해결되었습니다.
# We have 2 inputs, 1 for each picture
left_input = Input((160,60,3))
right_input = Input((160,60,3))
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(160, 60, 3))
output = conv_base.layers[-5].output
x=Flatten()(output)
x=Dense(512,activation='relu')(x)
out=Dense(512,activation='relu')(x)
for layer in conv_base.layers[:-11]:
layer.trainable = False
convnet = Model(conv_base.input, output=out)
그때:
loaded_model= load_model('drive/My Drive/thesis/new change parametr/model/model-001.h5')
print('Load successfully')
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다