Home > SuperSegger > trainingConstants > trainTree.m

trainTree

PURPOSE ^

trainTree : trains a classification tree.

SYNOPSIS ^

function [classificationTreePruned, treeFun] = trainTree(X,Y)

DESCRIPTION ^

 trainTree : trains a classification tree.
  returns a trained classifier and its validation accuracy.
  This code recreates the classification model trained in
  Classification Learner app.

   Input:
       X : predictive variables (21)
       Y : response variables

   Output:
       classificationTreePruned: a struct containing pruned trained classifier.
        The struct contains various fields with information about the
        trained classifier.

       treeFun : the function used for obtaining the scores with this classifier. 

 Auto-generated by MATLAB on 24-May-2016 23:47:52
 University of Washington, 2016
 This file is part of SuperSegger.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [classificationTreePruned, treeFun] = trainTree(X,Y)
0002 % trainTree : trains a classification tree.
0003 %  returns a trained classifier and its validation accuracy.
0004 %  This code recreates the classification model trained in
0005 %  Classification Learner app.
0006 %
0007 %   Input:
0008 %       X : predictive variables (21)
0009 %       Y : response variables
0010 %
0011 %   Output:
0012 %       classificationTreePruned: a struct containing pruned trained classifier.
0013 %        The struct contains various fields with information about the
0014 %        trained classifier.
0015 %
0016 %       treeFun : the function used for obtaining the scores with this classifier.
0017 %
0018 % Auto-generated by MATLAB on 24-May-2016 23:47:52
0019 % University of Washington, 2016
0020 % This file is part of SuperSegger.
0021 
0022 treeFun = @treeScore;
0023 trainingData = [Y,X];
0024 num_var = size(trainingData,2);
0025 
0026 
0027 area = X(:,11);
0028 
0029 % Convert input to table
0030 inputTable = table(trainingData);
0031 inputTable.Properties.VariableNames = {'column'};
0032 
0033 names = {};
0034 for i = 1 : num_var
0035 names{end+1} = [inputTable.Properties.VariableNames{1},'_',num2str(i)];
0036 end
0037 
0038 % Split matrices in the input table into vectors
0039 inputTable = [inputTable(:,setdiff(inputTable.Properties.VariableNames, {'column'})), array2table(table2array(inputTable(:,{'column'})), 'VariableNames', names)];
0040 %{'column_1', 'column_2', 'column_3', 'column_4', 'column_5', 'column_6', 'column_7', 'column_8', 'column_9', 'column_10', 'column_11', 'column_12', 'column_13', 'column_14', 'column_15', 'column_16', 'column_17', 'column_18', 'column_19', 'column_20', 'column_21', 'column_22'})];
0041 
0042 
0043 % Extract predictors and response
0044 % This code processes the data into the right shape for training the
0045 % classifier.
0046 predictorNames = names(2:end) ; %{'column_2', 'column_3', 'column_4', 'column_5', 'column_6', 'column_7', 'column_8', 'column_9', 'column_10', 'column_11', 'column_12', 'column_13', 'column_14', 'column_15', 'column_16', 'column_17', 'column_18', 'column_19', 'column_20', 'column_21', 'column_22'};
0047 predictors = inputTable(:, predictorNames);
0048 response = inputTable.column_1;
0049 
0050 % Train a classifier
0051 % This code specifies all the classifier options and trains the classifier.
0052 classificationTree = fitctree(...
0053     predictors, ...
0054     response, ...
0055     'SplitCriterion', 'gdi', ...
0056     'MaxNumSplits', 100, ...
0057     'Surrogate', 'off', ...
0058     'ClassNames', [0; 1]);
0059 
0060 trainedClassifier.ClassificationTree = classificationTree;
0061 convertMatrixToTableFcn = @(x) table(x, 'VariableNames', {'column'});
0062 splitMatricesInTableFcn = @(t) [t(:,setdiff(t.Properties.VariableNames, {'column'})), array2table(table2array(t(:,{'column'})), 'VariableNames', predictorNames)];
0063 extractPredictorsFromTableFcn = @(t) t(:, predictorNames);
0064 predictorExtractionFcn = @(x) extractPredictorsFromTableFcn(splitMatricesInTableFcn(convertMatrixToTableFcn(x)));
0065 treePredictFcn = @(x) predict(classificationTree, x);
0066 trainedClassifier.predictFcn = @(x) treePredictFcn(predictorExtractionFcn(x));
0067 
0068 %classificationTreePruned = trainedClassifier.ClassificationTree ;
0069 % prune the tree
0070 rng(1); % For reproducibility
0071 m = max(classificationTree.PruneList) - 1;
0072 [E,~,~,bestLevel] = cvloss(classificationTree,'SubTrees',0:m,'KFold',5);
0073 classificationTreePruned = prune(classificationTree,'Level',bestLevel);
0074 figure(1);
0075 view(classificationTreePruned,'Mode','graph')
0076 
0077 % Perform cross-validation
0078 partitionedModel = crossval(classificationTreePruned, 'KFold', 5);
0079 
0080 % Compute validation accuracy
0081 validationAccuracy = 1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError');
0082 
0083 % Compute validation predictions and scores
0084 [validationPredictions, validationScores] = kfoldPredict(partitionedModel);
0085 
0086 disp(['Training classification tree done with accuracy : ',num2str(validationAccuracy)]);
0087 
0088

Generated on Thu 19-Jan-2017 13:55:21 by m2html © 2005