首页 文章

如何在TensorFlow中将张量转换为numpy数组?

提问于
浏览
104

在使用带有Python绑定的Tensorflow时,如何将张量转换为numpy数组?

6 回答

  • 67

    你需要:

    • 将图像张量以某种格式(jpeg,png)编码为二进制张量

    • 评估(运行)会话中的二进制张量

    • 将二进制文件转换为流

    • 提供给PIL图像

    • (可选)使用matplotlib显示图像

    码:

    import tensorflow as tf
    import matplotlib.pyplot as plt
    import PIL
    
    ...
    
    image_tensor = <your decoded image tensor>
    jpeg_bin_tensor = tf.image.encode_jpeg(image_tensor)
    
    with tf.Session() as sess:
        # display encoded back to image data
        jpeg_bin = sess.run(jpeg_bin_tensor)
        jpeg_str = StringIO.StringIO(jpeg_bin)
        jpeg_image = PIL.Image.open(jpeg_str)
        plt.imshow(jpeg_image)
    

    这对我有用 . 你可以在ipython笔记本中试试 . 只是不要忘记添加以下行:

    %matplotlib inline
    
  • 1

    也许你可以尝试,这种方法:

    import tensorflow as tf
    W1 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    array = W1.eval(sess)
    print (array)
    
  • 0

    Session.runeval 返回的任何张量都是NumPy数组 .

    >>> print(type(tf.Session().run(tf.constant([1,2,3]))))
    <class 'numpy.ndarray'>
    

    要么:

    >>> sess = tf.InteractiveSession()
    >>> print(type(tf.constant([1,2,3]).eval()))
    <class 'numpy.ndarray'>
    

    或者,等效地:

    >>> sess = tf.Session()
    >>> with sess.as_default():
    >>>    print(type(tf.constant([1,2,3]).eval()))
    <class 'numpy.ndarray'>
    

    EDIT: Session.runeval() 返回的任何张量都不是NumPy数组 . 例如,稀疏张量作为SparseTensorValue返回:

    >>> print(type(tf.Session().run(tf.SparseTensor([[0, 0]],[1],[1,2]))))
    <class 'tensorflow.python.framework.sparse_tensor.SparseTensorValue'>
    
  • 4

    要从张量转换回numpy数组,您只需在转换张量上运行 .eval() 即可 .

  • 88

    我已经面对并解决了张量 - > ndarray转换的特定情况下的张量代表(对抗)图像,用cleverhans库/教程获得 .

    我认为我的问题/答案(here)也可能是其他案例的一个有用的例子 .

    我是TensorFlow的新手,我的经验结论是:

    为了成功,似乎tensor.eval()方法可能还需要输入占位符的值 . Tensor可以像需要其输入值(提供到 feed_dict )中的函数一样工作,以便返回输出值,例如,

    array_out = tensor.eval(session=sess, feed_dict={x: x_input})
    

    请注意,在我的情况下,占位符名称是 x ,但我想您应该找到输入占位符的正确名称 . x_input 是包含输入数据的标量值或数组 .

    在我的情况下,也提供 sess 是强制性的 .

    我的例子也涵盖了matplotlib图像可视化部分,但这是OT .

  • 0

    一个简单的例子可能是,

    import tensorflow as tf
        import numpy as np
        a=tf.random_normal([2,3],0.0,1.0,dtype=tf.float32)  #sampling from a std normal
        print(type(a))
        #<class 'tensorflow.python.framework.ops.Tensor'>
        tf.InteractiveSession()  # run an interactive session in Tf.
    

    n现在,如果我们想要将张量a转换为numpy数组

    a_np=a.eval()
        print(type(a_np))
        #<class 'numpy.ndarray'>
    

    就如此容易!

相关问题