我的代码:

output3d = Dense(21*3, name='imediate_3d', activation='tanh')(output3d)

reshape_out = Reshape((21, 3), input_shape=(21*3,), name='reshape_to_21_3')(output3d)

def proj_output_shape(shp):
    return (None, 32, 32, 1)

def gaussian(x, mu, sigma):
    # from: https://github.com/keras-team/keras/issues/3720
    return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2))

def make_kernel(sigma):
    # from: https://github.com/keras-team/keras/issues/3720
    # ...
    return kernel

def mysparse_to_dense(sparse_indices,
                    output_shape,
                    sparse_values,
                    default_value=0,
                    validate_indices=True,
                    name=None):
    import tensorflow as tf
    from tensorflow.python.ops.array_ops import fill
    from tensorflow.python.ops import variables
    #values = [sparse_indices, sparse_values]
    default_value = tf.convert_to_tensor(default_value,
                                         dtype=tf.float32)
    tmp = fill(output_shape, default_value)
    print('tmp={0}'.format(tmp))
    filled = variables.Variable(tmp)
    print('filled={0}'.format(filled))
    return tf.scatter_nd_add(filled,
                             indices=sparse_indices,
                             updates=sparse_values, name=name)

def f(x):
    import tensorflow as tf
    print('x.shape={0}'.format(x.shape))
    idx = K.cast(x[:, :, 0:2]*15.5+15.5, "int64")
    print('idx.shape={0}'.format(idx.shape))

    z = mysparse_to_dense(idx, (K.shape(x)[0], 32, 32), 1.0, 0.0, name='sparse_tensor')
    print('z={0}'.format(z))
    #z = tf.sparse_add(tf.zeros(z.dense_shape), z)
    z = K.reshape(z, (K.shape(x)[0], 32, 32, 1))
    fil = make_kernel(1.0)
    fil = K.reshape(fil, (5, 5, 1, 1))

    print('z.shape={0}'.format(z.shape), z)
    print('fil.shape={0}'.format(fil.shape), fil)

    r = K.conv2d(z,kernel=fil, padding='same', data_format="channels_last")
    print('r.shape={0}'.format(r.shape), r)

    return r

proj_out = Lambda(lambda x: f(x), 
                  output_shape=proj_output_shape, name='projection')(reshape_out)

输出:

x.shape=(?, 21, 3)
idx.shape=(?, 21, 2)
tmp=Tensor("projection_7/Fill:0", shape=(?, 32, 32), dtype=float32)

为什么 print('filled={0}'.format(filled)) 永远不会跑?谢谢

UPDATE

batch_size 传入 tf.Variable 时,程序将挂起 . 怎么修?

def f(x):shape = tf.Variable([K.shape(x)[0],32,32])