0001 function [coefficients,lassoFun] = trainLasso (X,Y)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016 lassoFun = @lassoScore;
0017
0018 parallel = false;
0019 linear = false;
0020
0021 if linear
0022 alldataX = X;
0023 else
0024 dataQuadraticX = repmat(X,1,size(X,2)).*repelem(X,1,size(X,2));
0025 numD = size(X,2);
0026 indicesToStay = find(tril(ones(numD,numD)));
0027 dataQuadraticX = dataQuadraticX(:,indicesToStay);
0028 alldataX = [X, dataQuadraticX];
0029 end
0030
0031 if parallel == 1
0032 options = statset('UseParallel',true);
0033 else
0034 options = statset('UseParallel',false);
0035 end
0036
0037 bad_index = isnan(Y);
0038 Y = Y (~bad_index) ;
0039 X = X (~bad_index,:) ;
0040
0041
0042
0043 tic;
0044 [B,FitInfo] = lassoglm(alldataX,Y,'normal','NumLambda',10,'CV',5,'Options',options);
0045 toc
0046
0047
0048
0049
0050
0051 indx = FitInfo.Index1SE;
0052 B0 = B(:,indx);
0053 nonzeros = sum(B0 ~= 0)
0054
0055
0056 cnst = FitInfo.Intercept(indx);
0057 coefficients = [cnst;B0];
0058
0059
0060 preds = glmval(coefficients,alldataX,'logit');
0061 preds(isnan(preds)) = 0;
0062
0063
0064 histogram(Y - preds)
0065 title('Residuals from lassoglm model')
0066 disp (['error : ', num2str(sum(Y - round(preds)))])
0067 disp (['percentage error :', num2str(abs(sum(Y - round(preds))/numel(Y)))])
0068
0069
0070 end