为什么在加载模型时Tensorflow多类图像预测不起作用?

剧本

我目前正在尝试学习机器学习技术,并想用张量流重新创建一个简单的图像识别算法。因此,我制作了两个Python文件:一个用于训练,一个用于预测。

在Ubuntu 18.04上进行了测试使用的Python版本:3.7使用的Numpy版本:1.18.1使用的Tensorflow版本:1.14和2.1.0(以下输出来自版本1.14)

我的图片来自http://www.cs.columbia.edu/CAVE/databases/pubfig/download/#dev该集合包含来自60个人的约3000张裁剪后的面孔图像。

train_model.py:

#!/usr/bin/env python

import concurrent.futures
import pandas as pd
import urllib
import pathlib
import hashlib
import os
import json
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

people = pd.read_csv("dev_people.txt")

image_generator = ImageDataGenerator(rescale=1./255, validation_split=0.2, rotation_range=45, zoom_range=0.2)

IMG_HEIGHT = 128
IMG_WIDTH = 128
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
NUM_TRAIN = 100
STEPS_PER_EPOCH = round(NUM_TRAIN) // BATCH_SIZE
VAL_STEPS = 20
NUM_EPOCHS = 3

train_data = image_generator.flow_from_directory(batch_size=BATCH_SIZE,
    directory="persons-cropped",
    shuffle=True,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    class_mode="categorical",
    subset="training")

labels = train_data.class_indices
labels = {v: k for k, v in labels.items()}

with open("labels.json", "w") as labels_file:
    labels_file.write(json.dumps(labels))

validation_data = image_generator.flow_from_directory(batch_size=BATCH_SIZE,
    directory="persons-cropped",
    shuffle=True,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    class_mode="categorical",
    subset="validation")

base_model = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_WIDTH, IMG_HEIGHT, 3),
    include_top=False,
    weights="imagenet"
)

base_model.trainable = False

maxpool_layer = tf.keras.layers.GlobalMaxPooling2D()
prediction_layer = tf.keras.layers.Dense(60, activation="sigmoid")
dropout_layer = tf.keras.layers.Dropout(0.2)

model = tf.keras.Sequential([
    base_model,
    maxpool_layer,
#   dropout_layer,
    prediction_layer
])

model.compile(optimizer=tf.keras.optimizers.Adam(lr=LEARNING_RATE),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

model.summary()

model.fit(
    train_data,
    epochs=NUM_EPOCHS,
    steps_per_epoch=None,
    validation_data=validation_data,
    validation_steps=None,
    use_multiprocessing=False,
    workers=6,
    verbose=2
)

model.save("model.h5")

输出:

Found 2431 images belonging to 60 classes.
Found 573 images belonging to 60 classes.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
2020-01-25 22:23:40.036326: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2599985000 Hz
2020-01-25 22:23:40.036657: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x6b81c60 executing computations on platform Host. Devices:
2020-01-25 22:23:40.036789: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
2020-01-25 22:23:40.615771: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1412] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU.  To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984
_________________________________________________________________
global_max_pooling2d (Global (None, 1280)              0
_________________________________________________________________
dense (Dense)                (None, 60)                76860
=================================================================
Total params: 2,334,844
Trainable params: 76,860
Non-trainable params: 2,257,984
_________________________________________________________________
Epoch 1/3
2020-01-25 22:23:55.995833: W tensorflow/core/framework/allocator.cc:107] Allocation of 154140672 exceeds 10% of system memory.
2020-01-25 22:23:56.730363: W tensorflow/core/framework/allocator.cc:107] Allocation of 156905472 exceeds 10% of system memory.
2020-01-25 22:24:02.782372: W tensorflow/core/framework/allocator.cc:107] Allocation of 154140672 exceeds 10% of system memory.
2020-01-25 22:24:03.531172: W tensorflow/core/framework/allocator.cc:107] Allocation of 156905472 exceeds 10% of system memory.
2020-01-25 22:24:09.474692: W tensorflow/core/framework/allocator.cc:107] Allocation of 154140672 exceeds 10% of system memory.
/usr/local/lib/python3.7/dist-packages/PIL/TiffImagePlugin.py:788: UserWarning: Corrupt EXIF data.  Expecting to read 4 bytes but only got 0.
  warnings.warn(str(msg))
76/76 - 602s - loss: 0.3851 - acc: 0.9097 - val_loss: 0.1812 - val_acc: 0.9495
Epoch 2/3
76/76 - 616s - loss: 0.1480 - acc: 0.9757 - val_loss: 0.1732 - val_acc: 0.9544
Epoch 3/3
76/76 - 627s - loss: 0.1452 - acc: 0.9760 - val_loss: 0.1767 - val_acc: 0.9516

它说该模型的准确度得分约为95%,这非常好,现在预测图像应该没有问题。但是,这是预测文件:

预报_图片.py

#!/usr/bin/env python

import concurrent.futures
import pandas as pd
import numpy as np
import urllib
import pathlib
import hashlib
import os
import sys
import cv2
import json
import tensorflow as tf
import PIL
import skimage
from tensorflow.keras.preprocessing.image import ImageDataGenerator

