我正在寻找一种简单的方法来从Tensorflow中的当前张量中删除一组张量,并且我有一个难以合理的解决方案 .
例如,假设我有以下当前张量:
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
我想删除这个张量中的两个项目(2.0和5.0) .
在创建之后将这个张量转换为[1.0,3.0,4.0,6.0]的最佳方法是什么?
提前谢谢了 .
您可以调用 tf.unstack 以获取子张量的列表 . 然后,您可以修改列表并调用 tf.stack 从列表中构造张量 . 例如,以下代码从以下位置删除[2.0,5.0]列:
tf.unstack
tf.stack
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') a_vecs = tf.unstack(a, axis=1) del a_vecs[1] a_new = tf.stack(a_vecs, 1)
另一种方法是使用分割或切片功能 . 如果张量很大,这将是有用的 .
Method 1: 利用split功能 .
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') split1, split2, split3 = tf.split(a, [1, 1, 1], 1) a_new = tf.concat([split1, split3], 1)
Method 2: 利用slice功能 .
slice1 = tf.slice(a, [0, 0], [2, 1]) slice2 = tf.slice(a, [0, 2], [2, 1]) a_new = tf.concat([slice1, slice2], 1)
在这两种情况下,a_new都会有
[[ 1. 3.] [ 4. 6.]]
2 回答
您可以调用
tf.unstack
以获取子张量的列表 . 然后,您可以修改列表并调用tf.stack
从列表中构造张量 . 例如,以下代码从以下位置删除[2.0,5.0]列:另一种方法是使用分割或切片功能 . 如果张量很大,这将是有用的 .
Method 1: 利用split功能 .
Method 2: 利用slice功能 .
在这两种情况下,a_new都会有