首页 文章

如何明确地使用矩阵和矩阵运算在MXNet中构建模型?

提问于
浏览
11

我可以使用预构建的高级函数(如 FullyConnected )创建模型 . 例如:

X = mx.sym.Variable('data')
P  = mx.sym.FullyConnected(data = X, name = 'fc1', num_hidden = 2)

通过这种方式,我得到一个符号变量 P ,它取决于符号变量 X . 换句话说,我有计算图,可用于定义模型并执行 fitpredict 等操作 .

现在,我想以不同的方式通过 X 表达 P . 更详细地说,我想使用低级张量运算(如矩阵乘法)和表示模型参数的符号变量(湖泊权重矩阵)来指定 PX "explicitly"之间的关系,而不是使用高级功能(如 FullyConnected ) . ) .

例如为了实现与上述相同,我尝试了以下内容:

W = mx.sym.Variable('W')
B = mx.sym.Variable('B')
P = mx.sym.broadcast_plus(mx.sym.dot(X, W), B)

但是,以这种方式获得的 P 并不等同于之前获得的 P . 我不能以同样的方式使用它 . 特别是,据我所知,MXNet抱怨 WB 没有值(这是有道理的) .

我还尝试以另一种方式声明 WB (这样它们确实有值):

w = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
b = np.array([7.0, 8.0])

W = mx.nd.array(w)
B = mx.nd.array(b)

它不起作用 . 我猜MXNet抱怨因为它需要一个符号变量,但它会得到nd-arrays .

所以,我的问题是如何使用低级张量运算(如矩阵乘法)和表示模型参数的显式对象(如权重矩阵)来构建模型 .

1 回答

  • 5

    你可能想看看Gluon API . 例如,这里是从头开始构建MLP的指南,包括分配参数:

    #######################
    #  Allocate parameters for the first hidden layer
    #######################
    W1 = nd.random_normal(shape=(num_inputs, num_hidden), scale=weight_scale, ctx=model_ctx)
    b1 = nd.random_normal(shape=num_hidden, scale=weight_scale, ctx=model_ctx)
    
    params = [W1, b1, ...]
    

    将它们附加到自动渐变

    for param in params:
        param.attach_grad()
    

    定义模型:

    def net(X):
        #######################
        #  Compute the first hidden layer
        #######################
        h1_linear = nd.dot(X, W1) + b1
        ...
    

    并执行它

    epochs = 10
    learning_rate = .001
    smoothing_constant = .01
    
    for e in range(epochs):
        ...
        for i, (data, label) in enumerate(train_data):
            data = data.as_in_context(model_ctx).reshape((-1, 784))
            label = label.as_in_context(model_ctx)
            ...
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label_one_hot)
            loss.backward()
            SGD(params, learning_rate)
    

    您可以在直涂料中看到完整的示例:

    http://gluon.mxnet.io/chapter03_deep-neural-networks/mlp-scratch.html

相关问题