设为首页收藏本站

EPS数据狗论坛

 找回密码
 立即注册

QQ登录

只需一步,快速开始

查看: 1559|回复: 0

matlab练习程序(神经网络识别mnist手写数据集)

[复制链接]

30

主题

333

金钱

471

积分

入门用户

发表于 2019-7-12 16:39:18 | 显示全部楼层 |阅读模式


mnist数据集训练数据一共有28*28*60000个像素,标签有60000个。
测试数据一共有28*28*10000个,标签10000个。
这里神经网络输入层是784个像素,用了100个隐含层,最终10个输出结果。
arc代表的是神经网络结构,可以增加隐含层,不过我试了没太大效果,毕竟梯度消失。
因为是最普通的神经网络,最终识别错误率大概在5%左右。

迭代曲线:
1.png

代码如下:
  1. clear all;

  2. close all;
  3. clc;load mnist_uint8;

  4. train_x = double(train_x) / 255;
  5. test_x  = double(test_x)  / 255;
  6. train_y = double(train_y);
  7. test_y  = double(test_y);

  8. mu=mean(train_x);   
  9. sigma=max(std(train_x),eps);
  10. train_x=bsxfun(@minus,train_x,mu);          %每个样本分别减去平均值
  11. train_x=bsxfun(@rdivide,train_x,sigma);     %分别除以标准差

  12. test_x=bsxfun(@minus,test_x,mu);
  13. test_x=bsxfun(@rdivide,test_x,sigma);

  14. arc = [784 100 10]; %输入784,隐含层100,输出10
  15. n=numel(arc);

  16. W = cell(1,n-1);    %权重矩阵

  17. for i=2:n
  18.     W{i-1} = (rand(arc(i),arc(i-1)+1)-0.5) * 8 *sqrt(6 / (arc(i)+arc(i-1)));

  19. end

  20. learningRate = 2;   %训练速度
  21. numepochs = 5;      %训练5遍
  22. batchsize = 100;    %一次训练100个数据

  23. m = size(train_x, 1);       %数据总量
  24. numbatches = m / batchsize;    %一共有numbatches这么多组

  25. %% 训练
  26. L = zeros(numepochs*numbatches,1);
  27. ll=1;

  28. for i = 1 :
  29. numepochs
  30.     kk = randperm(m);   

  31. for l = 1 :
  32. numbatches
  33.         batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
  34.         batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);      

  35. %% 正向传播
  36.         mm = size(batch_x,1);
  37.         x = [ones(mm,1) batch_x];
  38.         a{1} = x;        

  39. for ii = 2 : n-1
  40.             a{ii} = 1.7159*tanh(2/3.*(a{ii - 1} * W{ii - 1}'));   
  41.             a{ii} = [ones(mm,1) a{ii}];        

  42. end
  43.         
  44.         a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));
  45.         e = batch_y - a{n};
  46.         L(ll) = 1/2 * sum(sum(e.^2)) / mm;
  47.         ll=ll+1;      

  48. %% 反向传播
  49.         d{n} = -e.*(a{n}.*(1 - a{n}));        

  50. for ii = (n - 1) : -1 : 2
  51.             d_act = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * a{ii}.^2);            
  52.             if ii+1==n   
  53.                 d{ii} = (d{ii + 1} * W{ii}) .* d_act;
  54.             else
  55.                 d{ii} = (d{ii + 1}(:,2:end) * W{ii}).* d_act;            end         
  56.         end
  57.          
  58.         for ii = 1 : n-1
  59.             if ii + 1 == n
  60.                 dW{ii} = (d{ii + 1}' * a{ii}) / size(d{ii + 1}, 1);
  61.             else
  62.                 dW{ii} = (d{ii + 1}(:,2:end)' * a{ii}) / size(d{ii + 1}, 1);      
  63.             end
  64.         end
  65.          
  66.        %% 更新参数        

  67. for ii = 1 : n - 1      
  68.             W{ii} = W{ii} - learningRate*dW{ii};        

  69. end
  70.               
  71.     end

  72. end

  73. %% 测试,相当于把正向传播再走一遍
  74. mm = size(test_x,1);
  75. x = [ones(mm,1) test_x];
  76. a{1} = x;

  77. for ii = 2 : n-1   
  78.     a{ii} = 1.7159 * tanh( 2/3 .* (a{ii - 1} * W{ii - 1}'));  
  79.     a{ii} = [ones(mm,1) a{ii}];

  80. end

  81. a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));

  82. [~, i] = max(a{end},[],2);
  83. labels = i;                         %识别后打的标签
  84. [~, expected] = max(test_y,[],2);
  85. bad = find(labels ~= expected);     %有哪些识别错了
  86. er = numel(bad) / size(x, 1);       %错误率

  87. plot(L);
复制代码



测试数据:
测试数据.rar (14.02 MB, 下载次数: 0, 售价: 1 金钱)
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

关闭

站长推荐上一条 /1 下一条

客服中心
关闭
在线时间:
周一~周五
8:30-17:30
QQ群:
653541906
联系电话:
010-85786021-8017
在线咨询
客服中心

意见反馈|网站地图|手机版|小黑屋|EPS数据狗论坛 ( 京ICP备09019565号-3 )   

Powered by BFIT! X3.4

© 2008-2028 BFIT Inc.

快速回复 返回顶部 返回列表