我正在尝试使用Keras实现V-net . 这个过程非常简单(感谢Keras),但我遇到了一些问题 . Vnet是FCNN(完全卷积神经网络),因此理论上输入的大小可以变化,只涉及卷积 . 然而,在收缩/缩小/卷积路径和解除/上/解卷积路径之间跳过层(长连接)的需要使我有必要选择固定的输入大小 .

然而,我已经尝试根据2个连接层中的差异动态地实现Lambda层图层裁剪 .

这是我的实现假设带有“channels_last”数据格式的3d输入 .

def CropToConcat3D():
"""
   Layer cropping bigger to smaller with concatenation.
"""

    def crop_to_3D(inputs):
    """
        inputs = bigger_input,smaller_input
        Crop bigger input to smaller input to
        have same dimension.
    """
        bigger_input,smaller_input = inputs
        bigger_input_size=bigger.get_shape().as_list()
        smaller_input_size=smaller.get_shape().as_list()
        _,bh,bw,bd,_ = bigger_input_size
        _,sh,sw,sd,_ = smaller_input_size
        if (bh is None) and (bw is None) and (bd is None):
            cropped_to_smaller_input = bigger_input
        else:
            cropped_to_smaller_input = bigger_input
            dh,dw,dd = bh-sh, bw-sw, bd-sd
            q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
            cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
                                                      q2dw:bw-(q2dw+r2dw),
                                                      q2dd:bd-(q2dd+r2dd),:]
        return K.concatenate([smaller_input,cropped_to_smaller_input],
                           axis=-1)

    return Lambda(crop_to_3D)

对于None的情况,有必要在编译模型时处理变量大小,但是当使用fit,predict或其他方法运行模型时,忽略else分支 .

这些图层使用如下,并且应该从张量动态获取尺寸(我使用张量流后端_keras_shape应该与其他后端一起工作) .

outputs = CropToConcat3D()([bigger,smaller])

所以代码似乎在运行,但是else分支总是被忽略,在Keras中是否有一组编译的属性或指令,我没有考虑到这一点?我还检查了声明一个固定输入,在这种情况下,访问了else分支,但如果我给出一个不同形状的输入,输入层会抛出一个错误,因为输入不符合模型尺寸 .

我尝试了在何时/何地获取/存储维度的代码变体,但行为始终是相同的 .

谢谢大家的见解 .

解决了

问题是否则计算图中没有考虑其他语句 . 使用tf.cond解决了这个问题 . 网必须运行一次这个指令每个前馈操作只一次(语句是编译我想到其中一个分支)这是非常难看,但工作 .

def CropToConcat3D():
"""
    inputs = bigger_input,smaller_input
    Crop bigger input to smaller input to
    have same dimension.
"""

def control_copy_crop3D(inputs):
    bigger_input,smaller_input = inputs
    def simple_concat_3D():
        return K.concatenate([bigger_input,smaller_input], axis=-1)
    def crop_to_concat_3D():
        bigger_shape, smaller_shape = tf.shape(bigger_input), \
                                      tf.shape(smaller_input)
        sh,sw,sd = smaller_shape[1],smaller_shape[2],smaller_shape[3]
        bh,bw,bd = bigger_shape[1],bigger_shape[2],bigger_shape[3]
        dh,dw,dd = bh-sh, bw-sw, bd-sd
        q2dh,r2dh,q2dw,r2dw,q2dd,r2dd = dh//2,dh%2,dw//2,dw%2,dd//2,dd%2
        cropped_to_smaller_input = bigger_input[:,q2dh:bh-(q2dh+r2dh),
                                                  q2dw:bw-(q2dw+r2dw),
                                                  q2dd:bd-(q2dd+r2dd),:]
        return K.concatenate([smaller_input,cropped_to_smaller_input], axis=-1)

    smaller_shape = tf.shape(smaller_input)
    sh = smaller_shape[1]
    return tf.cond(tf.Variable(sh is None,dtype=tf.bool),simple_concat_3D,
                   crop_to_concat_3D)

return Lambda(control_copy_crop3D)