Home > voicebox > gaussmixg.m

gaussmixg

PURPOSE ^

GAUSSMIXG global mean, variance and mode of a GMM

SYNOPSIS ^

function [mg,vg,pg,pv]=gaussmixg(m,v,w,n)

DESCRIPTION ^

GAUSSMIXG global mean, variance and mode of a GMM

 Usage: (1) gaussmixg(m,v,w)               % plot the mean and mode positions of a GMM
        (2) [mg,vg]=gaussmixg(m,v,w)       % find global mean and covariance of a GMM
        (3) [mg,vg,pg]=gaussmixg(m,v,w)    % find global mean,covariance and mode of a GMM
        (4) [mg,vg,pg,pv]=gaussmixg(m,v,w) % ... also find log probability of the peak

  Inputs:  M(k,p) = mixture means for pg(p)
           V(k,p) or V(p,p,k) variances (diagonal or full)
           W(k,1) = mixture weights
           N      = maximum number of modes to find [default 1]

 Outputs: MG(1,p) = global mean
          VG(p,p) = global covariance
          PG(N,p) = sorted list of N modes
          PV(N,1) = log pdf at the modes PG(N,p) (in decreasing order)

  This routine finds the global mean and covariance matrix of a Gaussian Mixture (GMM). It also
  attempts to find up to N local maxima using a combination of the fixed point and quadratic
  Newton-Raphson algorithms from [1]. Currently, N must be less than or equal to the number of
  mixtures K. In general the PDF surface of a GMM can be very complicated with many local maxima [2]
  and, as discussed in [1,2], this algorithm is not guaranteed to find the N highest. In [2], it is
  conjectured that the number of local maxima is <=K for the following cases (a) P=1, (b) all
  mixture covariance matrices are equal and (c) all mixture covariance matrices are multiples of
  the identity.

 Refs:
   [1]    M. Á. Carreira-Perpiñán. Mode-finding for mixtures of gaussian distributions.
       IEEE Trans. Pattern Anal and Machine Intell, 22 (11): 1318–1323, 2000. doi: 10.1109/34.888716.
   [2] M. Á. Carreira-Perpiñán and C. K. I. Williams. On the number of modes of a gaussian mixture.
       In Proc Intl Conf on Scale Space Theories in Computer Vision, volume LNCS 2695, pages 625–640,
       Isle of Skye, June 2003. doi: 10.1007/3-540-44935-3_44.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [mg,vg,pg,pv]=gaussmixg(m,v,w,n)
