Home > SuperSegger > trainingConstants > trainLasso.m

trainLasso

PURPOSE ^

B1 are the coeffiecients, B1(0) is a constants, B1(1:end) are the

SYNOPSIS ^

function [coefficients,lassoFun] = trainLasso (X,Y)

DESCRIPTION ^

 B1 are the coeffiecients, B1(0) is a constants, B1(1:end) are the
 coefficients for the parameters in info.
   choice for x : 'segs', 'regs'

 dirname is directory with allready trained seg files
 dirname = '/Users/Stella/Dropbox/100XTrain_/'

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [coefficients,lassoFun] = trainLasso (X,Y)
0002 % B1 are the coeffiecients, B1(0) is a constants, B1(1:end) are the
0003 % coefficients for the parameters in info.
0004 %   choice for x : 'segs', 'regs'
0005 %
0006 % dirname is directory with allready trained seg files
0007 % dirname = '/Users/Stella/Dropbox/100XTrain_/'
0008 
0009 
0010 % what i used to train it with regularized logistic regression on segments
0011 % disp ('starting training on segments...');
0012 % [Xsegs,Ysegs] = getInfoScores (segDirMod,'segs');
0013 % A = lassoLogisticRegression (Xsegs,Ysegs,parallel);
0014 %updateScores(segDirMod,'segs',A,calculateLassoScores);
0015 
0016 lassoFun = @lassoScore;
0017 % construct squares
0018 parallel = false;
0019 linear = false;
0020 
0021 if linear
0022     alldataX = X;
0023 else
0024     dataQuadraticX = repmat(X,1,size(X,2)).*repelem(X,1,size(X,2));
0025     numD = size(X,2);
0026     indicesToStay = find(tril(ones(numD,numD)));
0027     dataQuadraticX = dataQuadraticX(:,indicesToStay);
0028     alldataX = [X, dataQuadraticX];
0029 end
0030 
0031 if parallel == 1
0032     options = statset('UseParallel',true);
0033 else
0034     options = statset('UseParallel',false);
0035 end
0036 
0037 bad_index = isnan(Y);
0038 Y = Y (~bad_index) ;
0039 X = X (~bad_index,:) ;
0040 
0041 %Construct a regularized binomial regression using 25 Lambda values and 10-fold cross validation
0042 % B : fitted coefficients with size (number of predictors x lambda)
0043 tic;
0044 [B,FitInfo] = lassoglm(alldataX,Y,'normal','NumLambda',10,'CV',5,'Options',options);
0045 toc
0046 
0047 % plots to look at different lambdas
0048 %lassoPlot(B,FitInfo,'PlotType','CV');
0049 %lassoPlot(B,FitInfo,'PlotType','Lambda','XScale','log');
0050 
0051 indx = FitInfo.Index1SE; % index of lambda with minimum deviance plus one standard deviation
0052 B0 = B(:,indx); % B for the lambda with min deviance + std
0053 nonzeros = sum(B0 ~= 0) % non zero coefficients
0054 
0055 % create a coefficient vector with the constant term first.
0056 cnst = FitInfo.Intercept(indx);
0057 coefficients = [cnst;B0];
0058 
0059 % residuals
0060 preds = glmval(coefficients,alldataX,'logit');
0061 preds(isnan(preds)) = 0;
0062 
0063 % results
0064 histogram(Y - preds) % histogram of residuals
0065 title('Residuals from lassoglm model')
0066 disp (['error : ', num2str(sum(Y - round(preds)))])
0067 disp (['percentage error :', num2str(abs(sum(Y - round(preds))/numel(Y)))])
0068 
0069 
0070 end

Generated on Thu 19-Jan-2017 13:55:21 by m2html © 2005