有许多在线教程指导如何使用预训练模型在Keras中预测一次图像 . 对于我的情况,我在Keras中使用VGG16模型,我需要连续预测图像,所以我使用for循环来加载图像,然后将其传递给预测函数,它运行良好,但是一个预测时间太长(我的约800ms)机器,仅限CPU),这里是代码:
# one full prediction function cost 800ms
def predict(image):
image = img_to_array(image)
image = image / 255
image = np.expand_dims(image, axis=0)
# build the VGG16 network, this line code cost 400~500ms
model = keras.applications.VGG16(include_top=True, weights='imagenet')
# Do predication
prediction = model.predict(image)
'''
Process predication results
'''
'''
some preprocess
'''
for img in imgs_list:
predict(img)
上面的代码可以很好地工作,但每次预测都花费太多时间,整个功能需要800ms,而构建VGG网络需要500ms,成本太高 . 我想为连续预测模式的每个预测删除这500ms .
我尝试将“model = keras.applications.VGG16(include_top = True,weights ='imagenet')”此行代码放到预测函数的外部,全局定义或传递“model”作为函数的参数,但程序将返回错误并在第一次成功预测后结束 .
Traceback (most recent call last):
File "/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1075, in _run
subfeed, allow_tensor=True, allow_operation=False)
File "/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3590, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3669, in _as_graph_element_locked
raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) is not an element of this graph.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "multi_classifier.py", line 256, in
predict(current_file_path)
File "multi_classifier.py", line 184, in predict
bottleneck_prediction = model_1.predict(image)
File "/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py", line 1835, in predict
verbose=verbose, steps=steps)
File "/home/zi/venv/lib/python3.5/site-packages/keras/engine/training.py", line 1331, in _predict_loop
batch_outs = f(ins_batch)
File "/home/zi/venv/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py", line 2482, in __call__
**self.session_kwargs)
File "/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/home/zi/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1078, in _run
'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input_1:0", shape=(?, ?, ?, 3), dtype=float32) is not an element of this graph.
看起来我需要为每个预测实例化一个VGG模型,如何更改代码以节省模型构建时间?谢谢 .