首页 文章

为Tensor的重量分配一个新值并保存回模型

提问于
浏览
0

TF v1.3 中,我可以使用 tf.train.import_meta_graphrestore(sess,tf.train.latest_checkpoint 恢复模型的元和权重 . 运行 sess.run() 将给出张量值 .

我的问题是我如何能够将新值分配(重写)到张量值并将其保存回模型中以便进一步处理 . 说,我有给定图层的这些值:

>>>(sess.run('MobilenetV1/Conv2d_0/weights:0'))
...
1.55560642e-01,   1.29789323e-01,   2.59163193e-02,
8.00046027e-02,   4.73752208e-02,  -5.41094005e-01,
-8.93476382e-02,  -9.48717445e-02]]]], dtype=float32)

如何使用tf.assign()为最后8个打印值分配不同的值并将其保存回检查点 .

1 回答

  • 1

    也许......这样可以为你服务:

    import tensorflow as tf
    import numpy as np
    tf.reset_default_graph()
    
    
    with tf.variable_scope('scope1') as scope:
        w = tf.get_variable('w', shape=[4,4])
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        np_w = sess.run(w)
        print(np_w)
        np_w[2:,2:] = np.ones((2,2))
        print("  .  '"*10)
        print(np_w)
        print("  .  '"*10)
        with tf.variable_scope(scope, reuse=True):
            v=tf.get_variable('w')
            sess.run(tf.assign(w, np_w))
        print(sess.run(w))
    

    作为一个例子,我用 get_variable 创建了一个随机的4x4矩阵,并将子矩阵重新分配给1 . 希望这可以帮助 .

    EDIT

    要从您保存的模型中访问变量w,然后为其分配一个新值:

    w = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == 'w:0'][0]
    

    现在你可以像我上面推荐的那样继续黑客攻击 . w应该与已恢复模型中的变量具有相同的名称,例如 MobilenetV1/Conv2d_0/weights .

相关问题