Home > SuperSegger > trainingConstants > neuralNetTrain.m

neuralNetTrain

PURPOSE ^

neuralNetTrain : trains a neural network to predict different classes.

SYNOPSIS ^

function [net,fun] = neuralNetTrain (X, Y, numCanidates)

DESCRIPTION ^

 neuralNetTrain : trains a neural network to predict different classes.
 Solve a Pattern Recognition Problem with a Neural Network
 Script generated by Neural Pattern Recognition app

 INPUT : 
   X : input data.
   Y : target data.
   numCanidates : number of times to repeat training

 OUTPUT : 
   net : struct with trained neural network.
   fun : the function used for obtaining the scores with the neural network. 

 Script was generated by Neural Pattern Recognition app
 University of Washington, 2016
 This file is part of SuperSegger.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [net,fun] = neuralNetTrain (X, Y, numCanidates)
0002 % neuralNetTrain : trains a neural network to predict different classes.
0003 % Solve a Pattern Recognition Problem with a Neural Network
0004 % Script generated by Neural Pattern Recognition app
0005 %
0006 % INPUT :
0007 %   X : input data.
0008 %   Y : target data.
0009 %   numCanidates : number of times to repeat training
0010 %
0011 % OUTPUT :
0012 %   net : struct with trained neural network.
0013 %   fun : the function used for obtaining the scores with the neural network.
0014 %
0015 % Script was generated by Neural Pattern Recognition app
0016 % University of Washington, 2016
0017 % This file is part of SuperSegger.
0018 
0019 
0020 fun = @scoreNeuralNet;
0021 
0022 if ~exist('numCanidates','var') || isempty(numCanidates)
0023     numCanidates = 1;
0024 end
0025 
0026 t = [(Y == 0),Y]';
0027 x = X';
0028 
0029 
0030 
0031 % Create a Pattern Recognition Network
0032 % hiddenLayerSizes
0033 net = patternnet([10 15]);
0034 
0035 % Choose a Training Function
0036 % For a list of all training functions type: help nntrain
0037 % 'trainlm' is usually fastest.
0038 % 'trainbr' takes longer but may be better for challenging problems.
0039 % 'trainscg' uses less memory. Suitable in low memory situations.
0040 net.trainFcn = 'trainbr';  % Scaled conjugate gradient backpropagation.
0041 
0042 
0043 % Setup Division of Data for Training, Validation, Testing
0044 net.divideParam.trainRatio = 70/100;
0045 net.divideParam.valRatio = 15/100;
0046 net.divideParam.testRatio = 15/100;
0047 
0048 %Make sure the overfit is real.
0049 net.trainParam.max_fail = net.trainParam.epochs / 10;
0050 
0051 numTotal = numel(t(1,:));
0052 numTrue = numel(find(t(1,:)==0));
0053 numFalse = numel(find(t(1,:)==1));
0054 
0055 errorWeights = {[(t(1,:)==0) * 0 + (t(1,:)==1) * numTotal / numFalse; (t(2,:)==0) * 0 + (t(2,:)==1) * numTotal / numTrue]};
0056 
0057 
0058 
0059 % Train the Network
0060 canidates = {};
0061 for i = 1:numCanidates
0062     [canidates{i}] = train(net,x,t,{},{},errorWeights);
0063 end
0064 
0065 % Test the Network
0066 % y = net(x);
0067 % e = gsubtract(t,y);
0068 % performance = perform(net,t,y);
0069 tind = vec2ind(t);
0070 percentError = [];
0071 for i = 1:numCanidates
0072     net = canidates{i};
0073     y = net(x);
0074     yind = vec2ind(y);
0075     percentError(i) = sum(tind ~= yind)/numel(tind);
0076 end
0077 
0078 [percentErrors, index] = min(percentError);
0079 
0080 net = canidates{index};
0081 disp (['Percent Error from neural network predictions : ',num2str(percentErrors)]);
0082 
0083 end
0084

Generated on Thu 19-Jan-2017 13:55:21 by m2html © 2005