我试图实现CNN以正确地检测 322x322 二进制图像中的四边形的边缘,即黑色背景图像上的白色四边形,例如,

网络输出被设计为表示夹在一起的x和y,因此对于来自数据集的任何给定图像,应激发/激发四个( 4 )对应的x和y神经元对,以便我的训练标签是一个热门标签如果第一个四边形具有以下边缘( x,y ),那么我的图像的每个xy边对的编码(多类分类)是10x10,其中可能的(x,y) - 值来自[ 0-9 ] =>( 12 )( 38 )( 43 )( 57 )我的标签就像

[image] [
enter image description here
] 2

对于每个四边形 . 这是使用deeplearning4j在java中的网络配置 .

MultiLayerConfiguration EDGE_PERCEPTRON = new NeuralNetConfiguration.Builder()
                    .seed(seed)
                    .l2(0.0005)
                    .weightInit(WeightInit.XAVIER)
                    .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, lrSchedule)))
                    .list()
                    .layer(0, new ConvolutionLayer.Builder(8, 8)
                            .stride(2, 2)
                            .padding(0, 0)
                            .nIn(channels)
                            .nOut(filtersize10)
                            .activation(Activation.IDENTITY)
                            .build())
                    .layer(1, new ConvolutionLayer.Builder(4, 4)
                            .stride(2, 2)
                            .padding(0, 0)
                            .nOut(filtersize20)
                            .activation(Activation.IDENTITY)
                            .build())
                    .layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                            .kernelSize(2, 2)
                            .stride(1, 1)
                            .padding(0, 0)
                            .build())
                    .layer(3, new DenseLayer.Builder().activation(Activation.SIGMOID)
                            .nOut(outputSize).build())
                    .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.XENT)
                            .nOut(outputSize)
                            .activation(Activation.SIGMOID)
                            .build())
                    .setInputType(InputType.convolutionalFlat(height, width, channels))
                    .backprop(true).pretrain(false).build();

这是它的图形表示

我的配置有什么问题,或者你认为我做错了什么 . 网络没有从数据集中学到任何东西,预测都是统一的,都小于0.1 . 如何正确配置网络以提供正确的预测?