搜档网
当前位置:搜档网 › 用神经网络做训练和分类的matlab代码

用神经网络做训练和分类的matlab代码

close all
clear all
clc
x=xlsread('training_data.xls',['B2:G401']);
y=xlsread('training_data.xls',['I2:K401']);
inputs = x';
targets = y';

% 创建一个模式识别网络(两层BP网络),同时给出中间层神经元的个数,这里使用20
hiddenLayerSize = 20;
net = patternnet(hiddenLayerSize);

% 对数据进行预处理,这里使用了归一化函数(一般不用修改)
% For a list of all processing functions type: help nnprocess
net.inputs{1}.processFcns = {'removeconstantrows','mapminmax'};
net.outputs{2}.processFcns = {'removeconstantrows','mapminmax'};


% 把训练数据分成三部分,训练网络、验证网络、测试网络
% For a list of all data division functions type: help nndivide
net.divideFcn = 'dividerand'; % Divide data randomly
net.divideMode = 'sample'; % Divide up every sample
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;

% 训练函数
% For a list of all training functions type: help nntrain
net.trainFcn = 'trainlm'; % Levenberg-Marquardt

% 使用均方误差来评估网络
% For a list of all performance functions type: help nnperformance
net.performFcn = 'mse'; % Mean squared error

% 画图函数
% For a list of all plot functions type: help nnplot
net.plotFcns = {'plotperform','plottrainstate','ploterrhist', ...
'plotregression', 'plotfit'};


% 开始训练网络(包含了训练和验证的过程)
[net,tr] = train(net,inputs,targets);

% 测试网络
outputs = net(inputs);
errors = gsubtract(targets,outputs);
performance = perform(net,targets,outputs)

% 获得训练、验证和测试的结果
trainTargets = targets .* tr.trainMask{1};
valTargets = targets .* tr.valMask{1};
testTargets = targets .* tr.testMask{1};
trainPerformance = perform(net,trainTargets,outputs)
valPerformance = perform(net,valTargets,outputs)
testPerformance = perform(net,testTargets,outputs)

% 可以查看网络的各个参数
view(net)

% 根据画图的结果,决定是否满意
% Uncomment these lines to enable various plots.
figure, plotperform(tr)
figure, plottrainstate(tr)
figure, plotconfusion(targets,outputs)
figure, ploterrhist(errors)


%如果你对该次训练满意,可以保存训练好网络
save('training_net.mat','net','tr');



下面是用来分类的代码
clear all
close all
clc
load 'training_net.mat'


%% You can change the filename, sheet name, and range
%导入测试数据
new_input = xlsread('new_data.xls',['A2:F25']);
new_output = round(net(new_input'));
xlswrite('new_data.xls',new_output','result','G2');

%把二进制转换成对应的类别
new_output=new_output';
[r c]=size(new_output);
my_category=zeros(r,1);
for i=1:r
my_category(i,1)=4*new_output(i,1)+2*new_output(i,2)+1*new_output(i,3);
end
xlswrite('new_data.xls',my_category,'result','J2');
%% End of Change



相关主题