V_GAUSSMIXG global mean, variance and mode of a GMM Usage: (1) v_gaussmixg(m,v,w) % plot the mean and mode positions of a GMM (2) [mg,vg]=v_gaussmixg(m,v,w) % find global mean and covariance of a GMM (3) [mg,vg,pg]=v_gaussmixg(m,v,w) % find global mean,covariance and mode of a GMM (4) [mg,vg,pg,pv]=v_gaussmixg(m,v,w) % ... also find log probability of the peak Inputs: M(k,p) = mixture means for pg(p) V(k,p) or V(p,p,k) variances (diagonal or full) W(k,1) = mixture weights N = maximum number of modes to find [default 1] Outputs: MG(1,p) = global mean VG(p,p) = global covariance PG(N,p) = sorted list of N modes PV(N,1) = log pdf at the modes PG(N,p) (in decreasing order) This routine finds the global mean and covariance matrix of a Gaussian Mixture (GMM). It also attempts to find up to N local maxima using a combination of the fixed point and quadratic Newton-Raphson algorithms from [1]. Currently, N must be less than or equal to the number of mixtures K. In general the PDF surface of a GMM can be very complicated with many local maxima [2] and, as discussed in [1,2], this algorithm is not guaranteed to find the N highest. In [2], it is conjectured that the number of local maxima is <=K for the following cases (a) P=1, (b) all mixture covariance matrices are equal and (c) all mixture covariance matrices are multiples of the identity. Refs: [1] M. A. Carreira-Perpinan. Mode-finding for mixtures of gaussian distributions. IEEE Trans. Pattern Anal and Machine Intell, 22 (11): 1318-1323, 2000. doi: 10.1109/34.888716. [2] M. A. Carreira-Perpinan and C. K. I. Williams. On the number of modes of a gaussian mixture. In Proc Intl Conf on Scale Space Theories in Computer Vision, volume LNCS 2695, pages 625-640, Isle of Skye, June 2003. doi: 10.1007/3-540-44935-3_44.
0001 function [mg,vg,pg,pv]=v_gaussmixg(m,v,w,n) 0002 %V_GAUSSMIXG global mean, variance and mode of a GMM 0003 % 0004 % Usage: (1) v_gaussmixg(m,v,w) % plot the mean and mode positions of a GMM 0005 % (2) [mg,vg]=v_gaussmixg(m,v,w) % find global mean and covariance of a GMM 0006 % (3) [mg,vg,pg]=v_gaussmixg(m,v,w) % find global mean,covariance and mode of a GMM 0007 % (4) [mg,vg,pg,pv]=v_gaussmixg(m,v,w) % ... also find log probability of the peak 0008 % 0009 % Inputs: M(k,p) = mixture means for pg(p) 0010 % V(k,p) or V(p,p,k) variances (diagonal or full) 0011 % W(k,1) = mixture weights 0012 % N = maximum number of modes to find [default 1] 0013 % 0014 % Outputs: MG(1,p) = global mean 0015 % VG(p,p) = global covariance 0016 % PG(N,p) = sorted list of N modes 0017 % PV(N,1) = log pdf at the modes PG(N,p) (in decreasing order) 0018 % 0019 % This routine finds the global mean and covariance matrix of a Gaussian Mixture (GMM). It also 0020 % attempts to find up to N local maxima using a combination of the fixed point and quadratic 0021 % Newton-Raphson algorithms from [1]. Currently, N must be less than or equal to the number of 0022 % mixtures K. In general the PDF surface of a GMM can be very complicated with many local maxima [2] 0023 % and, as discussed in [1,2], this algorithm is not guaranteed to find the N highest. In [2], it is 0024 % conjectured that the number of local maxima is <=K for the following cases (a) P=1, (b) all 0025 % mixture covariance matrices are equal and (c) all mixture covariance matrices are multiples of 0026 % the identity. 0027 % 0028 % Refs: 0029 % [1] M. A. Carreira-Perpinan. Mode-finding for mixtures of gaussian distributions. 0030 % IEEE Trans. Pattern Anal and Machine Intell, 22 (11): 1318-1323, 2000. doi: 10.1109/34.888716. 0031 % [2] M. A. Carreira-Perpinan and C. K. I. Williams. On the number of modes of a gaussian mixture. 0032 % In Proc Intl Conf on Scale Space Theories in Computer Vision, volume LNCS 2695, pages 625-640, 0033 % Isle of Skye, June 2003. doi: 10.1007/3-540-44935-3_44. 0034 0035 % Bugs/Suggestions: 0036 % (1) Sometimes the mode is not found, e.g. m=[0 1; 1 0];v=[.01 10; 10 .01]; 0037 % has a true mode near (0,0). Could add to the list of mode candidates 0038 % all the pairwise intersections of the mixtures. 0039 % Another is: m=[0 0; 10 0.3]; v=[1 1; 1000 .001]; 0040 % (2) When merging candidates, we should keep the one with the highest probability 0041 % (3) could preserve the fixed arrays between calls if p and/or k are unchanged 0042 % 0043 % See also: v_gaussmix, v_gaussmixd, v_gaussmixp, v_randvec 0044 0045 % Copyright (C) Mike Brookes 2000-2012 0046 % Version: $Id: v_gaussmixg.m 10865 2018-09-21 17:22:45Z dmb $ 0047 % 0048 % VOICEBOX is a MATLAB toolbox for speech processing. 0049 % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 0050 % 0051 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0052 % This program is free software; you can redistribute it and/or modify 0053 % it under the terms of the GNU General Public License as published by 0054 % the Free Software Foundation; either version 2 of the License, or 0055 % (at your option) any later version. 0056 % 0057 % This program is distributed in the hope that it will be useful, 0058 % but WITHOUT ANY WARRANTY; without even the implied warranty of 0059 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 0060 % GNU General Public License for more details. 0061 % 0062 % You can obtain a copy of the GNU General Public License from 0063 % http://www.gnu.org/copyleft/gpl.html or by writing to 0064 % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 0065 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0066 % Algorithm parameters 0067 nfp=2; % number of fixed point iterations to do at start 0068 maxloop=60; % maximum number of iterations 0069 ssf=0.1; % factor when calculating minimum mode separation 0070 %Sort out arguments 0071 [k,p]=size(m); 0072 if nargin<4 0073 n=1; 0074 if nargin<3 0075 w=ones(k,1); 0076 if nargin<2 0077 v=ones(k,p); 0078 end 0079 end 0080 end 0081 if ~nargout 0082 if nargin<4 0083 n=k; 0084 end 0085 nao=4; 0086 else 0087 nao=nargout; 0088 end 0089 0090 full=numel(size(v))>2 || k==1 && numel(v)>p; % test for full covariance matrices 0091 if full && p==1 % if p=1 then we force diagonal covariance matrices 0092 v=reshape(v,1,k)'; 0093 full=false; 0094 end 0095 w=w/sum(w); % force w to sum to unity 0096 % calculate the global mean and covariance 0097 mg=w.'*m; 0098 mz=m-mg(ones(k,1),:); % means relative to global mean 0099 if nao>2 % need to calculate the modes 0100 nx=k; % number of pg values initially 0101 kk=reshape(repmat(1:nx,k,1),k*nx,1); % [1 1 1 2 2 2 ... nx nx nx]' 0102 km=repmat(1:k,1,nx)'; % [1:k 1:k ... 1:k]' 0103 % sort out indexing for all data value pairs; needed to eliminate duplicates 0104 nxp=nx*(nx-1)/2; 0105 ja=(1:nxp)'; 0106 jb=floor((3+sqrt(8*ja-3))/2); % [2 3 3 4 4 4 ... nx]' 0107 ja=ja-(jb.^2-3*jb+2)/2; % [1 1:2 1:3 ... 1:nx-1]' 0108 jc=ones(nxp,1); 0109 % sort out indexing for vectorized upper triagular matrix 0110 npu=p*(p+1)/2; % number of distinct elements in a symmetrial (p,p) matrix 0111 kw=1:npu; 0112 ku=floor((1+sqrt(8*kw-3))/2); % [1 2 2 3 3 3 ... p] 0113 kv=kw-(ku.^2-ku)/2; % [1 1:2 1:3 ... 1:p] 0114 zpp=zeros(p,p); 0115 zpp(kv+p*(ku-1))=kw; 0116 zpp(ku+p*(kv-1))=kw; 0117 kw=reshape(zpp,1,[]); % maps vectorized upper triangular to vectorized full matrix 0118 kp=repmat(1:p,1,p); % row indices for a (p,p) matrix as a row vector 0119 kq=reshape(repmat(1:p,p,1),1,p^2); % col indices for a (p,p) matrix as a row vector 0120 kr=p*kp-p+kq; % transpose indexing for a vectorized (p,p) matrix 0121 kd=1:p+1:p^2; % diagonal indices of a (p,p) matrix 0122 % unity vectors to make efficient replication 0123 wk=ones(k,1); 0124 wnx=ones(nx,1); 0125 if full 0126 vg=mz.'*(mz.*w(:,ones(1,p)))+reshape(reshape(v,p^2,k)*w,p,p); 0127 % now determine the mode 0128 vi=zeros(p*k,p); % stack of k inverse cov matrices each size p*p times -0.5 0129 vim=zeros(p*k,1); % stack of k vectors of the form -0.5*inv(vt)*m 0130 mtk=vim; % stack of k vectors of the form m 0131 lvm=zeros(k,1); 0132 wpk=repmat((1:p)',k,1); 0133 for i=1:k % loop for each mixture 0134 [uvk,dvk]=eig(v(:,:,i)); % find eigenvalues 0135 dvk=diag(dvk); 0136 if(any(dvk<=0)) 0137 error('Covariance matrix for mixture %d is not positive definite',i); 0138 end 0139 vik=-0.5*uvk*diag(dvk.^(-1))*uvk'; % calculate inverse including -0.5 factor 0140 vi((i-1)*p+(1:p),:)=vik; % vi contains all mixture inverses stacked on top of each other 0141 vim((i-1)*p+(1:p))=vik*m(i,:)'; % vim contains vi*m for all mixtures stacked on top of each other 0142 mtk((i-1)*p+(1:p))=m(i,:)'; % mtk contains all mixture means stacked on top of each other 0143 lvm(i)=log(w(i))-0.5*sum(log(dvk)); % vm contains the weighted sqrt of det(vi) for each mixture 0144 end 0145 vif=reshape(permute(reshape(vi,p,k,p),[2 1 3]),k,p^2); % each covariance matrix as a vectorized row 0146 vimf=reshape(vim,p,k)'; % vi*m as a row for each mixture 0147 ss=sqrt(min(v(repmat(kd,k,1)+repmat(p^2*(0:k-1)',1,p)),[],1))*ssf/sqrt(p); % minimum separation of modes [this is a conservative guess] 0148 else 0149 vg=mz.'*(mz.*w(:,ones(1,p)))+diag(w.'*v); 0150 % now determine the mode 0151 vi=-0.5*v.^(-1); % vi(k,p) = data-independent scale factor in exponent 0152 vi2=vi(:,ku).*vi(:,kv); % vi2(k,npu) = upper triangular Hessian data dependent term 0153 lvm=log(w)-0.5*sum(log(v),2); % log of external scale factor (excluding -0.5*p*log(2pi) term) 0154 vim=vi.*m; 0155 vim2=vim(:,ku).*vim(:,kv); % vim2(k,npu) = upper triangular Hessian data independent term 0156 vimvi=vim(:,kp).*vi(:,kq); % vimvi(k,p^2) = vectorized Hessian term 0157 ss=sqrt(min(v,[],1))*ssf/sqrt(p); % minimum separation of modes [this is a conservative guess] 0158 end 0159 pgf=zeros(nx,p); % space for fixed point update 0160 sv=0.01*ss; % convergence threshold 0161 pg=m; % initialize mode candidates to mixture means 0162 i=1; % loop counter 0163 % gx=zeros(nx,p); %%%%%%%%%%%% temp 0164 while i<=maxloop 0165 % pg00=pg0; %%%%%%%%%%%% temp 0166 pg0=pg; % save previous mode candidates pg(nx,p) 0167 if full 0168 py=reshape(sum(reshape((vi*pg'-vim(:,wnx)).*(pg(:,wpk)'-mtk(:,wnx)),p,nx*k),1),k,nx)+lvm(:,wnx); 0169 mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp() 0170 px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint 0171 ps=sum(px,1); % total normalized likelihood of each data point 0172 px=(px./ps(wk,:))'; % px(nx,k) = relative mixture probabilities for each data point (rows sum to 1) 0173 % calculate the fixed point update 0174 pxvif=px*vif; % pxvif(nx,p^2) 0175 pxvimf=px*vimf; % pxvimf(nx,p) 0176 for j=1:nx 0177 pgf(j,:)=pxvimf(j,:)/reshape(pxvif(j,:),p,p); 0178 end 0179 else 0180 py=reshape(sum((pg(kk,:)-m(km,:)).^2.*vi(km,:),2),k,nx)+lvm(:,wnx); 0181 mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp() 0182 px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint 0183 ps=sum(px,1); % total normalized likelihood of each data point 0184 px=(px./ps(wk,:))'; % px(nx,k) = relative mixture probabilities for each data point (rows sum to 1) 0185 % calculate the fixed point update 0186 pxvim=px*vim; 0187 pxvi=px*vi; 0188 pgf=pxvim./pxvi; % fixed point update for all points 0189 end 0190 if i>nfp 0191 % calculate gradient and Hessian; see [1] equations (4) and (5) 0192 lp=log(ps)+mx; % log prob of each data point 0193 if full 0194 % gx0=gx; %%%%%%%%%%%% temp 0195 gx=pxvimf-reshape(sum(repmat(pg,p,1).*reshape(pxvif,[],p),2),nx,p); 0196 vimpg=repmat(vimf,nx,1)-reshape(permute(reshape(pg*vi',nx,p,k),[3 1 2]),[],p); % vimpg(k*nx,p) 0197 hx1=2*reshape(sum(reshape(repmat(reshape(px',[],1),1,npu).*vimpg(:,ku).*vimpg(:,kv),k,[]),1),nx,[]); 0198 hx=pxvif+hx1(:,kw); 0199 else 0200 gx=pxvim-pxvi.*pg; % gradient for each data point (one row per point) 0201 hx1=px*vim2+(px*vi2).*pg(:,ku).*pg(:,kv); 0202 hx2=(px*(vimvi)).*pg(:,kq); 0203 hx=2*(hx1(:,kw)-hx2-hx2(:,kr)); 0204 hx(:,kd)=hx(:,kd)+pxvi; 0205 end 0206 hx=reshape(hx',p,p,nx); 0207 for j=1:nx 0208 if all(eig(hx(:,:,j))<0) % if positive definite 0209 pg(j,:)=pg(j,:)+gx(j,:)/hx(:,:,j); % do a Newton-Raphson update 0210 if full 0211 pyj=sum(reshape((vi*pg(j,:)'-vim).*(pg(j,wpk)'-mtk),p,k),1)'+lvm; 0212 else 0213 pyj=sum((repmat(pg(j,:),k,1)-m).^2.*vi,2)+lvm; 0214 end 0215 mxj=max(pyj); % find normalizing factor for each data point to prevent underflow when using exp() 0216 pxj=exp(pyj-mxj); % find normalized probability of each mixture for each datapoint 0217 psj=sum(pxj,1); % total normalized likelihood of each data point 0218 lpj=log(psj)+mxj; % log prob of updated data point 0219 if lpj<lp(j) % check if the probability has decreased 0220 pg(j,:)=pgf(j,:); % if so, do fixed point update 0221 end 0222 else 0223 pg(j,:)=pgf(j,:); % else do fixed point update 0224 end 0225 end 0226 else 0227 pg=pgf; % fixed point update for all points 0228 end 0229 if all(all(abs(pg-pg0)<sv(wnx,:))) && i+2<maxloop 0230 maxloop=min(maxloop,i+2); % two more loops if converged if converged 0231 end 0232 % [all(pg==pgf,2) pg [i; repmat(NaN,nx-1,1)] pg-pg0] %debug: [fixed-point x-mode iteration delta-x] 0233 jd=all(abs(pg(jb,:)-pg(ja,:))<ss(jc,:),2); % find duplicate modes 0234 if any(jd) 0235 jx=sparse([(1:nx)';ja;jb],[(1:nx)';jb;ja],[wnx;jd;jd]); % neighbour matrix 0236 kx=any((jx*jx)>0 & ~jx,2); % find chains that are not fully connected 0237 while any(kx) 0238 kx=any(jx(:,kx),2); % destroy all links connected to these chains 0239 jx(kx,:)=0; 0240 jx(:,kx)=0; 0241 % jx(kx,kx)=1; 0242 kx=any((jx*jx)>0 & ~jx,2); 0243 end 0244 jx([1:nx+1:nx*nx (ja+(jb-1)*nx)'])=0; % reset the upper triangle + diagonal 0245 pg(any(jx,2),:)=[]; % delete the duplicates 0246 % update nx and anything that depends on it 0247 nx=size(pg,1); 0248 pgf=zeros(nx,p); % space for fixed point update 0249 wnx=ones(nx,1); 0250 nxp=nx*(nx-1)/2; 0251 ja=ja(1:nxp); 0252 jb=jb(1:nxp); 0253 kk=reshape(repmat(1:nx,k,1),k*nx,1); % [1 1 1 2 2 2 ... nx nx nx]' 0254 km=reshape(repmat(1:k,1,nx),k*nx,1); % [1:k 1:k ... 1:k]' 0255 jc=ones(nxp,1); 0256 end 0257 i=i+1; 0258 end 0259 % calculate the log pdf at each mode 0260 if full 0261 py=reshape(sum(reshape((vi*pg'-vim(:,wnx)).*(pg(:,wpk)'-mtk(:,wnx)),p,nx*k),1),k,nx)+lvm(:,wnx); 0262 mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp() 0263 px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint 0264 ps=sum(px,1); % total normalized likelihood of each data point 0265 else 0266 py=reshape(sum((pg(kk,:)-m(km,:)).^2.*vi(km,:),2),k,nx)+lvm(:,wnx); 0267 mx=max(py,[],1); % find normalizing factor for each data point to prevent underflow when using exp() 0268 px=exp(py-mx(wk,:)); % find normalized probability of each mixture for each datapoint 0269 ps=sum(px,1); % total normalized likelihood of each data point 0270 end 0271 [pv,ix]=sort((log(ps)+mx)'-0.5*p*log(2*pi),'descend'); 0272 pg=pg(ix,:); 0273 if n<numel(pv) % only keep the first n modes 0274 pg=pg(1:n,:); 0275 pv=pv(1:n); 0276 end 0277 elseif nao>1 0278 if full 0279 vg=mz.'*(mz.*w(:,ones(1,p)))+reshape(reshape(v,p^2,k)*w,p,p); 0280 else 0281 vg=mz.'*(mz.*w(:,ones(1,p)))+diag(w.'*v); 0282 end 0283 end 0284 0285 if ~nargout 0286 % now plot the result 0287 clf; 0288 pg1=pg(1,:); 0289 lpm=v_gaussmixp(mg,m,v,w); 0290 lpp=pv(1); 0291 switch p 0292 case 1 0293 v_gaussmixp([],m,v,w); 0294 hold on; 0295 ylim=get(gca,'ylim')'; 0296 plot([mg mg]',ylim,'-k',[mg mg; mg mg]+[-1 1; -1 1]*sqrt(vg),ylim,':k'); 0297 plot(pg(1),pv(1),'^k'); 0298 if numel(pg)>1 0299 plot(pg(2:end),pv(2:end),'xk'); 0300 end 0301 hold off; 0302 title(sprintf('Mean+-sd = %.3g+-%.3g LogP = %.3g, Mode\\Delta = %.3g LogP = %.3g',mg,sqrt(vg),lpm,pg1,lpp)); 0303 xlabel('x'); 0304 case 2 0305 v_gaussmixp([],m,v,w); 0306 hold on; 0307 t=linspace(0,2*pi,100); 0308 xysd=chol(vg)'*[cos(t); sin(t)]+repmat(mg',1,length(t)); 0309 plot(xysd(1,:),xysd(2,:),':k',mg(1),mg(2),'ok'); 0310 plot(pg(1,1),pg(1,2),'^k'); 0311 if numel(pv)>1 0312 plot(pg(2:end,1),pg(2:end,2),'xk'); 0313 end 0314 hold off; 0315 title(sprintf('Mean = (%.3g,%.3g) LogP = %.3g, Mode:\\Delta = (%.3g,%.3g) LogP = %.3g',mg,lpm,pg1,pv(1))); 0316 xlabel('x'); 0317 ylabel('y'); 0318 otherwise 0319 nx=200; 0320 nc=ceil(sqrt(p/2)); 0321 nr=ceil(p/nc); 0322 sdx=sqrt(diag(vg))'; % std deviation 0323 minx=min([mg; pg],[],1)-1.5*sdx; 0324 maxx=max([mg; pg],[],1)+1.5*sdx; 0325 ix=2:p; % selected indices 0326 for i=1:p 0327 xi=linspace(minx(i),maxx(i),nx)'; 0328 [mm,vm,wm]=v_gaussmixd(mg(ix),m,v,w,[],ix); 0329 ym=v_gaussmixp(xi,mm,vm,wm')+lpm-gaussmixp(mg(i),mm,vm,wm'); 0330 [mp,vp,wp]=v_gaussmixd(pg(1,ix),m,v,w,[],ix); 0331 yp=v_gaussmixp(xi,mp,vp,wp')+lpp-gaussmixp(pg1(i),mp,vp,wp'); 0332 subplot(nr,nc,i); 0333 plot(xi,ym,'-k',mg(i),lpm,'ok',xi,yp,':b',pg1(i),lpp,'^b'); 0334 v_axisenlarge([-1 -1 -1 -1.05]); 0335 hold on 0336 plot([mg(i) mg(i)],get(gca,'ylim'),'-k'); 0337 hold off 0338 xlabel(sprintf('x[%d], Mean = %.3g, Mode\\Delta = %.3g',i,mg(i),pg1(i))); 0339 if i==1 0340 title(sprintf('Log Prob: Mean = %.3g, Mode\\Delta = %.3g',lpm,lpp)); 0341 end 0342 ix(i)=i; 0343 end 0344 end 0345 end