下面注释掉的五行应该起作用,但是无效。预测分数与我的预期分数不尽相同,当我执行plt.imshow(img)时,它将显示错误的图像。这是我在Colab中的笔记本的链接。
x, y = next(valid_generator)
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
image_url = 'https://mysite_example/share/court3.jpg'
image_url = tf.keras.utils.get_file('Court', origin=image_url )
#img = keras.preprocessing.image.load_img( image_url, target_size=( 224, 224 ) )
#img_array = keras.preprocessing.image.img_to_array(img)
#img_array = tf.expand_dims(img_array, 0)
#prediction_scores = model.predict(np.expand_dims(img_array, axis=0))
#plt.imshow(img)
# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + get_class_string_from_index(true_index))
print("Predicted label: " + get_class_string_from_index(predicted_index)
tf.keras.utils.get_file
仅当文件尚未缓存时,该方法才将文件从url下载到本地缓存。因此,如果所有URL都使用相同的缓存名称(代码中的“ Court”?),则只会看到第一个文件。
同样,在训练时,您还有一个预处理步骤,通过用除以将所有像素归一化255
。您还必须在推理过程中应用相同的预处理步骤。
工作代码:
_, axis = plt.subplots(1,3)
for i, image_url in enumerate(['https://squashvideo.site/share/court3.jpg',
'https://i.pinimg.com/originals/0f/c2/9b/0fc29b35532f8e2fb998f5605212ab27.jpg',
'https://thumbs.dreamstime.com/b/squash-court-photo-empty-30346175.jpg']):
image_url = tf.keras.utils.get_file('Court', origin=image_url )
img = tf.keras.preprocessing.image.load_img(image_url, target_size=( 224, 224 ) )
os.remove(image_url) # Remove the cached file
axis[i].imshow(img)
img_array = keras.preprocessing.image.img_to_array(img)
prediction_scores = model.predict(np.expand_dims(img_array, axis=0)/255)
axis[i].title.set_text(np.argmax(prediction_scores, axis=1))
如您所见,预测是完美的,最后一张图片属于第0类(空壁球场),第二张图片属于第1类(在壁球场玩的球员)
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句