首页 文章

使用Tensorflow Estimator API和SemSeg的图像

提问于
浏览
1

我尝试使用Tensorflow Estimator API实现我的模型 . 作为我使用的数据输入功能

def input_fn():
    train_in_np = sorted(io_utils.loadDataset(join(basepath, r"leftImg8bit/train/*/*")))
    train_out_np = sorted(io_utils.loadDataset(join(basepath, r"gtFine/train/*/*_ignoreLabel.png")))

    train_in = tf.constant(train_in_np)
    train_out = tf.constant(train_out_np)

    tr_data = tf.data.Dataset.from_tensor_slices((train_in, train_out))
    tr_data = tr_data.shuffle(len(train_in_np))
    tr_data = tr_data.repeat(epoch_cnt+1)

    tr_data = tr_data.apply(tf.contrib.data.map_and_batch(parse_files, batch_size=batchSize))
    tr_data = tr_data.prefetch(buffer_size=batchSize)

    iterator = tr_data.make_initializable_iterator()
    return iterator.get_next()

io_utils.loadDataset只返回文件路径列表 . 数据本身由解析

def parse_files(in_file, gt_file):
    image_in = tf.read_file(in_file)
    image_in = tf.image.decode_image(image_in, channels=3)
    image_in.set_shape([None, None, 3])
    image_in = tf.cast(image_in, tf.float32)

    mean, std = tf.nn.moments(image_in, [0, 1])
    image_in = image_in - mean
    image_in = image_in / std

    gt = tf.read_file(gt_file)
    gt = tf.image.decode_image(gt, channels=1)    
    gt.set_shape([None, None, 1])
    gt = tf.cast(gt, tf.int32)

    return {'img':image_in}, gt

我的估计开始于

def estimator_fcn_model_fn(features, labels, mode, params):
    x = tf.feature_column.input_layer(features, params['feature_columns'])

并且要素列定义为

my_feature_columns = []
my_feature_columns.append(tf.feature_column.numeric_column(key='img'))

我跳过的其余代码的可清除性 . 我的问题是功能的形状:

Blockquote ValueError :('排名不支持卷积',2)

x,features和feature_columns的打印输出:

Tensor(“input_layer / concat:0”,shape =(?,1),dtype = float32){'img':tf.Tensor'IteratorGetNext:0'shape =(?,?,?,3)dtype = float32 } [_NumericColumn(key ='img',shape =(1,),default_value = None,dtype = tf.float32,normalizer_fn = None)]

有谁知道如何解决这个问题,我猜它与特征列的类型有关,但我不知道将谁应用于图像 .

1 回答

  • 0

    由于Y.Luo给出了一个提示tf.feature_column.input_layer不适合图像 . 一种更简单的方法是通过密钥直接使用特征字典,这可以通过参数传递以获得更大的灵活性 .

    x = features[params['input_name']]
    

    代替

    x = tf.feature_column.input_layer(features, params['feature_columns'])
    

相关问题