0001 function [B,stats] = lassoglm(x,y,distr,varargin)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169
0170
0171
0172
0173
0174
0175
0176
0177
0178
0179
0180
0181
0182
0183
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193 if nargin < 2
0194 error(message('stats:lassoGlm:TooFewInputs'));
0195 end
0196
0197 if nargin < 3 || isempty(distr), distr = 'normal'; end
0198
0199 paramNames = { 'link' 'offset' 'weights'};
0200 paramDflts = {'canonical' [] []};
0201 [link,offset,pwts,~,varargin] = ...
0202 internal.stats.parseArgs(paramNames, paramDflts, varargin{:});
0203
0204
0205 LRdefault = 1e-4;
0206 pnames = { 'alpha' 'numlambda' 'lambdaratio' 'lambda' ...
0207 'dfmax' 'standardize' 'reltol' 'cv' 'mcreps' ...
0208 'predictornames' 'options' };
0209 dflts = { 1 100 LRdefault [] ...
0210 [] true 1e-4 'resubstitution' 1 ...
0211 {} []};
0212 [alpha, nLambda, lambdaRatio, lambda, ...
0213 dfmax, standardize, reltol, cvp, mcreps, predictorNames, parallelOptions] ...
0214 = internal.stats.parseArgs(pnames, dflts, varargin{:});
0215
0216 if ~isempty(lambda)
0217 userSuppliedLambda = true;
0218 else
0219 userSuppliedLambda = false;
0220 end
0221
0222
0223 if ~ismatrix(x) || length(size(x)) ~= 2 || ~isreal(x)
0224 error(message('stats:lassoGlm:XnotaReal2DMatrix'));
0225 end
0226
0227
0228 if isempty(x) || size(x,1) < 2
0229 error(message('stats:lassoGlm:TooFewObservations'));
0230 end
0231
0232
0233 if isa(y,'categorical')
0234 [y, classname] = grp2idx(y);
0235 nc = length(classname);
0236 if nc > 2
0237 error(message('stats:glmfit:TwoLevelCategory'));
0238 end
0239 y(y==1) = 0;
0240 y(y==2) = 1;
0241 end
0242
0243
0244 P = size(x,2);
0245
0246
0247 wsIllConditioned2 = warning('off','stats:glmfit:IllConditioned');
0248 cleanupIllConditioned2 = onCleanup(@() warning(wsIllConditioned2));
0249
0250
0251
0252
0253 [X, Y, offset, pwts, dataClass, nTrials, binomialTwoColumn] = ...
0254 glmProcessData(x, y, distr, 'off', offset, pwts);
0255
0256 [~,sqrtvarFun,devFun,linkFun,dlinkFun,ilinkFun,link,mu,eta,muLims,isCanonical,dlinkFunCanonical] = ...
0257 glmProcessDistrAndLink(Y,distr,link,'off',nTrials,dataClass);
0258
0259 [X,Y,pwts,nLambda,lambda,dfmax,cvp,mcreps,predictorNames,ever_active] = ...
0260 processLassoParameters(X,Y,pwts,alpha,nLambda,lambdaRatio,lambda,dfmax, ...
0261 standardize,reltol,cvp,mcreps,predictorNames);
0262
0263
0264 [lambdaMax, nullDev, nullIntercept] = computeLambdaMax(X, Y, pwts, alpha, standardize, ...
0265 distr, link, dlinkFun, offset, isCanonical, dlinkFunCanonical, devFun);
0266
0267
0268
0269 if isempty(lambda)
0270 lambda = computeLambdaSequence(lambdaMax, nLambda, lambdaRatio, LRdefault);
0271 end
0272
0273
0274
0275 if strcmp(distr,'binomial')
0276 muLims = [1.0e-5 1.0-1.0e-5];
0277 end
0278
0279
0280
0281
0282 if isempty(pwts) && isscalar(nTrials)
0283 totalWeight = size(X,1);
0284 elseif ~isempty(pwts) && isscalar(nTrials)
0285 totalWeight = sum(pwts);
0286 elseif isempty(pwts) && ~isscalar(nTrials)
0287 totalWeight = sum(nTrials);
0288 else
0289 totalWeight = sum(pwts .* nTrials);
0290 end
0291
0292
0293 lambda = lambda * totalWeight;
0294
0295 penalizedFitPartition = @(x,y,offset,pwts,n,wlsfit,b,active,mu,eta,sqrtvarFun) ...
0296 glmIRLSwrapper(x,y,distr,offset,pwts,dataClass,n, ...
0297 sqrtvarFun,linkFun,dlinkFun,ilinkFun,devFun,b,active,mu,muLims,wlsfit,nullDev,reltol);
0298
0299 penalizedFit = @(x,y,wlsfit,b,active,mu,eta) ...
0300 penalizedFitPartition(x,y,offset,pwts',nTrials,wlsfit,b,active,mu,eta,sqrtvarFun);
0301
0302 [B,Intercept,lambda,deviance] = ...
0303 lassoFit(X,Y,pwts,lambda,alpha,dfmax,standardize,reltol,lambdaMax*totalWeight,ever_active, ...
0304 penalizedFit,mu,eta,dataClass,userSuppliedLambda,nullDev,nullIntercept);
0305
0306
0307 df = sum(B~=0,1);
0308
0309
0310
0311
0312 stats = struct();
0313 stats.Intercept = [];
0314 stats.Lambda = [];
0315 stats.Alpha = alpha;
0316 stats.DF = [];
0317 stats.Deviance = [];
0318 stats.PredictorNames = predictorNames;
0319
0320
0321
0322
0323
0324
0325 if ~isequal(cvp,'resubstitution')
0326
0327
0328
0329
0330
0331 cvfun = @(Xtrain,Ytrain,Xtest,Ytest) lassoFitAndPredict( ...
0332 Xtrain,Ytrain,Xtest,Ytest, ...
0333 lambda,alpha,P,standardize,reltol,ever_active, ...
0334 penalizedFitPartition,distr,link,linkFun,dlinkFun,sqrtvarFun, ...
0335 isCanonical, dlinkFunCanonical,devFun,dataClass);
0336 weights = pwts;
0337 if isempty(weights)
0338 weights = nan(size(X,1),1);
0339 end
0340 if isempty(offset) || isequal(offset,0)
0341 offset = nan(size(X,1),1);
0342 end
0343 if binomialTwoColumn
0344 response = [nTrials Y];
0345 else
0346 response = Y;
0347 end
0348 cvDeviance = crossval(cvfun,[weights(:) offset(:) X],response, ...
0349 'Partition',cvp,'Mcreps',mcreps,'Options',parallelOptions);
0350
0351 cvDeviance = bsxfun(@times,cvDeviance,repmat((size(X,1) ./ cvp.TestSize)', mcreps, 1));
0352 deviance = mean(cvDeviance);
0353 se = std(cvDeviance) / sqrt(size(cvDeviance,1));
0354 minDeviance = min(deviance);
0355 minIx = find(deviance == minDeviance,1);
0356 lambdaMin = lambda(minIx);
0357 minplus1 = deviance(minIx) + se(minIx);
0358 seIx = find((deviance(1:minIx) <= minplus1),1,'first');
0359 if isempty(seIx)
0360 lambdaSE = [];
0361 else
0362 lambdaSE = lambda(seIx);
0363 end
0364
0365
0366 stats.SE = se;
0367 stats.LambdaMinDeviance = lambdaMin;
0368 stats.Lambda1SE = lambdaSE;
0369 stats.IndexMinDeviance = minIx;
0370 stats.Index1SE = seIx;
0371
0372
0373
0374 stats.LambdaMinDeviance = stats.LambdaMinDeviance / totalWeight;
0375 stats.Lambda1SE = stats.Lambda1SE / totalWeight;
0376 end
0377
0378
0379
0380
0381
0382 nLambda = length(lambda);
0383 reverseIndices = nLambda:-1:1;
0384 lambda = lambda(reverseIndices);
0385 lambda = reshape(lambda,1,nLambda);
0386 B = B(:,reverseIndices);
0387 Intercept = Intercept(reverseIndices);
0388 df = df(reverseIndices);
0389 deviance = deviance(reverseIndices);
0390 if ~isequal(cvp,'resubstitution')
0391 stats.SE = stats.SE(reverseIndices);
0392 stats.IndexMinDeviance = nLambda - stats.IndexMinDeviance + 1;
0393 stats.Index1SE = nLambda - stats.Index1SE + 1;
0394 end
0395
0396 stats.Intercept = Intercept;
0397 stats.Lambda = lambda;
0398 stats.DF = df;
0399 stats.Deviance = deviance;
0400
0401
0402 stats.Lambda = stats.Lambda / totalWeight;
0403
0404 end
0405
0406
0407
0408
0409
0410
0411
0412
0413
0414 function mu = startingVals(distr,y,N)
0415
0416 switch distr
0417 case 'poisson'
0418 mu = y + 0.25;
0419 case 'binomial'
0420 mu = (N .* y + 0.5) ./ (N + 1);
0421 case {'gamma' 'inverse gaussian'}
0422 mu = max(y, eps(class(y)));
0423 otherwise
0424 mu = y;
0425 end
0426 end
0427
0428
0429
0430
0431
0432 function diagnoseSeparation(eta,y,N)
0433
0434 [x,idx] = sort(eta);
0435 if ~isscalar(N)
0436 N = N(idx);
0437 end
0438 p = y(idx);
0439 if all(p==p(1))
0440 return
0441 end
0442 if x(1)==x(end)
0443 return
0444 end
0445
0446 noFront = 0<p(1) && p(1)<1;
0447 noEnd = 0<p(end) && p(end)<1;
0448 if p(1)==p(end) || (noFront && noEnd)
0449
0450
0451 return
0452 end
0453
0454
0455
0456
0457
0458
0459
0460
0461
0462
0463
0464
0465 dx = 100*max(eps(x(1)),eps(x(end)));
0466 n = length(p);
0467 if noFront
0468 A = 0;
0469 else
0470 A = find(p~=p(1),1,'first')-1;
0471 cutoff = x(A+1)-dx;
0472 A = sum(x(1:A)<cutoff);
0473 end
0474
0475 if noEnd
0476 B = n+1;
0477 else
0478 B = find(p~=p(end),1,'last')+1;
0479 cutoff = x(B-1)+dx;
0480 B = (n+1) - sum(x(B:end)>cutoff);
0481 end
0482
0483 if A+1<B-1
0484
0485 if x(B-1)-x(A+1)>dx
0486 return
0487 end
0488 end
0489
0490
0491 if A+1==B
0492 xmid = x(A) + 0.5*(x(B)-x(A));
0493 else
0494 xmid = x(A+1);
0495 if isscalar(N)
0496 pmid = mean(p(A+1:B-1));
0497 else
0498 pmid = sum(p(A+1:B-1).*N(A+1:B-1)) / sum(N(A+1:B-1));
0499 end
0500 end
0501
0502
0503 if A>=1
0504 explanation = sprintf('\n XB<%g: P=%g',xmid,p(1));
0505 else
0506 explanation = '';
0507 end
0508
0509
0510 if A+1<B
0511 explanation = sprintf('%s\n XB=%g: P=%g',explanation,xmid,pmid);
0512 end
0513
0514
0515 if B<=n
0516 explanation = sprintf('%s\n XB>%g: P=%g',explanation,xmid,p(end));
0517 end
0518
0519 warning(message('stats:lassoGlm:PerfectSeparation', explanation));
0520 end
0521
0522
0523
0524
0525
0526 function [x, y, offset, pwts, dataClass, N, binomialTwoColumn] = ...
0527 glmProcessData(x, y, distr, const, offset, pwts)
0528
0529 N = [];
0530 binomialTwoColumn = false;
0531
0532
0533 if strcmp(distr,'binomial')
0534 if size(y,2) == 1
0535
0536 if any(y < 0 | y > 1)
0537 error(message('stats:lassoGlm:BadDataBinomialFormat'));
0538 end
0539 elseif size(y,2) == 2
0540 binomialTwoColumn = true;
0541 y(y(:,2)==0,2) = NaN;
0542 N = y(:,2);
0543 y = y(:,1) ./ N;
0544 if any(y < 0 | y > 1)
0545 error(message('stats:lassoGlm:BadDataBinomialRange'));
0546 end
0547 else
0548 error(message('stats:lassoGlm:MatrixOrBernoulliRequired'));
0549 end
0550 end
0551
0552 [anybad,~,y,x,offset,pwts,N] = dfswitchyard('statremovenan',y,x,offset,pwts,N);
0553 if anybad > 0
0554 switch anybad
0555 case 2
0556 error(message('stats:lassoGlm:InputSizeMismatchX'))
0557 case 3
0558 error(message('stats:lassoGlm:InputSizeMismatchOffset'))
0559 case 4
0560 error(message('stats:lassoGlm:InputSizeMismatchPWTS'))
0561 end
0562 end
0563
0564
0565 okrows = all(isfinite(x),2) & all(isfinite(y),2) & all(isfinite(offset));
0566
0567 if ~isempty(pwts)
0568
0569 if ~isvector(pwts) || ~isreal(pwts) || size(x,1) ~= length(pwts) || ...
0570 ~all(isfinite(pwts)) || any(pwts<0)
0571 error(message('stats:lassoGlm:InvalidObservationWeights'));
0572 end
0573 okrows = okrows & pwts(:)>0;
0574 pwts = pwts(okrows);
0575 end
0576
0577
0578 if sum(okrows)<2
0579 error(message('stats:lassoGlm:TooFewObservationsAfterNaNs'));
0580 end
0581
0582
0583
0584 x = x(okrows,:);
0585 y = y(okrows);
0586 if ~isempty(N) && ~isscalar(N)
0587 N = N(okrows);
0588 end
0589 if ~isempty(offset)
0590 offset = offset(okrows);
0591 end
0592
0593 if isequal(const,'on')
0594 x = [ones(size(x,1),1) x];
0595 end
0596 dataClass = superiorfloat(x,y);
0597 x = cast(x,dataClass);
0598 y = cast(y,dataClass);
0599
0600 if isempty(offset), offset = 0; end
0601 if isempty(N), N = 1; end
0602
0603 end
0604
0605
0606
0607
0608
0609 function [X,Y,weights,nLambda,lambda,dfmax,cvp,mcreps,predictorNames,ever_active] = ...
0610 processLassoParameters(X,Y,weights, alpha, nLambda, lambdaRatio, lambda, dfmax, ...
0611 standardize, reltol, cvp, mcreps, predictorNames)
0612
0613
0614 if ~isempty(weights)
0615
0616
0617 weights = weights(:)';
0618
0619 end
0620
0621 [~,P] = size(X);
0622
0623
0624
0625
0626 constantPredictors = (range(X)==0);
0627 ever_active = ~constantPredictors;
0628
0629
0630
0631
0632
0633 if ~isscalar(alpha) || ~isreal(alpha) || ~isfinite(alpha) || ...
0634 alpha <= 0 || alpha > 1
0635 error(message('stats:lassoGlm:InvalidAlpha'))
0636 end
0637
0638
0639
0640
0641 if ~isscalar(standardize) || (~islogical(standardize) && standardize~=0 && standardize~=1)
0642 error(message('stats:lassoGlm:InvalidStandardize'))
0643 end
0644
0645
0646
0647 if ~isempty(lambda)
0648
0649
0650 if ~isreal(lambda) || any(lambda < 0)
0651 error(message('stats:lassoGlm:NegativeLambda'));
0652 end
0653
0654 lambda = sort(lambda(:),1,'descend');
0655
0656 else
0657
0658
0659 if ~isreal(nLambda) || ~isfinite(nLambda) || nLambda < 1
0660 error(message('stats:lassoGlm:InvalidNumLambda'));
0661 else
0662 nLambda = floor(nLambda);
0663 end
0664
0665
0666 if ~isreal(lambdaRatio) || lambdaRatio <0 || lambdaRatio >= 1
0667 error(message('stats:lassoGlm:InvalidLambdaRatio'));
0668 end
0669 end
0670
0671
0672
0673 if ~isscalar(reltol) || ~isreal(reltol) || ~isfinite(reltol) || reltol <= 0 || reltol >= 1
0674 error(message('stats:lassoGlm:InvalidRelTol'));
0675 end
0676
0677
0678
0679
0680
0681
0682 if isempty(dfmax)
0683 dfmax = P;
0684 else
0685 if ~isscalar(dfmax)
0686 error(message('stats:lassoGlm:DFmaxBadType'));
0687 end
0688 try
0689 dfmax = uint32(dfmax);
0690 catch ME
0691 mm = message('stats:lassoGlm:DFmaxBadType');
0692 throwAsCaller(MException(mm.Identifier,'%s',getString(mm)));
0693 end
0694 if dfmax < 1
0695 error(message('stats:lassoGlm:DFmaxNotAnIndex'));
0696 else
0697 dfmax = min(dfmax,P);
0698 end
0699 end
0700
0701
0702
0703 if ~isscalar(mcreps) || ~isreal(mcreps) || ~isfinite(mcreps) || mcreps < 1
0704 error(message('stats:lassoGlm:MCRepsBadType'));
0705 end
0706 mcreps = fix(mcreps);
0707
0708
0709
0710
0711 if isnumeric(cvp) && isscalar(cvp) && (cvp==round(cvp)) && (0<cvp)
0712
0713 if (cvp>size(X,1))
0714 error(message('stats:lassoGlm:InvalidCVforX'));
0715 end
0716 cvp = cvpartition(size(X,1),'Kfold',cvp);
0717 elseif isa(cvp,'cvpartition')
0718 if strcmpi(cvp.Type,'resubstitution')
0719 cvp = 'resubstitution';
0720 elseif strcmpi(cvp.Type,'leaveout')
0721 error(message('stats:lassoGlm:InvalidCVtype'));
0722 elseif strcmpi(cvp.Type,'holdout') && mcreps<=1
0723 error(message('stats:lassoGlm:InvalidMCReps'));
0724 end
0725 elseif strncmpi(cvp,'resubstitution',length(cvp))
0726
0727
0728
0729 cvp = 'resubstitution';
0730 else
0731 error(message('stats:lassoGlm:InvalidCVtype'));
0732 end
0733 if strcmp(cvp,'resubstitution') && mcreps ~= 1
0734 error(message('stats:lassoGlm:InvalidMCReps'));
0735 end
0736
0737 if isa(cvp,'cvpartition')
0738 if (cvp.N ~= size(X,1)) || (min(cvp.TrainSize) < 2)
0739
0740
0741
0742 error(message('stats:lassoGlm:TooFewObservationsForCrossval'));
0743 end
0744 end
0745
0746
0747
0748
0749
0750
0751 if ~isempty(predictorNames)
0752 if ~iscellstr(predictorNames) || length(predictorNames(:)) ~= size(X,2)
0753 error(message('stats:lassoGlm:InvalidPredictorNames'));
0754 else
0755 predictorNames = predictorNames(:)';
0756 end
0757 end
0758
0759 end
0760
0761
0762
0763
0764
0765 function [estdisp,sqrtvarFun,devFun,linkFun,dlinkFun,ilinkFun,link,mu,eta,muLims,...
0766 isCanonical,dlinkFunCanonical] = ...
0767 glmProcessDistrAndLink(y,distr,link,estdisp,N,dataClass)
0768
0769 switch distr
0770 case 'normal'
0771 canonicalLink = 'identity';
0772 case 'binomial'
0773 canonicalLink = 'logit';
0774 case 'poisson'
0775 canonicalLink = 'log';
0776 case 'gamma'
0777 canonicalLink = 'reciprocal';
0778 case 'inverse gaussian'
0779 canonicalLink = -2;
0780 end
0781
0782 if isequal(link, 'canonical'), link = canonicalLink; end
0783
0784 switch distr
0785 case 'normal'
0786 sqrtvarFun = @(mu) ones(size(mu));
0787 devFun = @(mu,y) (y - mu).^2;
0788 estdisp = 'on';
0789 case 'binomial'
0790 sqrtN = sqrt(N);
0791 sqrtvarFun = @(mu) sqrt(mu).*sqrt(1-mu) ./ sqrtN;
0792 devFun = @(mu,y) 2*N.*(y.*log((y+(y==0))./mu) + (1-y).*log((1-y+(y==1))./(1-mu)));
0793 case 'poisson'
0794 if any(y < 0)
0795 error(message('stats:lassoGlm:BadDataPoisson'));
0796 end
0797 sqrtvarFun = @(mu) sqrt(mu);
0798 devFun = @(mu,y) 2*(y .* (log((y+(y==0)) ./ mu)) - (y - mu));
0799 case 'gamma'
0800 if any(y <= 0)
0801 error(message('stats:lassoGlm:BadDataGamma'));
0802 end
0803 sqrtvarFun = @(mu) mu;
0804 devFun = @(mu,y) 2*(-log(y ./ mu) + (y - mu) ./ mu);
0805 estdisp = 'on';
0806 case 'inverse gaussian'
0807 if any(y <= 0)
0808 error(message('stats:lassoGlm:BadDataInvGamma'));
0809 end
0810 sqrtvarFun = @(mu) mu.^(3/2);
0811 devFun = @(mu,y) ((y - mu)./mu).^2 ./ y;
0812 estdisp = 'on';
0813 otherwise
0814 error(message('stats:lassoGlm:BadDistribution'));
0815 end
0816
0817
0818
0819 [linkFun,dlinkFun,ilinkFun] = dfswitchyard('stattestlink',link,dataClass);
0820
0821
0822 mu = startingVals(distr,y,N);
0823 eta = linkFun(mu);
0824
0825
0826
0827 switch distr
0828 case 'binomial'
0829
0830
0831 muLims = [eps(dataClass) 1-eps(dataClass)];
0832 case {'poisson' 'gamma' 'inverse gaussian'}
0833
0834
0835 muLims = realmin(dataClass).^.25;
0836 otherwise
0837 muLims = [];
0838 end
0839
0840
0841
0842
0843 isCanonical = isequal(link, canonicalLink);
0844 [~, dlinkFunCanonical] = dfswitchyard('stattestlink', canonicalLink, dataClass);
0845
0846 end
0847
0848
0849
0850
0851
0852 function [b,mu,eta,varargout] = glmIRLS(x,y,distr,offset,pwts,dataClass,N, ...
0853 sqrtvarFun,linkFun,dlinkFun,ilinkFun,b,active,mu,muLims, ...
0854 wlsfit,nullDev,devFun,reltol)
0855
0856 wsIterationLimit = warning('off','stats:lassoGlm:IterationLimit');
0857 wsPerfectSeparation = warning('off','stats:lassoGlm:PerfectSeparation');
0858 wsBadScaling = warning('off','stats:lassoGlm:BadScaling');
0859 cleanupIterationLimit = onCleanup(@() warning(wsIterationLimit));
0860 cleanupPerfectSeparation = onCleanup(@() warning(wsPerfectSeparation));
0861 cleanupBadScaling = onCleanup(@() warning(wsBadScaling));
0862
0863 if isempty(pwts)
0864 pwts = 1;
0865 end
0866
0867
0868 iter = 0;
0869 iterLim = 100;
0870 warned = false;
0871 seps = sqrt(eps);
0872
0873
0874
0875 convcrit = max(1e-6,2*reltol);
0876
0877 eta = linkFun(mu);
0878
0879 while iter <= iterLim
0880 iter = iter+1;
0881
0882
0883 deta = dlinkFun(mu);
0884 z = eta + (y - mu) .* deta;
0885
0886
0887 sqrtw = sqrt(pwts) ./ (abs(deta) .* sqrtvarFun(mu));
0888
0889
0890
0891
0892
0893 wtol = max(sqrtw)*eps(dataClass)^(2/3);
0894 t = (sqrtw < wtol);
0895 if any(t)
0896 t = t & (sqrtw ~= 0);
0897 if any(t)
0898 sqrtw(t) = wtol;
0899 if ~warned
0900 warning(message('stats:lassoGlm:BadScaling'));
0901 end
0902 warned = true;
0903 end
0904 end
0905
0906 b_old = b;
0907 [b,active] = wlsfit(z - offset, x, sqrtw.^2, b, active);
0908
0909
0910 eta = offset + x * b;
0911
0912
0913 mu = ilinkFun(eta);
0914
0915
0916 switch distr
0917 case 'binomial'
0918 if any(mu < muLims(1) | muLims(2) < mu)
0919 mu = max(min(mu,muLims(2)),muLims(1));
0920 end
0921 case {'poisson' 'gamma' 'inverse gaussian'}
0922 if any(mu < muLims(1))
0923 mu = max(mu,muLims(1));
0924 end
0925 end
0926
0927
0928
0929 if (~any(abs(b-b_old) > convcrit * max(seps, abs(b_old))))
0930 break;
0931 end
0932
0933 if sum(devFun(mu,y)) < (1.0e-3 * nullDev)
0934 break;
0935 end
0936
0937 end
0938
0939 if iter > iterLim
0940 warning(message('stats:lassoGlm:IterationLimit'));
0941 end
0942
0943 if iter>iterLim && isequal(distr,'binomial')
0944 diagnoseSeparation(eta,y,N);
0945 end
0946
0947 varargout{1} = active;
0948
0949 end
0950
0951
0952
0953
0954
0955 function [B,active,varargout] = glmIRLSwrapper(X,Y,distr,offset,pwts,dataClass,N, ...
0956 sqrtvarFun,linkFun,dlinkFun,ilinkFun,devFun,b,active,mu,muLims, ...
0957 wlsfit,nullDev,reltol)
0958
0959
0960
0961
0962
0963
0964
0965
0966
0967
0968
0969
0970 X = [ones(size(X,1),1) X];
0971
0972
0973
0974 if isempty(pwts), pwts=1; end
0975
0976 [B,mu,eta,active] = glmIRLS(X,Y,distr,offset,pwts,dataClass,N, ...
0977 sqrtvarFun,linkFun,dlinkFun,ilinkFun,b,active,mu,muLims, ...
0978 wlsfit,nullDev,devFun,reltol);
0979
0980 deviance = sum(pwts.* devFun(mu,Y));
0981
0982
0983
0984 Intercept = B(1);
0985 B = B(2:end);
0986
0987 extras.Intercept = Intercept;
0988 extras.Deviance = deviance;
0989 varargout{1} = extras;
0990 varargout{2} = mu;
0991 varargout{3} = eta;
0992
0993 end
0994
0995
0996
0997
0998
0999 function dev = lassoFitAndPredict(Xtrain,Ytrain,Xtest,Ytest, ...
1000 lambda,alpha,dfmax,standardize,reltol,ever_active, ...
1001 penalizedFitPartition,distr,link,linkFun,dlinkFun,sqrtvarFun, ...
1002 isCanonical, dlinkFunCanonical,devFun,dataClass)
1003
1004
1005
1006
1007
1008 trainWeights = Xtrain(:,1);
1009
1010
1011
1012 if any(isnan(trainWeights))
1013 trainWeights = [];
1014 end
1015 trainOffset = Xtrain(:,2);
1016 if any(isnan(trainOffset))
1017 trainOffset = 0;
1018 end
1019
1020 Xtrain = Xtrain(:,3:end);
1021 if size(Ytrain,2) == 2
1022 trainN = Ytrain(:,1);
1023 Ytrain = Ytrain(:,2);
1024 else
1025 trainN = 1;
1026 end
1027
1028
1029
1030
1031 mu = startingVals(distr,Ytrain,trainN);
1032 eta = linkFun(mu);
1033 if isequal(distr,'binomial')
1034 sqrtvarFun = @(mu) sqrt(mu).*sqrt(1-mu) ./ sqrt(trainN);
1035 devFun = @(mu,y) 2*trainN.*(y.*log((y+(y==0))./mu) + (1-y).*log((1-y+(y==1))./(1-mu)));
1036 end
1037
1038 penalizedFit = @(x,y,wlsfit,b,active,mu,eta) penalizedFitPartition(x,y, ...
1039 trainOffset,trainWeights,trainN,wlsfit,b,active,mu,eta,sqrtvarFun);
1040
1041 [lambdaMax, nullDev, nullIntercept] = computeLambdaMax(Xtrain, Ytrain, trainWeights, ...
1042 alpha, standardize, distr, link, dlinkFun, trainOffset, isCanonical, dlinkFunCanonical, devFun);
1043
1044
1045
1046
1047 if isempty(trainWeights) && isscalar(trainN)
1048 totalWeight = size(Xtrain,1);
1049 elseif ~isempty(trainWeights) && isscalar(trainN)
1050 totalWeight = sum(trainWeights);
1051 elseif isempty(trainWeights) && ~isscalar(trainN)
1052 totalWeight = sum(trainN);
1053 else
1054 totalWeight = sum(trainWeights .* trainN);
1055 end
1056
1057 lambdaMax = lambdaMax * totalWeight;
1058
1059 [B,Intercept] = lassoFit(Xtrain,Ytrain, ...
1060 trainWeights,lambda,alpha,dfmax,standardize,reltol, ...
1061 lambdaMax,ever_active,penalizedFit,mu,eta,dataClass,true,nullDev,nullIntercept);
1062 Bplus = [Intercept; B];
1063
1064 testWeights = Xtest(:,1);
1065 if any(isnan(testWeights))
1066 testWeights = ones(size(Xtest,1),1);
1067 end
1068 testOffset = Xtest(:,2);
1069 if any(isnan(testOffset))
1070 testOffset = 0;
1071 end
1072 Xtest = Xtest(:,3:end);
1073 if size(Ytest,2) == 2
1074 testN = Ytest(:,1);
1075 Ytest = Ytest(:,2);
1076 else
1077 testN = 1;
1078 end
1079
1080
1081
1082
1083
1084 if isequal(distr,'binomial')
1085 devFun = @(mu,y) 2*testN.*(y.*log((y+(y==0))./mu) + (1-y).*log((1-y+(y==1))./(1-mu)));
1086 end
1087
1088 numFits = size(Bplus,2);
1089 dev = zeros(1,numFits);
1090 for i=1:numFits
1091 if ~isequal(testOffset,0)
1092 mu = glmval(Bplus(:,i), Xtest, link, 'Offset',testOffset);
1093 else
1094 mu = glmval(Bplus(:,i), Xtest, link);
1095 end
1096 di = devFun(mu,Ytest);
1097 dev(i) = sum(testWeights' * di);
1098 end
1099
1100 end
1101
1102
1103
1104
1105
1106 function [B,Intercept,lambda,varargout] = ...
1107 lassoFit(X,Y,weights,lambda,alpha,dfmax,standardize,reltol, ...
1108 lambdaMax,ever_active,penalizedFit,mu,eta,dataClass,userSuppliedLambda,nullDev,nullIntercept)
1109
1110
1111
1112
1113
1114 regressionType = 'GLM';
1115
1116 [~,P] = size(X);
1117 nLambda = length(lambda);
1118
1119
1120
1121
1122 constantPredictors = (range(X)==0);
1123 ever_active = ever_active & ~constantPredictors;
1124
1125
1126
1127 observationWeights = ~isempty(weights);
1128 if ~isempty(weights)
1129 observationWeights = true;
1130 weights = weights(:)';
1131
1132 weights = weights / sum(weights);
1133 end
1134
1135 if standardize
1136 if ~observationWeights
1137
1138 [X0,muX,sigmaX] = zscore(X,1);
1139
1140 sigmaX(constantPredictors) = 1;
1141 else
1142
1143 muX = weights*X;
1144 X0 = bsxfun(@minus,X,muX);
1145 sigmaX = sqrt( weights*(X0.^2) );
1146
1147 sigmaX(constantPredictors) = 1;
1148 X0 = bsxfun(@rdivide, X0, sigmaX);
1149 end
1150 else
1151 switch regressionType
1152 case 'OLS'
1153 if ~observationWeights
1154
1155 muX = mean(X,1);
1156 X0 = bsxfun(@minus,X,muX);
1157 sigmaX = 1;
1158 else
1159
1160 muX = weights*X;
1161 X0 = bsxfun(@minus,X,muX);
1162 sigmaX = 1;
1163 end
1164 case 'GLM'
1165 X0 = X;
1166
1167 sigmaX = 1;
1168 muX = zeros(1,size(X,2));
1169 end
1170 end
1171
1172
1173 switch regressionType
1174 case 'OLS'
1175 if ~observationWeights
1176 muY = mean(Y);
1177 else
1178 muY = weights*Y;
1179 end
1180 Y0 = bsxfun(@minus,Y,muY);
1181 case 'GLM'
1182 Y0 = Y;
1183 end
1184
1185
1186
1187
1188 B = zeros(P,nLambda);
1189
1190 b = zeros(P,1,dataClass);
1191
1192 if nLambda > 0
1193 Extras(nLambda) = struct('Intercept', nullIntercept, 'Deviance', nullDev);
1194 for i=1:nLambda-1, Extras(i) = Extras(nLambda); end
1195 intercept = nullIntercept;
1196 end
1197
1198 active = false(1,P);
1199
1200 for i = 1:nLambda
1201
1202 lam = lambda(i);
1203 disp (['iteration ', num2str(i), ' out of ' , num2str(nLambda)]);
1204 if lam >= lambdaMax
1205 continue;
1206 end
1207
1208
1209
1210
1211 wlsfit = @(x,y,weights,b,active) glmPenalizedWlsWrapper(y,x,b,active,weights,lam, ...
1212 alpha,reltol,ever_active);
1213
1214 [b,active,extras,mu,eta] = penalizedFit(X0,Y0,wlsfit,[intercept;b],active,mu,eta);
1215
1216 B(:,i) = b;
1217
1218 Extras(i) = extras;
1219
1220
1221 if sum(active) > dfmax
1222
1223 lambda = lambda(1:(i-1));
1224 B = B(:,1:(i-1));
1225 Extras = Extras(:,1:(i-1));
1226 break
1227 end
1228
1229
1230
1231 if ~(userSuppliedLambda || isempty(nullDev))
1232 if extras.Deviance < 1.0e-3 * nullDev
1233 lambda = lambda(1:i);
1234 B = B(:,1:i);
1235 Extras = Extras(:,1:i);
1236 break
1237 end
1238 end
1239
1240 end
1241
1242
1243
1244
1245
1246 B = bsxfun(@rdivide, B, sigmaX');
1247 B(~ever_active,:) = 0;
1248
1249 switch regressionType
1250 case 'OLS'
1251 Intercept = muY-muX*B;
1252 case 'GLM'
1253 Intercept = zeros(1,length(lambda));
1254 for i=1:length(lambda)
1255 Intercept(i) = Extras(i).Intercept;
1256 end
1257 if isempty(lambda)
1258 Intercept = [];
1259 else
1260 Intercept = Intercept - muX*B;
1261 end
1262 end
1263
1264
1265
1266
1267
1268 switch regressionType
1269 case 'OLS'
1270 Intercept = muY-muX*B;
1271 BwithI = [Intercept; B];
1272 fits = [ones(size(X,1),1) X]*BwithI;
1273 residuals = bsxfun(@minus, Y, fits);
1274 if ~observationWeights
1275 mspe = mean(residuals.^2);
1276 else
1277
1278 mspe = weights * (residuals.^2);
1279 end
1280 varargout{1} = mspe;
1281 case 'GLM'
1282 deviance = zeros(1,length(lambda));
1283 for i=1:length(lambda)
1284 deviance(i) = Extras(i).Deviance;
1285 end
1286 if isempty(lambda)
1287 deviance = [];
1288 end
1289 varargout{1} = deviance;
1290 end
1291
1292 end
1293
1294
1295
1296
1297
1298 function potentially_active = thresholdScreen(X0, wX0, Y0, ...
1299 b, active, threshold)
1300 r = Y0 - X0(:,active)*b(active,:);
1301
1302
1303 potentially_active = abs(r' *wX0) > threshold;
1304 end
1305
1306
1307
1308
1309
1310 function [b,active,wX2,wX2calculated,shrinkFactor] = ...
1311 cdescentCycleNewCandidates(X0, weights, wX0, wX2, wX2calculated, Y0, ...
1312 b, active, shrinkFactor, threshold, candidates)
1313
1314 r = Y0 - X0(:,active)*b(active,:);
1315 bold = b;
1316
1317 for j=find(candidates);
1318
1319 bj = wX0(:,j)' * r;
1320
1321 margin = abs(bj) - threshold;
1322
1323
1324 if margin > 0
1325 if ~wX2calculated(j)
1326 wX2(j) = weights * X0(:,j).^2;
1327 wX2calculated(j) = true;
1328 shrinkFactor(j) = wX2(j) + shrinkFactor(j);
1329 end
1330
1331 b(j) = sign(bj) .* margin ./ shrinkFactor(j);
1332
1333 active(j) = true;
1334 end
1335
1336 r = r - X0(:,j)*(b(j)-bold(j));
1337 end
1338
1339 end
1340
1341
1342
1343
1344
1345 function [b,active] = ...
1346 cdescentCycleNoRecalc(X0, wX0, wX2, Y0, b, active, shrinkFactor, threshold)
1347
1348 r = Y0 - X0(:,active)*b(active,:);
1349 bwX2 = b.*wX2;
1350 bold = b;
1351
1352 for j=find(active);
1353
1354 bj = wX0(:,j)' * r + bwX2(j);
1355
1356 margin = abs(bj) - threshold;
1357
1358
1359 if margin > 0
1360 b(j) = sign(bj) .* margin ./ shrinkFactor(j);
1361 else
1362 b(j) = 0;
1363 active(j) = false;
1364 end
1365
1366 r = r - X0(:,j)*(b(j)-bold(j));
1367 end
1368
1369 end
1370
1371
1372
1373
1374
1375 function [b,varargout] = ...
1376 penalizedWls(X,Y,b,active,weights,lambda,alpha,reltol)
1377
1378 weights = weights(:)';
1379
1380 [~,P] = size(X);
1381
1382 wX = bsxfun(@times,X,weights');
1383
1384 wX2 = zeros(P,1);
1385 wX2(active) = (weights * X(:,active).^2)';
1386 wX2calculated = active;
1387
1388 threshold = lambda * alpha;
1389
1390 shrinkFactor = wX2 + lambda * (1-alpha);
1391
1392
1393 while true
1394
1395 bold = b;
1396 old_active = active;
1397
1398 [b,active] = cdescentCycleNoRecalc(X,wX,wX2,Y, b,active,shrinkFactor,threshold);
1399
1400 if ~any( abs(b(old_active) - bold(old_active)) > reltol * max(1.0,abs(bold(old_active))) )
1401
1402
1403
1404
1405 bold = b;
1406 potentially_active = thresholdScreen(X,wX,Y,b,active,threshold);
1407 new_candidates = potentially_active & ~active;
1408 if any(new_candidates)
1409 [b,new_active,wX2,wX2calculated,shrinkFactor] = ...
1410 cdescentCycleNewCandidates(X,weights,wX,wX2,wX2calculated,Y, ...
1411 b,active,shrinkFactor,threshold,new_candidates);
1412 else
1413 new_active = active;
1414 end
1415
1416 if isequal(new_active, active)
1417 break
1418 else
1419 super_active = active | new_active;
1420 if ~any( abs(b(super_active) - bold(super_active)) > reltol * max(1.0,abs(bold(super_active))) )
1421
1422
1423
1424
1425
1426
1427
1428
1429 if sum(new_active) > sum(active)
1430 b = bold;
1431 else
1432 active = new_active;
1433 end
1434 break
1435 else
1436 active = new_active;
1437 end
1438 end
1439 end
1440
1441 end
1442
1443 varargout{1} = active;
1444
1445 end
1446
1447
1448
1449
1450
1451 function [b,varargout] = glmPenalizedWlsWrapper(X,Y,b,active,weights, ...
1452 lambda,alpha,reltol,ever_active)
1453
1454
1455
1456
1457 X0 = X(:,2:end);
1458
1459 weights = weights(:)';
1460
1461 normedWeights = weights / sum(weights);
1462
1463
1464 muX = normedWeights * X0;
1465 X0 = bsxfun(@minus,X0,muX);
1466
1467
1468 muY = normedWeights * Y;
1469 Y = Y - muY;
1470
1471 [bPredictors,varargout{1}] = penalizedWls(X0, Y, b(2:end), ...
1472 active,weights,lambda,alpha,reltol);
1473
1474 bPredictors(~ever_active,:) = 0;
1475
1476
1477
1478 Intercept = muY-muX*bPredictors;
1479 b = [Intercept; bPredictors];
1480
1481 end
1482
1483 function [lambdaMax, nullDev, nullIntercept] = computeLambdaMax(X, Y, weights, alpha, standardize, ...
1484 distr, link, dlinkFun, offset, isCanonical, dlinkFunCanonical, devFun)
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496 wsIllConditioned2 = warning('off','stats:glmfit:IllConditioned');
1497 wsIterationLimit = warning('off','stats:glmfit:IterationLimit');
1498 wsPerfectSeparation = warning('off','stats:glmfit:PerfectSeparation');
1499 wsBadScaling = warning('off','stats:glmfit:BadScaling');
1500 cleanupIllConditioned2 = onCleanup(@() warning(wsIllConditioned2));
1501 cleanupIterationLimit = onCleanup(@() warning(wsIterationLimit));
1502 cleanupPerfectSeparation = onCleanup(@() warning(wsPerfectSeparation));
1503 cleanupBadScaling = onCleanup(@() warning(wsBadScaling));
1504
1505 if ~isempty(weights)
1506 observationWeights = true;
1507 weights = weights(:)';
1508
1509 normalizedweights = weights / sum(weights);
1510 else
1511 observationWeights = false;
1512 end
1513
1514 [N,~] = size(X);
1515
1516
1517
1518
1519
1520 if standardize
1521
1522
1523 constantPredictors = (range(X)==0);
1524
1525 if ~observationWeights
1526
1527 [X0,~,~] = zscore(X,1);
1528 else
1529
1530 muX = normalizedweights * X;
1531 X0 = bsxfun(@minus,X,muX);
1532 sigmaX = sqrt( normalizedweights * (X0.^2) );
1533
1534 sigmaX(constantPredictors) = 1;
1535 X0 = bsxfun(@rdivide, X0, sigmaX);
1536 end
1537 else
1538 if ~observationWeights
1539
1540 muX = mean(X,1);
1541 X0 = bsxfun(@minus,X,muX);
1542 else
1543
1544 muX = normalizedweights(:)' * X;
1545 X0 = bsxfun(@minus,X,muX);
1546 end
1547 end
1548
1549 constantTerm = ones(length(Y),1);
1550 if isscalar(offset)
1551 [coeffs,nullDev] = glmfit(constantTerm,Y,distr,'constant','off', ...
1552 'link',link, 'weights',weights);
1553 predictedMu = glmval(coeffs,constantTerm,link,'constant','off');
1554 else
1555 [coeffs,nullDev] = glmfit(constantTerm,Y,distr,'constant','off', ...
1556 'link',link,'weights',weights,'offset',offset);
1557 predictedMu = glmval(coeffs,constantTerm,link,'constant','off','offset',offset);
1558 end
1559
1560 nullIntercept = coeffs;
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574 if observationWeights
1575 muDev = weights * devFun(mean(Y)*ones(length(Y),1), Y);
1576 else
1577 muDev = sum(devFun(mean(Y)*ones(length(Y),1), Y));
1578 end
1579 if (muDev - nullDev) / max([1.0 muDev nullDev]) < - 1.0e-4
1580 [~, lastid] = lastwarn;
1581 if strcmp(lastid,'stats:glmfit:BadScaling')
1582
1583
1584
1585
1586 predictedMu = mean(Y)*ones(length(Y),1);
1587 warning(message('stats:lassoGlm:DifficultLikelihood'));
1588 end
1589 end
1590
1591 if ~isCanonical
1592 X0 = bsxfun( @times, X0, dlinkFunCanonical(predictedMu) ./ dlinkFun(predictedMu) );
1593 end
1594
1595 if ~observationWeights
1596 dotp = abs(X0' * (Y - predictedMu));
1597 lambdaMax = max(dotp) / (N*alpha);
1598 else
1599 wX0 = bsxfun(@times, X0, normalizedweights');
1600 dotp = abs(sum(bsxfun(@times, wX0, (Y - predictedMu))));
1601 lambdaMax = max(dotp) / alpha;
1602 end
1603
1604 end
1605
1606 function lambda = computeLambdaSequence(lambdaMax, nLambda, lambdaRatio, LRdefault)
1607
1608
1609
1610 if nLambda==1
1611 lambda = lambdaMax;
1612 else
1613
1614 if lambdaRatio==0
1615 lambdaRatio = LRdefault;
1616 addZeroLambda = true;
1617 else
1618 addZeroLambda = false;
1619 end
1620 lambdaMin = lambdaMax * lambdaRatio;
1621 loghi = log(lambdaMax);
1622 loglo = log(lambdaMin);
1623 logrange = loghi - loglo;
1624 interval = -logrange/(nLambda-1);
1625 lambda = exp(loghi:interval:loglo)';
1626 if addZeroLambda
1627 lambda(end) = 0;
1628 else
1629 lambda(end) = lambdaMin;
1630 end
1631 end
1632
1633 end
1634