本文最后更新于 426 天前,其中的信息可能已经过时,如有错误请发送邮件到 wuxianglongblog@163.com
Theano 实例:人工神经网络
神经网络的模型可以参考 UFLDL 的教程,这里不做过多描述。
http://ufldl.stanford.edu/wiki/index.php/%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C
import theano import theano.tensor as T import numpy as np from load import mnist
Using gpu device 1: Tesla K10.G2.8GB (CNMeM is disabled)
我们在这里使用一个简单的三层神经网络:输入 - 隐层 - 输出。
对于网络的激活函数,隐层用 sigmoid
函数,输出层用 softmax
函数,其模型如下:
def model(X, w_h, w_o): """ input: X: input data w_h: hidden unit weights w_o: output unit weights output: Y: probability of y given x """ # 隐层 h = T.nnet.sigmoid(T.dot(X, w_h)) # 输出层 pyx = T.nnet.softmax(T.dot(h, w_o)) return pyx
使用随机梯度下降的方法进行训练:
def sgd(cost, params, lr=0.05): """ input: cost: cost function params: parameters lr: learning rate output: update rules """ grads = T.grad(cost=cost, wrt=params) updates = [] for p, g in zip(params, grads): updates.append([p, p - g * lr]) return updates
对于 MNIST
手写数字的问题,我们使用一个 784 × 625 × 10
即输入层大小为 784
,隐层大小为 625
,输出层大小为 10
的神经网络来模拟,最后的输出表示数字为 0
到 9
的概率。
为了对权重进行更新,我们需要将权重设为 shared 变量:
def floatX(X): return np.asarray(X, dtype=theano.config.floatX) def init_weights(shape): return theano.shared(floatX(np.random.randn(*shape) * 0.01))
因此变量初始化为:
X = T.matrix() Y = T.matrix() w_h = init_weights((784, 625)) w_o = init_weights((625, 10))
模型输出为:
py_x = model(X, w_h, w_o)
预测的结果为:
y_x = T.argmax(py_x, axis=1)
模型的误差函数为:
cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))
更新规则为:
updates = sgd(cost, [w_h, w_o])
定义训练和预测的函数:
train = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True) predict = theano.function(inputs=[X], outputs=y_x, allow_input_downcast=True)
训练:
导入 MNIST 数据:
trX, teX, trY, teY = mnist(onehot=True)
训练 100 轮,正确率为 0.956:
for i in range(100): for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)): cost = train(trX[start:end], trY[start:end]) print "{0:03d}".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))
000 0.7028 001 0.8285 002 0.8673 003 0.883 004 0.89 005 0.895 006 0.8984 007 0.9017 008 0.9047 009 0.907 010 0.9089 011 0.9105 012 0.9127 013 0.914 014 0.9152 015 0.9159 016 0.9169 017 0.9173 018 0.918 019 0.9185 020 0.919 021 0.9197 022 0.9201 023 0.9205 024 0.9206 025 0.9212 026 0.9219 027 0.9228 028 0.9228 029 0.9229 030 0.9236 031 0.9244 032 0.925 033 0.9255 034 0.9263 035 0.927 036 0.9274 037 0.9278 038 0.928 039 0.9284 040 0.9289 041 0.9294 042 0.9298 043 0.9302 044 0.9311 045 0.932 046 0.9325 047 0.9332 048 0.934 049 0.9347 050 0.9354 051 0.9358 052 0.9365 053 0.9372 054 0.9377 055 0.9385 056 0.9395 057 0.9399 058 0.9405 059 0.9411 060 0.9416 061 0.9422 062 0.9427 063 0.9429 064 0.9431 065 0.9438 066 0.9444 067 0.9446 068 0.9449 069 0.9453 070 0.9458 071 0.9462 072 0.9469 073 0.9475 074 0.9474 075 0.9476 076 0.948 077 0.949 078 0.9497 079 0.95 080 0.9503 081 0.9507 082 0.9507 083 0.9515 084 0.9519 085 0.9521 086 0.9523 087 0.9529 088 0.9536 089 0.9538 090 0.9542 091 0.9545 092 0.9544 093 0.9546 094 0.9547 095 0.9549 096 0.9552 097 0.9554 098 0.9557 099 0.9562