IMG_HEIGHT = 128
IMG_WIDTH = 128

def load_image(filename):
    img = tf.keras.preprocessing.image.load_img(filename, target_size=(IMG_WIDTH,IMG_HEIGHT))
    img = tf.keras.preprocessing.image.img_to_array(img)
    img = np.expand_dims(img, axis=0) / 255
    return img

from glob import glob
class_names = glob("persons-cropped/*")
class_names = sorted(class_names)

labels_file = open("labels.json", "r")
labels = json.loads(labels_file.read())

print(labels)

model = tf.keras.models.load_model("model.h5")

model.summary()

img = load_image(sys.argv[1])

predictions = model.predict(img, verbose=1)
prediction = predictions.argmax(axis=-1)

print(predictions)
print(prediction)
map_labels = np.vectorize(lambda i: labels[str(i)])
print(map_labels(prediction))

输出:使用Zach Braff图像时:

{'0': 'Abhishek Bachan', '1': 'Alex Rodriguez', '2': 'Ali Landry', '3': 'Alyssa Milano', '4': 'Anderson Cooper', '5': 'Anna Paquin', '6': 'Audrey Tautou', '7': 'Barack Obama', '8': 'Ben Stiller', '9': 'Christina Ricci', '10': 'Clive Owen', '11': 'Cristiano Ronaldo', '12': 'Daniel Craig', '13': 'Danny Devito', '14': 'David Duchovny', '15': 'Denise Richards', '16': 'Diane Sawyer', '17': 'Donald Faison', '18': 'Ehud Olmert', '19': 'Faith Hill', '20': 'Famke Janssen', '21': 'Hugh Jackman', '22': 'Hugh Laurie', '23': 'James Spader', '24': 'Jared Leto', '25': 'Julia Roberts', '26': 'Julia Stiles', '27': 'Karl Rove', '28': 'Katherine Heigl', '29': 'Kevin Bacon', '30': 'Kiefer Sutherland', '31': 'Kim Basinger', '32': 'Mark Ruffalo', '33': 'Meg Ryan', '34': 'Michelle Trachtenberg', '35': 'Michelle Wie', '36': 'Mickey Rourke', '37': 'Miley Cyrus', '38': 'Milla Jovovich', '39': 'Nicole Richie', '40': 'Rachael Ray', '41': 'Robert Gates', '42': 'Ryan Seacrest', '43': 'Sania Mirza', '44': 'Sarah Chalke', '45': 'Sarah Palin', '46': 'Scarlett Johansson', '47': 'Seth Rogen', '48': 'Shahrukh Khan', '49': 'Shakira', '50': 'Stephen Colbert', '51': 'Stephen Fry', '52': 'Steve Carell', '53': 'Steve Martin', '54': 'Tracy Morgan', '55': 'Ty Pennington', '56': 'Viggo Mortensen', '57': 'Wilmer Valderrama', '58': 'Zac Efron', '59': 'Zach Braff'}
2020-01-25 22:58:05.582049: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2599985000 Hz
2020-01-25 22:58:05.582514: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x618f910 executing computations on platform Host. Devices:
2020-01-25 22:58:05.582653: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): <undefined>, <undefined>
2020-01-25 22:58:06.454565: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1412] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU.  To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
1/1 [==============================] - 1s 999ms/sample
[[3.23683023e-04 6.47217035e-04 3.90201807e-04 2.69789696e-02
  2.17908323e-02 1.53781831e-01 4.79090214e-03 8.64863396e-04
  1.11432403e-01 8.87395382e-01 3.30170989e-03 2.17252970e-03
  1.78458661e-01 1.09243691e-02 1.47551298e-04 2.62927115e-02
  3.22049320e-01 2.69562006e-04 9.11523938e-01 2.44581699e-03
  7.65213370e-03 2.90286541e-03 1.01376325e-01 6.43432140e-05
  4.43832874e-02 3.94093990e-03 6.90050423e-02 7.47233629e-04
  1.05589628e-03 8.04662704e-07 3.76045704e-03 4.28827941e-01
  1.20029151e-02 1.77664489e-01 5.27173281e-04 2.45797634e-03
  5.89579344e-03 9.46103930e-01 2.79089808e-03 2.09265649e-02
  2.83238888e-02 4.86207008e-03 8.15459788e-02 1.30202770e-02
  1.50602162e-02 1.33922696e-03 1.24056339e-02 5.76970875e-02
  2.65627503e-02 5.18084109e-01 4.89562750e-04 3.15269828e-03
  4.88847494e-04 2.13665128e-01 1.40489936e-02 2.93705761e-02
  5.01989722e-02 1.21492555e-03 1.62564263e-01 2.91267484e-01]]
[37]
['Miley Cyrus']

预测算法一直都是错误的。如果我使用其他Zach Braff图像,则当然对于同一张图片,输出将保持不变,但是在10个测试用例中,它不是Zach Braff,而是始终是其他人。(不仅是麦莉·赛勒斯,还有夏奇拉,史蒂夫·卡雷尔,...)