0002 %GAUSSMIXG global mean, variance and mode of a GMM
0003 %
0004 % Usage: (1) gaussmixg(m,v,w)               % plot the mean and mode positions of a GMM
0005 %        (2) [mg,vg]=gaussmixg(m,v,w)       % find global mean and covariance of a GMM
0006 %        (3) [mg,vg,pg]=gaussmixg(m,v,w)    % find global mean,covariance and mode of a GMM
0007 %        (4) [mg,vg,pg,pv]=gaussmixg(m,v,w) % ... also find log probability of the peak
0008 %
0009 %  Inputs:  M(k,p) = mixture means for pg(p)
0010 %           V(k,p) or V(p,p,k) variances (diagonal or full)
0011 %           W(k,1) = mixture weights
0012 %           N      = maximum number of modes to find [default 1]
0013 %
0014 % Outputs: MG(1,p) = global mean
0015 %          VG(p,p) = global covariance
0016 %          PG(N,p) = sorted list of N modes
0017 %          PV(N,1) = log pdf at the modes PG(N,p) (in decreasing order)
0018 %
0019 %  This routine finds the global mean and covariance matrix of a Gaussian Mixture (GMM). It also
0020 %  attempts to find up to N local maxima using a combination of the fixed point and quadratic
0021 %  Newton-Raphson algorithms from [1]. Currently, N must be less than or equal to the number of
0022 %  mixtures K. In general the PDF surface of a GMM can be very complicated with many local maxima [2]
0023 %  and, as discussed in [1,2], this algorithm is not guaranteed to find the N highest. In [2], it is
0024 %  conjectured that the number of local maxima is <=K for the following cases (a) P=1, (b) all
0025 %  mixture covariance matrices are equal and (c) all mixture covariance matrices are multiples of
0026 %  the identity.
0027 %
0028 % Refs:
0029 %   [1]    M. Á. Carreira-Perpiñán. Mode-finding for mixtures of gaussian distributions.
0030 %       IEEE Trans. Pattern Anal and Machine Intell, 22 (11): 1318–1323, 2000. doi: 10.1109/34.888716.
0031 %   [2] M. Á. Carreira-Perpiñán and C. K. I. Williams. On the number of modes of a gaussian mixture.
0032 %       In Proc Intl Conf on Scale Space Theories in Computer Vision, volume LNCS 2695, pages 625–640,
0033 %       Isle of Skye, June 2003. doi: 10.1007/3-540-44935-3_44.
0034 
0035 % Bugs/Suggestions:
0036 % (1) Sometimes the mode is not found, e.g. m=[0 1; 1 0];v=[.01 10; 10 .01];
0037 %     has a true mode near (0,0). Could add to the list of mode candidates
0038 %     all the pairwise intersections of the mixtures.
0039 %     Another is: m=[0 0; 10 0.3]; v=[1 1; 1000 .001];
0040 % (2) When merging candidates, we should keep the one with the highest probability
0041 % (3) could preserve the fixed arrays between calls if p and/or k are unchanged
0042 %
0043 % See also: gaussmix, gaussmixd, gaussmixp, randvec
0044 
0045 %      Copyright (C) Mike Brookes 2000-2012
0046 %      Version: $Id: gaussmixg.m 3227 2013-07-04 15:42:04Z dmb $
0047 %
0048 %   VOICEBOX is a MATLAB toolbox for speech processing.
0049 %   Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html
0050 %
0051 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0052 %   This program is free software; you can redistribute it and/or modify
0053 %   it under the terms of the GNU General Public License as published by
0054 %   the Free Software Foundation; either version 2 of the License, or
0055 %   (at your option) any later version.
0056 %
0057 %   This program is distributed in the hope that it will be useful,
0058 %   but WITHOUT ANY WARRANTY; without even the implied warranty of
0059 %   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0060 %   GNU General Public License for more details.
0061 %
0062 %   You can obtain a copy of the GNU General Public License from
0063 %   http://www.gnu.org/copyleft/gpl.html or by writing to
0064 %   Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA.
0065 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0066 % Algorithm parameters
0067 nfp=2;          % number of fixed point iterations to do at start
0068 maxloop=60;     % maximum number of iterations
0069 ssf=0.1;        % factor when calculating minimum mode separation
0070 %Sort out arguments
0071 [k,p]=size(m);
0072 if nargin<4
0073     n=1;
0074     if nargin<3
0075         w=ones(k,1);
0076         if nargin<2
0077             v=ones(k,p);
0078         end
0079     end
0080 end
0081 if ~nargout
0082     if nargin<4
0083         n=k;
0084     end
0085     nao=4;
0086 else
0087     nao=nargout;
0088 end
0089 
0090 full=numel(size(v))>2 || k==1 && numel(v)>p; % test for full covariance matrices
0091 if full && p==1   % if p=1 then we force diagonal covariance matrices
0092     v=reshape(v,1,k)';
0093     full=false;
0094 end
0095 w=w/sum(w);  % force w to sum to unity
0096 % calculate the global mean and covariance
0097 mg=w.'*m;
0098 mz=m-mg(ones(k,1),:);   % means relative to global mean
0099 if nao>2                % need to calculate the modes
0100     nx=k;               % number of pg values initially
0101     kk=reshape(repmat(1:nx,k,1),k*nx,1); % [1 1 1 2 2 2 ... nx nx nx]'
0102     km=repmat(1:k,1,nx)'; % [1:k 1:k ... 1:k]'
0103     % sort out indexing for all data value pairs; needed to eliminate duplicates
0104     nxp=nx*(nx-1)/2;
0105     ja=(1:nxp)';
0106     jb=floor((3+sqrt(8*ja-3))/2);       % [2 3 3 4 4 4 ... nx]'
0107     ja=ja-(jb.^2-3*jb+2)/2;             % [1 1:2 1:3 ... 1:nx-1]'
0108     jc=ones(nxp,1);
0109     % sort out indexing for vectorized upper triagular matrix
0110     npu=p*(p+1)/2;                      % number of distinct elements in a symmetrial (p,p) matrix
0111     kw=1:npu;
0112     ku=floor((1+sqrt(8*kw-3))/2);       % [1 2 2 3 3 3 ... p]
0113     kv=kw-(ku.^2-ku)/2;                 % [1 1:2 1:3 ... 1:p]
0114     zpp=zeros(p,p);
0115     zpp(kv+p*(ku-1))=kw;
0116     zpp(ku+p*(kv-1))=kw;
0117     kw=reshape(zpp,1,[]);               % maps vectorized upper triangular to vectorized full matrix
0118     kp=repmat(1:p,1,p);                 % row indices for a (p,p) matrix as a row vector
0119     kq=reshape(repmat(1:p,p,1),1,p^2);  % col indices for a (p,p) matrix as a row vector
0120     kr=p*kp-p+kq;                       % transpose indexing for a vectorized (p,p) matrix
0121     kd=1:p+1:p^2;                       % diagonal indices of a (p,p) matrix
0122     % unity vectors to make efficient replication
0123     wk=ones(k,1);
0124     wnx=ones(nx,1);
0125     if full
0126         vg=mz.'*(mz.*w(:,ones(1,p)))+reshape(reshape(v,p^2,k)*w,p,p);
0127         % now determine the mode
0128         vi=zeros(p*k,p);                    % stack of k inverse cov matrices each size p*p times -0.5
0129         vim=zeros(p*k,1);                   % stack of k vectors of the form -0.5*inv(vt)*m
0130         mtk=vim;                             % stack of k vectors of the form m
0131         lvm=zeros(k,1);
0132         wpk=repmat((1:p)',k,1);
0133         for i=1:k    % loop for each mixture
0134             [uvk,dvk]=eig(v(:,:,i));      % find eigenvalues
0135             dvk=diag(dvk);
0136             if(any(dvk<=0))
0137                 error('Covariance matrix for mixture %d is not positive definite',i);
0138             end
0139             vik=-0.5*uvk*diag(dvk.^(-1))*uvk';   % calculate inverse including -0.5 factor
0140             vi((i-1)*p+(1:p),:)=vik;           % vi contains all mixture inverses stacked on top of each other
0141             vim((i-1)*p+(1:p))=vik*m(i,:)';   % vim contains vi*m for all mixtures stacked on top of each other
0142             mtk((i-1)*p+(1:p))=m(i,:)';       % mtk contains all mixture means stacked on top of each other
0143             lvm(i)=log(w(i))-0.5*sum(log(dvk));       % vm contains the weighted sqrt of det(vi) for each mixture
0144         end
0145         vif=reshape(permute(reshape(vi,p,k,p),[2 1 3]),k,p^2); % each covariance matrix as a vectorized row
0146         vimf=reshape(vim,p,k)'; % vi*m as a row for each mixture
0147         ss=sqrt(min(v(repmat(kd,k,1)+repmat(p^2*(0:k-1)',1,p)),[],1))*ssf/sqrt(p);   % minimum separation of modes [this is a conservative guess]
0148     else
0149         vg=mz.'*(mz.*w(:,ones(1,p)))+diag(w.'*v);
0150         % now determine the mode
0151         vi=-0.5*v.^(-1);                % vi(k,p) = data-independent scale factor in exponent
0152         vi2=vi(:,ku).*vi(:,kv);         % vi2(k,npu) = upper triangular Hessian data dependent term
0153         lvm=log(w)-0.5*sum(log(v),2);   % log of external scale factor (excluding -0.5*p*log(2pi) term)
0154         vim=vi.*m;
0155         vim2=vim(:,ku).*vim(:,kv);      % vim2(k,npu) = upper triangular Hessian data independent term
0156         vimvi=vim(:,kp).*vi(:,kq);      % vimvi(k,p^2) = vectorized Hessian term
0157         ss=sqrt(min(v,[],1))*ssf/sqrt(p);   % minimum separation of modes [this is a conservative guess]
0158     end
0159     pgf=zeros(nx,p);                    % space for fixed point update
0160     sv=0.01*ss;                         % convergence threshold
0161     pg=m;  % initialize mode candidates to mixture means
0162     i=1;   % loop counter
0163     %     gx=zeros(nx,p);   %%%%%%%%%%%% temp
0164     while i<=maxloop
0165         %         pg00=pg0;   %%%%%%%%%%%% temp
0166         pg0=pg;                         % save previous mode candidates pg(nx,p)
0167         if full
0168             py=reshape(sum(reshape((vi*pg'-vim(:,wnx)).*(pg(:,wpk)'-mtk(:,wnx)),p,nx*k),1),k,nx)+lvm(:,wnx);
0169             mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
0170             px=exp(py-mx(wk,:));            % find normalized probability of each mixture for each datapoint
0171             ps=sum(px,1);                   % total normalized likelihood of each data point
0172             px=(px./ps(wk,:))';             % px(nx,k) = relative mixture probabilities for each data point (rows sum to 1)
0173             % calculate the fixed point update
0174             pxvif=px*vif;    % pxvif(nx,p^2)
0175             pxvimf=px*vimf;  % pxvimf(nx,p)
0176             for j=1:nx
0177                 pgf(j,:)=pxvimf(j,:)/reshape(pxvif(j,:),p,p);
0178             end
0179         else
0180             py=reshape(sum((pg(kk,:)-m(km,:)).^2.*vi(km,:),2),k,nx)+lvm(:,wnx);
0181             mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
0182             px=exp(py-mx(wk,:));            % find normalized probability of each mixture for each datapoint
0183             ps=sum(px,1);                   % total normalized likelihood of each data point
0184             px=(px./ps(wk,:))';             % px(nx,k) = relative mixture probabilities for each data point (rows sum to 1)
0185             % calculate the fixed point update
0186             pxvim=px*vim;
0187             pxvi=px*vi;
0188             pgf=pxvim./pxvi;       % fixed point update for all points
0189         end
0190         if i>nfp
0191             % calculate gradient and Hessian; see [1] equations (4) and (5)
0192             lp=log(ps)+mx;              % log prob of each data point
0193             if full
0194                 %                 gx0=gx;   %%%%%%%%%%%% temp
0195                 gx=pxvimf-reshape(sum(repmat(pg,p,1).*reshape(pxvif,[],p),2),nx,p);
0196                 vimpg=repmat(vimf,nx,1)-reshape(permute(reshape(pg*vi',nx,p,k),[3 1 2]),[],p); % vimpg(k*nx,p)
0197                 hx1=2*reshape(sum(reshape(repmat(reshape(px',[],1),1,npu).*vimpg(:,ku).*vimpg(:,kv),k,[]),1),nx,[]);
0198                 hx=pxvif+hx1(:,kw);
0199             else
0200                 gx=pxvim-pxvi.*pg;               % gradient for each data point (one row per point)
0201                 hx1=px*vim2+(px*vi2).*pg(:,ku).*pg(:,kv);
0202                 hx2=(px*(vimvi)).*pg(:,kq);
0203                 hx=2*(hx1(:,kw)-hx2-hx2(:,kr));
0204                 hx(:,kd)=hx(:,kd)+pxvi;
0205             end
0206             hx=reshape(hx',p,p,nx);
0207             for j=1:nx
0208                 if all(eig(hx(:,:,j))<0)  % if positive definite
0209                     pg(j,:)=pg(j,:)+gx(j,:)/hx(:,:,j); % do a Newton-Raphson update
0210                     if full
0211                         pyj=sum(reshape((vi*pg(j,:)'-vim).*(pg(j,wpk)'-mtk),p,k),1)'+lvm;
0212                     else
0213                         pyj=sum((repmat(pg(j,:),k,1)-m).^2.*vi,2)+lvm;
0214                     end
0215                     mxj=max(pyj);                % find normalizing factor for each data point to prevent underflow when using exp()
0216                     pxj=exp(pyj-mxj);            % find normalized probability of each mixture for each datapoint
0217                     psj=sum(pxj,1);                   % total normalized likelihood of each data point
0218                     lpj=log(psj)+mxj;              % log prob of updated data point
0219                     if lpj<lp(j)       % check if the probability has decreased
0220                         pg(j,:)=pgf(j,:);   % if so, do fixed point update
0221                     end
0222                 else
0223                     pg(j,:)=pgf(j,:);   % else do fixed point update
0224                 end
0225             end
0226         else
0227             pg=pgf;       % fixed point update for all points
0228         end
0229         if all(all(abs(pg-pg0)<sv(wnx,:))) && i+2<maxloop
0230             maxloop=min(maxloop,i+2);        % two more loops if converged if converged
0231         end
0232         %         [all(pg==pgf,2) pg [i; repmat(NaN,nx-1,1)] pg-pg0]  %debug: [fixed-point x-mode iteration delta-x]
0233         jd=all(abs(pg(jb,:)-pg(ja,:))<ss(jc,:),2);   % find duplicate modes
0234         if any(jd)
0235             jx=sparse([(1:nx)';ja;jb],[(1:nx)';jb;ja],[wnx;jd;jd]);   % neighbour matrix
0236             kx=any((jx*jx)>0 & ~jx,2);  % find chains that  are not fully connected
0237             while any(kx)
0238                 kx=any(jx(:,kx),2);     % destroy all links connected to these chains
0239                 jx(kx,:)=0;
0240                 jx(:,kx)=0;
0241                 % jx(kx,kx)=1;
0242                 kx=any((jx*jx)>0 & ~jx,2);
0243             end
0244             jx([1:nx+1:nx*nx (ja+(jb-1)*nx)'])=0; % reset the upper triangle + diagonal
0245             pg(any(jx,2),:)=[];   % delete the duplicates
0246             % update nx and anything that depends on it
0247             nx=size(pg,1);
0248             pgf=zeros(nx,p);                    % space for fixed point update
0249             wnx=ones(nx,1);
0250             nxp=nx*(nx-1)/2;
0251             ja=ja(1:nxp);
0252             jb=jb(1:nxp);
0253             kk=reshape(repmat(1:nx,k,1),k*nx,1); % [1 1 1 2 2 2 ... nx nx nx]'
0254             km=reshape(repmat(1:k,1,nx),k*nx,1); % [1:k 1:k ... 1:k]'
0255             jc=ones(nxp,1);
0256         end
0257         i=i+1;
0258     end
0259     %     calculate the log pdf at each mode
0260     if full
0261         py=reshape(sum(reshape((vi*pg'-vim(:,wnx)).*(pg(:,wpk)'-mtk(:,wnx)),p,nx*k),1),k,nx)+lvm(:,wnx);
0262         mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
0263         px=exp(py-mx(wk,:));            % find normalized probability of each mixture for each datapoint
0264         ps=sum(px,1);                   % total normalized likelihood of each data point
0265     else
0266         py=reshape(sum((pg(kk,:)-m(km,:)).^2.*vi(km,:),2),k,nx)+lvm(:,wnx);
0267         mx=max(py,[],1);                % find normalizing factor for each data point to prevent underflow when using exp()
0268         px=exp(py-mx(wk,:));            % find normalized probability of each mixture for each datapoint
0269         ps=sum(px,1);                   % total normalized likelihood of each data point
0270     end
0271     [pv,ix]=sort((log(ps)+mx)'-0.5*p*log(2*pi),'descend');
0272     pg=pg(ix,:);
0273     if n<numel(pv) % only keep the first n modes
0274         pg=pg(1:n,:);
0275         pv=pv(1:n);
0276     end
0277 elseif nao>1
0278     if full
0279         vg=mz.'*(mz.*w(:,ones(1,p)))+reshape(reshape(v,p^2,k)*w,p,p);
0280     else
0281         vg=mz.'*(mz.*w(:,ones(1,p)))+diag(w.'*v);
0282     end
0283 end
0284 
0285 if ~nargout
0286     % now plot the result
0287     clf;
0288     pg1=pg(1,:);
0289     lpm=gaussmixp(mg,m,v,w);
0290     lpp=pv(1);
0291     switch p
0292         case 1
0293             gaussmixp([],m,v,w);
0294             hold on;
0295             ylim=get(gca,'ylim')';
0296             plot([mg mg]',ylim,'-k',[mg mg; mg mg]+[-1 1; -1 1]*sqrt(vg),ylim,':k');
0297             plot(pg(1),pv(1),'^k');
0298             if numel(pg)>1
0299                 plot(pg(2:end),pv(2:end),'xk');
0300             end
0301             hold off;
0302             title(sprintf('Mean+-sd = %.3g+-%.3g LogP = %.3g, Mode\\Delta = %.3g LogP = %.3g',mg,sqrt(vg),lpm,pg1,lpp));
0303             xlabel('x');
0304         case 2
0305             gaussmixp([],m,v,w);
0306             hold on;
0307             t=linspace(0,2*pi,100);
0308             xysd=chol(vg)'*[cos(t); sin(t)]+repmat(mg',1,length(t));
0309             plot(xysd(1,:),xysd(2,:),':k',mg(1),mg(2),'ok');
0310             plot(pg(1,1),pg(1,2),'^k');
0311             if numel(pv)>1
0312                 plot(pg(2:end,1),pg(2:end,2),'xk');
0313             end
0314             hold off;
0315             title(sprintf('Mean = (%.3g,%.3g) LogP = %.3g, Mode:\\Delta = (%.3g,%.3g) LogP = %.3g',mg,lpm,pg1,pv(1)));
0316             xlabel('x');
0317             ylabel('y');
0318         otherwise
0319             nx=200;
0320             nc=ceil(sqrt(p/2));
0321             nr=ceil(p/nc);
0322             sdx=sqrt(diag(vg))';  % std deviation
0323             minx=min([mg; pg],[],1)-1.5*sdx;
0324             maxx=max([mg; pg],[],1)+1.5*sdx;
0325             ix=2:p; % selected indices
0326             for i=1:p
0327                 xi=linspace(minx(i),maxx(i),nx)';
0328                 [mm,vm,wm]=gaussmixd(mg(ix),m,v,w,[],ix);
0329                 ym=gaussmixp(xi,mm,vm,wm')+lpm-gaussmixp(mg(i),mm,vm,wm');
0330                 [mp,vp,wp]=gaussmixd(pg(1,ix),m,v,w,[],ix);
0331                 yp=gaussmixp(xi,mp,vp,wp')+lpp-gaussmixp(pg1(i),mp,vp,wp');
0332                 subplot(nr,nc,i);
0333                 plot(xi,ym,'-k',mg(i),lpm,'ok',xi,yp,':b',pg1(i),lpp,'^b');
0334                 axisenlarge([-1 -1 -1 -1.05]);
0335                 hold on
0336                 plot([mg(i) mg(i)],get(gca,'ylim'),'-k');
0337                 hold off
0338                 xlabel(sprintf('x[%d], Mean = %.3g, Mode\\Delta = %.3g',i,mg(i),pg1(i)));
0339                 if i==1
0340                     title(sprintf('Log Prob: Mean = %.3g, Mode\\Delta = %.3g',lpm,lpp));
0341                 end
0342                 ix(i)=i;
0343             end
0344     end
0345 end

Generated on Tue 19-Sep-2017 12:07:31 by m2html © 2003