我正在使用Keras使用下面的代码为变分自动编码器的解码器进行采样 . 令我困惑的是为什么定义Lambda函数的行执行两次?

请注意:

  • 在下面的代码中,我把'if(i_b == 2):'只是为了更容易的临时调试 .

  • 解码器定义之前的行't = Lambda(...'执行两次,如上所述 .

这是一个示例输出:

epsilon:Tensor("lambda_21/Print:0",shape =(3,3),dtype = float32)
epsilon:Tensor("lambda_21/Print_1:0",shape =(3,3),dtype = float32)

batch_size = 100
original_dim = 784
latent_dim = 3
intermediate_dim = 256 # Size of the hidden layer.
epochs = 3

x = Input(batch_shape=(batch_size, original_dim))
def create_encoder(input_dim):
    # Encoder network.
    encoder = Sequential(name='encoder')
    encoder.add(InputLayer([input_dim]))
    encoder.add(Dense(intermediate_dim, activation='relu'))
    encoder.add(Dense(2 * latent_dim))
    return encoder

encoder = create_encoder(original_dim)

get_t_mean = Lambda(lambda h: h[:, :latent_dim])
get_t_log_var = Lambda(lambda h: h[:, latent_dim:])
h = encoder(x)
t_mean = get_t_mean(h)
t_log_var = get_t_log_var(h)

def sampling(args):
    t_mean, t_log_var = args

    t_var = tf.exp(t_log_var)
    batch_size = int(t_mean.shape[0])
    latent_dim = int(t_mean.shape[1])

    sampled_list = []

    epsilon_distribution = tf_dist.MultivariateNormalDiag(np.zeros(latent_dim, dtype=np.float32), np.identity(latent_dim, dtype=np.float32))

    for i_b in range(batch_size):
        epsilon = epsilon_distribution.sample()
        # epsilon = np.random.multivariate_normal(np.zeros(latent_dim), np.identity(latent_dim))

        if(i_b == 2):
            epsilon = tf.Print(epsilon, [epsilon], 'epsilon: ')
            print('epsilon: ', epsilon)
            break

        sampled_list.append(t_mean[i_b] + epsilon * t_var[i_b])

    with tf.variable_scope("sampling", reuse=tf.AUTO_REUSE):
        sampled_tensor = tf.stack(values=sampled_list, name='sampled_tensor')

    return sampled_tensor

t = Lambda(sampling)([t_mean, t_log_var])

def create_decoder(input_dim):
    # Decoder network
    decoder = Sequential(name='decoder')
    decoder.add(InputLayer([input_dim]))
    decoder.add(Dense(intermediate_dim, activation='relu'))
    decoder.add(Dense(original_dim, activation='sigmoid'))
    return decoder

decoder = create_decoder(latent_dim)
x_decoded_mean = decoder(t)