对于我在此处用作输入的任何类,此模式都会出现。

我在互联网上找不到任何有用的建议,并尝试了所有我可以成像的参数组合都可以使用。我还使用了两个版本的Tensorflow,并确保所有库都是最新的。

穆罕默德(Mohamoud Mohamed)

嘿,我相信您会得到奇怪的预测,因为在模型使用设置为二元交叉熵的损失函数编译时,您的数据分布有60类人员。

二进制交叉熵用于确定最多2个类别。您需要做的是将损失函数更改为分类交叉熵。

 model.compile(optimizer=tf.keras.optimizers.Adam(lr=LEARNING_RATE),
 loss="categorical_crossentropy", # here
 metrics=["accuracy"]
)

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

为什么多GPU tensorflow重新训练不起作用

来自分类Dev

为什么刷新模型不起作用?

来自分类Dev

为什么tensorflow模型总是预测同一类?

来自分类Dev

为什么渲染图像不起作用

来自分类Dev

为什么类检查不起作用?

来自分类Dev

为什么自定义系统类加载器不起作用?

来自分类Dev

为什么在保存时更新相关模型中的字段在Django中不起作用?

来自分类Dev

为什么 Laravel 模型中的隐藏字段在动态设置时不起作用?

来自分类Dev

悬停时,为什么带有CSS过渡的其他图像源不起作用?

来自分类Dev

为什么当我将鼠标悬停在图像上时它不起作用?

来自分类Dev

使用 ajax 加载更多图像时,prettyphoto 不起作用

来自分类Dev

为什么使用javascript / jQuery添加类时CSS过渡不起作用?

来自分类Dev

从案例类中提取Map键类型时为什么不起作用?

来自分类Dev

在检查当前类的实例时,为什么“ x instanceof getClass()”不起作用?

来自分类Dev

Spring注释:为什么在类为@Autowired时@Required不起作用

来自分类Dev

为什么我的jquery切换类在单击时不起作用

来自分类Dev

为什么在this.DataContext = this时绑定到我的类的实例不起作用

来自分类Dev

Spring注释:为什么在类为@Autowired时@Required不起作用

来自分类Dev

为什么我的jquery切换类在单击时不起作用

来自分类Dev

在加载时添加类,javascript 不起作用

来自分类Dev

为什么“加载事件在addEventListener()上不起作用”?

来自分类Dev

为什么在图片上加载事件不起作用?

来自分类Dev

Angular:为什么延迟加载装饰器不起作用?

来自分类Dev

tomcat-为什么目录升级对图像不起作用?

来自分类Dev

为什么我在图像上的CSS动画不起作用?

来自分类Dev

为什么我的响应式图像不起作用?

来自分类Dev

tomcat-为什么目录升级对图像不起作用?

来自分类Dev

为什么图像源中的正斜杠不起作用?

来自分类Dev

功能更改图像不起作用,为什么

Related 相关文章

  1. 1

    为什么多GPU tensorflow重新训练不起作用

  2. 2

    为什么刷新模型不起作用?

  3. 3

    为什么tensorflow模型总是预测同一类?

  4. 4

    为什么渲染图像不起作用

  5. 5

    为什么类检查不起作用?

  6. 6

    为什么自定义系统类加载器不起作用?

  7. 7

    为什么在保存时更新相关模型中的字段在Django中不起作用?

  8. 8

    为什么 Laravel 模型中的隐藏字段在动态设置时不起作用?

  9. 9

    悬停时,为什么带有CSS过渡的其他图像源不起作用?

  10. 10

    为什么当我将鼠标悬停在图像上时它不起作用?

  11. 11

    使用 ajax 加载更多图像时,prettyphoto 不起作用

  12. 12

    为什么使用javascript / jQuery添加类时CSS过渡不起作用?

  13. 13

    从案例类中提取Map键类型时为什么不起作用?

  14. 14

    在检查当前类的实例时,为什么“ x instanceof getClass()”不起作用?

  15. 15

    Spring注释:为什么在类为@Autowired时@Required不起作用

  16. 16

    为什么我的jquery切换类在单击时不起作用

  17. 17

    为什么在this.DataContext = this时绑定到我的类的实例不起作用

  18. 18

    Spring注释:为什么在类为@Autowired时@Required不起作用

  19. 19

    为什么我的jquery切换类在单击时不起作用

  20. 20

    在加载时添加类,javascript 不起作用

  21. 21

    为什么“加载事件在addEventListener()上不起作用”?

  22. 22

    为什么在图片上加载事件不起作用?

  23. 23

    Angular:为什么延迟加载装饰器不起作用?

  24. 24

    tomcat-为什么目录升级对图像不起作用?

  25. 25

    为什么我在图像上的CSS动画不起作用?

  26. 26

    为什么我的响应式图像不起作用?

  27. 27

    tomcat-为什么目录升级对图像不起作用?

  28. 28

    为什么图像源中的正斜杠不起作用?

  29. 29

    功能更改图像不起作用,为什么

热门标签

归档