我的代码:

def f(x):
    import tensorflow as tf
    print('x.shape={0}'.format(x.shape))
    idx = K.cast(x[:, :, 0:2]*15.5+15.5, "int64")
    print('idx.shape={0}'.format(idx.shape))
    st_z = tf.SparseTensor(idx, values=0.0, dense_shape=[K.shape(x)[0], 32, 32])

输出:

x.shape=(?, 21, 3)
idx.shape=(?, 21, 2)

错误:

ValueError: Shape (?, 21, 2) must have rank 2

如何解决这个问题?谢谢