首页 文章

在Tensorflow中,获取图表中所有张量的名称

提问于
浏览
78

我用 Tensorflowskflow 创建神经网络;由于某种原因,我想获得给定输入的一些内部张量的值,所以我使用 myClassifier.get_layer_value(input, "tensorName")myClassifierskflow.estimators.TensorFlowEstimator .

但是,我发现很难找到张量名称的正确语法,即使知道它的名字(我在操作和张量之间感到困惑),所以我使用tensorboard绘制图形并查找名称 .

有没有办法在不使用张量板的情况下枚举图中的所有张量?

8 回答

  • 0

    你可以做

    [n.name for n in tf.get_default_graph().as_graph_def().node]
    

    另外,如果您在IPython笔记本中进行原型设计,可以直接在笔记本中显示图形,请参阅Alexander的Deep Dream中的 show_graph 函数notebook

  • 9

    通过使用get_operations,有一种方法可以比雅罗斯拉夫的答案稍快一些 . 这是一个简单的例子:

    import tensorflow as tf
    
    a = tf.constant(1.3, name='const_a')
    b = tf.Variable(3.1, name='variable_b')
    c = tf.add(a, b, name='addition')
    d = tf.multiply(c, a, name='multiply')
    
    for op in tf.get_default_graph().get_operations():
        print(str(op.name))
    
  • 5

    tf.all_variables() 可以为您提供所需的信息 .

    此外,this commit今天在TensorFlow Learn中提供了在估算器中提供函数 get_variable_names ,您可以使用它来轻松检索所有变量名称 .

  • 3

    接受的答案只会为您提供一个包含名称的字符串列表 . 我更喜欢一种不同的方法,它可以(几乎)直接访问张量:

    graph = tf.get_default_graph()
    list_of_tuples = [op.values() for op in graph.get_operations()]
    

    list_of_tuples 现在包含每个张量,每个张量都在一个元组内 . 您也可以调整它以直接获得张量:

    graph = tf.get_default_graph()
    list_of_tuples = [op.values()[0] for op in graph.get_operations()]
    
  • 141

    我想这也会做:

    print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))
    

    但与萨尔瓦多和雅罗斯拉夫的答案相比,我不知道哪一个更好 .

  • 20

    以前的答案很好,我只是想分享一个我写的选择Tensors的实用函数:

    def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
        """Selects nodes' names in the graph if:
        - The name contains all items in and_conds
        - OR/AND depending on op
        - The name contains any item in or_conds
    
        Condition starting with a "!" are negated.
        Returns all ops if no optional arguments is given.
    
        Args:
            graph (tf.Graph): The graph containing sought tensors
            and_conds (list(str)), optional): Defaults to None.
                "and" conditions
            op (str, optional): Defaults to 'and'. 
                How to link the and_conds and or_conds:
                with an 'and' or an 'or'
            or_conds (list(str), optional): Defaults to None.
                "or conditions"
    
        Returns:
            list(str): list of relevant tensor names
        """
        assert op in {'and', 'or'}
    
        if and_conds is None:
            and_conds = ['']
        if or_conds is None:
            or_conds = ['']
    
        node_names = [n.name for n in graph.as_graph_def().node]
    
        ands = {
            n for n in node_names
            if all(
                cond in n if '!' not in cond
                else cond[1:] not in n
                for cond in and_conds
            )}
    
        ors = {
            n for n in node_names
            if any(
                cond in n if '!' not in cond
                else cond[1:] not in n
                for cond in or_conds
            )}
    
        if op == 'and':
            return [
                n for n in node_names
                if n in ands.intersection(ors)
            ]
        elif op == 'or':
            return [
                n for n in node_names
                if n in ands.union(ors)
            ]
    

    所以如果你有一个带有ops的图表:

    ['model/classifier/dense/kernel',
    'model/classifier/dense/kernel/Assign',
    'model/classifier/dense/kernel/read',
    'model/classifier/dense/bias',
    'model/classifier/dense/bias/Assign',
    'model/classifier/dense/bias/read',
    'model/classifier/dense/MatMul',
    'model/classifier/dense/BiasAdd',
    'model/classifier/ArgMax/dimension',
    'model/classifier/ArgMax']
    

    然后跑

    get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])
    

    收益:

    ['model/classifier/dense/kernel/Assign',
    'model/classifier/dense/bias',
    'model/classifier/dense/bias/Assign',
    'model/classifier/dense/bias/read',
    'model/classifier/dense/MatMul',
    'model/classifier/dense/BiasAdd']
    
  • 0

    这对我有用:

    for n in tf.get_default_graph().as_graph_def().node:
        print('\n',n)
    
  • 3

    由于OP要求提供张量列表而不是操作/节点列表,因此代码应略有不同:

    graph = tf.get_default_graph()    
    tensors_per_node = [node.values() for node in graph.get_operations()]
    tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]
    

相关问题