我在tensorflow中写了alexnet来执行mnist数据集 . 我得到一个ValueErorr说:通过输入形状为'pool5'(op:'MaxPool')从1中减去2得到的负尺寸大小:[?,1,1,1024] . 怎么解决?这是我的代码:

from __future__ import print_function

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
import os
import random
import matplotlib.pyplot as plt
import numpy as np

# Parameters
learning_rate = 0.001
training_iters = 100000
batch_size = 1000
display = True
display_step_console = 5
learn_from_scratch = False
train = False

# Network Parameters
n_input = 784 # MNIST data input (img shape: 28*28)
n_classes = 10 # MNIST total classes (0-9 digits)
dropout = 0.80 # Dropout, probability to keep units

# tf Graph input
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
keep_prob = tf.placeholder(tf.float32) #dropout (keep probability)

# Create AlexNet model
def conv1st(name, l_input, w, b):
    cov = tf.nn.conv2d(l_input, w, strides=[1, 4, 4, 1], padding='VALID')
    return tf.nn.relu(tf.nn.bias_add(cov,b), name=name)

def conv2d(name, l_input, w, b):
    cov = tf.nn.conv2d(l_input, w, strides=[1, 1, 1, 1], padding='SAME')
    return tf.nn.relu(tf.nn.bias_add(cov,b), name=name)

def max_pool(name, l_input, k, s):
    return tf.nn.max_pool(l_input, ksize=[1, k, k, 1], strides=[1, s, s, 1], 
                          padding='VALID', name=name)

def norm(name, l_input, lsize=4):
    return tf.nn.lrn(l_input, lsize, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 
                     name=name)

def alex_net(_X, weights, biases, _dropout):
    # Reshape input picture OH WAIT NOPE CUS JE SUIS UN TENSAI DESU
    _X = tf.reshape(_X, shape=[-1, 28, 28, 1])

    # Convolution Layer
    conv1 = conv1st('conv1', _X, _weights['wc1'], _biases['bc1'])

    # Max Pooling (down-sampling)
    pool1 = max_pool('pool1', conv1, k=2, s=2)
    # Apply Normalization
    norm1 = norm('norm1', pool1, lsize=4)
    # Apply Dropout
    norm1 = tf.nn.dropout(norm1, _dropout)

    # Convolution Layer
    conv2 = conv2d('conv2', norm1, _weights['wc2'], _biases['bc2'])
    # Max Pooling (down-sampling)
    pool2 = max_pool('pool2', conv2, k=2, s=2)
    # Apply Normalization
    norm2 = norm('norm2', pool2, lsize=4)
    # Apply Dropout
    norm2 = tf.nn.dropout(norm2, _dropout)

    # Convolution Layer
    conv3 = conv2d('conv3', norm2, _weights['wc3'], _biases['bc3'])
    conv4 = conv2d('conv4', conv3, _weights['wc4'], _biases['bc4'])
    conv5 = conv2d('conv5', conv4, _weights['wc5'], _biases['bc5'])
    # Max Pooling (down-sampling)
    pool5 = max_pool('pool5', conv5, k=2, s=2)
    # Apply Normalization
    norm5 = norm('norm5', pool5, lsize=4)
    # Apply Dropout
    norm5 = tf.nn.dropout(norm5, _dropout)

    # Fully connected layer
    dense1 = tf.reshape(norm5, [-1, _weights['wd1'].get_shape().as_list()
                       [0]]) # Reshape conv3 output to fit dense layer input
    dense1 = tf.nn.relu(tf.matmul(dense1, _weights['wd1']) + _biases['bd1'], 
                        name='fc1') # Relu activation

    dense2 = tf.nn.relu(tf.matmul(dense1, _weights['wd2']) + _biases['bd2'], 
                        name='fc2') # Relu activation

    # Output, class prediction
    out = tf.matmul(dense2, _weights['out']) + _biases['out']
    return out

# Store layers weight & bias
_weights = {
        'wc1': tf.Variable(tf.random_normal([11, 11, 1, 96])),
        'wc2': tf.Variable(tf.random_normal([5, 5, 96, 256])),
        'wc3': tf.Variable(tf.random_normal([3, 3, 256, 512])),
        'wc4': tf.Variable(tf.random_normal([3, 3, 512, 1024])),
        'wc5': tf.Variable(tf.random_normal([3, 3, 1024, 1024])),
        #'wd1': tf.Variable(tf.random_normal([8*8*256, 1024])),
        'wd1': tf.Variable(tf.random_normal([6*6*256, 3072])),
        'wd2': tf.Variable(tf.random_normal([3072, 4096])),
        'out': tf.Variable(tf.random_normal([4096, n_classes]))
}

