首页 文章

'Cannot interpret feed_dict key as Tensor: ' e.args [0])

提问于
浏览
1

我使用code重新训练了模型 . 然后在克隆之后遵循此repo的指令 . 替换了新生成的 labels.txtgraph.pb 文件 . 发布图像以使用以下代码进行分类时,

MAX_K = 10

TF_GRAPH = "{base_path}/inception_model/graph.pb".format(
    base_path=os.path.abspath(os.path.dirname(__file__)))
TF_LABELS = "{base_path}/inception_model/labels.txt".format(
    base_path=os.path.abspath(os.path.dirname(__file__)))


def load_graph():
    sess = tf.Session()
    with tf.gfile.FastGFile(TF_GRAPH, 'rb') as tf_graph:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(tf_graph.read())
        tf.import_graph_def(graph_def, name='')
    label_lines = [line.rstrip() for line in tf.gfile.GFile(TF_LABELS)]
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    return sess, softmax_tensor, label_lines


SESS, GRAPH_TENSOR, LABELS = load_graph()


@csrf_exempt
def classify_api(request):
    data = {"success": False}

    if request.method == "POST":
        tmp_f = NamedTemporaryFile()

    if request.FILES.get("image", None) is not None:
        image_request = request.FILES["image"]
        image_bytes = image_request.read()
        image = Image.open(io.BytesIO(image_bytes))
        image.save(tmp_f, image.format)
    elif request.POST.get("image64", None) is not None:
        base64_data = request.POST.get("image64", None).split(',', 1)[1]
        plain_data = b64decode(base64_data)
        tmp_f.write(plain_data)

    classify_result = tf_classify(tmp_f, int(request.POST.get('k', MAX_K)))
    tmp_f.close()

    if classify_result:
        data["success"] = True
        data["confidence"] = {}
        for res in classify_result:
            data["confidence"][res[0]] = float(res[1])

return JsonResponse(data)


def tf_classify(image_file, k=MAX_K):
    result = list()

    image_data = tf.gfile.FastGFile(image_file.name, 'rb').read()

    predictions = SESS.run(GRAPH_TENSOR, {'DecodeJpeg/contents:0': image_data})
    predictions = predictions[0][:len(LABELS)]
    top_k = predictions.argsort()[-k:][::-1]
    for node_id in top_k:
        label_string = LABELS[node_id]
        score = predictions[node_id]
        result.append([label_string, score])

    return result

然后它显示以下错误 .

TypeError: Cannot interpret feed_dict key as Tensor: The name 'DecodeJpeg/contents:0' refers to a Tensor which does not exist. The operation, 'DecodeJpeg/contents', does not exist in the graph.

1 回答

  • 0

    你的问题就在这条线上:

    predictions = SESS.run(GRAPH_TENSOR, {'DecodeJpeg/contents:0': image_data})
    

    feed_dict 字典中的键应该是张量,而不是字符串 . 您可以先按名称查找张量:

    data_tensor = tf.get_default_graph().get_tensor_by_name('DecodeJpeg/contents:0')
    predictions = SESS.run(GRAPH_TENSOR, {data_tensor: image_data})
    

相关问题