0001 function [classificationTreePruned, treeFun] = trainTree(X,Y)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022 treeFun = @treeScore;
0023 trainingData = [Y,X];
0024 num_var = size(trainingData,2);
0025
0026
0027 area = X(:,11);
0028
0029
0030 inputTable = table(trainingData);
0031 inputTable.Properties.VariableNames = {'column'};
0032
0033 names = {};
0034 for i = 1 : num_var
0035 names{end+1} = [inputTable.Properties.VariableNames{1},'_',num2str(i)];
0036 end
0037
0038
0039 inputTable = [inputTable(:,setdiff(inputTable.Properties.VariableNames, {'column'})), array2table(table2array(inputTable(:,{'column'})), 'VariableNames', names)];
0040
0041
0042
0043
0044
0045
0046 predictorNames = names(2:end) ;
0047 predictors = inputTable(:, predictorNames);
0048 response = inputTable.column_1;
0049
0050
0051
0052 classificationTree = fitctree(...
0053 predictors, ...
0054 response, ...
0055 'SplitCriterion', 'gdi', ...
0056 'MaxNumSplits', 100, ...
0057 'Surrogate', 'off', ...
0058 'ClassNames', [0; 1]);
0059
0060 trainedClassifier.ClassificationTree = classificationTree;
0061 convertMatrixToTableFcn = @(x) table(x, 'VariableNames', {'column'});
0062 splitMatricesInTableFcn = @(t) [t(:,setdiff(t.Properties.VariableNames, {'column'})), array2table(table2array(t(:,{'column'})), 'VariableNames', predictorNames)];
0063 extractPredictorsFromTableFcn = @(t) t(:, predictorNames);
0064 predictorExtractionFcn = @(x) extractPredictorsFromTableFcn(splitMatricesInTableFcn(convertMatrixToTableFcn(x)));
0065 treePredictFcn = @(x) predict(classificationTree, x);
0066 trainedClassifier.predictFcn = @(x) treePredictFcn(predictorExtractionFcn(x));
0067
0068
0069
0070 rng(1);
0071 m = max(classificationTree.PruneList) - 1;
0072 [E,~,~,bestLevel] = cvloss(classificationTree,'SubTrees',0:m,'KFold',5);
0073 classificationTreePruned = prune(classificationTree,'Level',bestLevel);
0074 figure(1);
0075 view(classificationTreePruned,'Mode','graph')
0076
0077
0078 partitionedModel = crossval(classificationTreePruned, 'KFold', 5);
0079
0080
0081 validationAccuracy = 1 - kfoldLoss(partitionedModel, 'LossFun', 'ClassifError');
0082
0083
0084 [validationPredictions, validationScores] = kfoldPredict(partitionedModel);
0085
0086 disp(['Training classification tree done with accuracy : ',num2str(validationAccuracy)]);
0087
0088