_biases = {
    'bc1': tf.Variable(tf.random_normal([96])),
    'bc2': tf.Variable(tf.random_normal([256])),
    'bc3': tf.Variable(tf.random_normal([512])),
    'bc4': tf.Variable(tf.random_normal([1024])),
    'bc5': tf.Variable(tf.random_normal([1024])),
    'bd1': tf.Variable(tf.random_normal([3072])),
    'bd2': tf.Variable(tf.random_normal([4096])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}

############### NOT SO IMPORTANT ANYMORE###################################

# Construct model
pred = alex_net(x, _weights, _biases, keep_prob)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, 
labels=y))
optimizer = 
tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

#Create summary scalars and operation
n1 = tf.summary.scalar("cost", cost)
n2 = tf.summary.scalar("accuracy", accuracy)
train_summary_op = tf.summary.merge([n1,n2])

#Do writer
log_dir = "./alexnet-classification-model-checkpoints/summary"
train_writer = tf.summary.FileWriter(log_dir+'/train', 
graph=tf.get_default_graph())


# Initializing the variables
init = tf.global_variables_initializer()

saver = tf.train.Saver(max_to_keep=1) 

initial_step = 0

# Launch the graph
with tf.Session() as sess:
    if learn_from_scratch == False:
        if os.path.isfile('./alexnet-classification-model-
checkpoints/step.txt'):
            with open("alexnet-classification-model-checkpoints/step.txt", 
"r") as file:
                step = file.read()
                print(step)
                initial_step = int(step)


        if os.path.isfile('./alexnet-classification-model-
checkpoints/checkpoint') and os.path.isfile('./alexnet-classification-model-
checkpoints/my-model.ckpt.meta'):
            saver = tf.train.import_meta_graph('alexnet-classification-
model-checkpoints/my-model.ckpt.meta')    
        saver.restore(sess, 'alexnet-classification-model-checkpoints/my-
model.ckpt')
            print("Loaded model successfully!")
        else:
            print("A saved model wasn't found, starting from scratch")
            sess.run(init)
    else:
        sess.run(init)



    if train:                
        print("Starting training!")

        step = 1
        # Keep training until reach max iterations
        while step * batch_size <= training_iters:
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            # Run optimization op (backprop)
            sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,
                                           keep_prob: dropout})

            if step % display_step_console == 0:
                if display:
                    batch_x_eval, batch_y_eval = mnist.train.next_batch(500, 
shuffle=True)

                    # Calculate batch loss and accuracy
                    loss, acc, summary = sess.run([cost, accuracy, 
train_summary_op], feed_dict={x: batch_x_eval,

y: batch_y_eval,

keep_prob: 1.0})

                    train_writer.add_summary(summary, global_step=((step + 
initial_step)*batch_size))

                    print("Iter " + str((step + initial_step)*batch_size) + 
", Minibatch Loss= " + \
                          "{:.6f}".format(loss) + ", Training Accuracy= " + 
\
                          "{:.5f}".format(acc))
                else: 
                    print("Still training... 
{}%".format(round((step*batch_size / training_iters)*100), 2))
            step += 1
        print("Optimization Finished!")

        savePath = saver.save(sess, 'alexnet-classification-model-
checkpoints/my-model.ckpt')
        with open("alexnet-classification-model-checkpoints/step.txt", "w") as file:
        file.write(str(initial_step+step))
    print("Saved!")




# Calculate accuracy for 256 mnist test images
print("Testing Accuracy:", \
    sess.run(accuracy, feed_dict={x: mnist.test.images,
                                  y: mnist.test.labels,
                                  keep_prob: 1.}))


num = random.randint(0, mnist.test.images.shape[0])
img = mnist.test.images[num]


cls = sess.run(tf.argmax(conv_net(img, weights, biases, dropout), 1))
cls2 = mnist.test.labels[num]

plt.imshow(img.reshape(28, 28), cmap=plt.cm.binary)
print ('NN predicted', cls, np.argmax(cls2))
plt.show()