V_GAUSSMIXB approximate Bhattacharyya divergence between two GMMs Usage: (1) d=v_gaussmixb(mf,vf,wf,mg,vg,wg); % Estimate Bhattacharyya divergence between {mf,vf,wf} and {mg,vg,wg} % vf and vg can independently be full or diagonal covariances (2) [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg); % Also calculate exact Bhattacharyya divergence between compnents of f and components of g (3) d=v_gaussmixb(mf,vf,wf,mg,vg,wg,0); % Calculate upper bound to Bhattacharyya divergence (4) [d,dbfg]=v_gaussmixb(mf,vf,wf); % Calculate Bhattacharyya divergence between compnents of f. d=0 always in this case. (5) v_gaussmixb(mf,vf,wf,mg,vg,wg); % Plot gra[hs of distributions (dimension p must equal 1) Inputs: with kf & kg mixtures, p data dimensions mf(kf,p) mixture means for GMM f vf(kf,p) or vf(p,p,kf) variances (diagonal or full) for GMM f [default: identity] wf(kf,1) weights for GMM f - must sum to 1 [default: uniform] mg(kg,p) mixture means for GMM g [g=f if mg,vg,wg omitted] vg(kg,p) or vg(p,p,kg) variances (diagonal or full) for GMM g [default: identity] wg(kg,1) weights for GMM g - must sum to 1 [default: uniform] nx number of samples to use in importance sampling [default: 1000] Set nx=0 to save computation by returning only an upper bound to the Bhattacharyya divergence. Outputs: d the approximate Bhattacharyya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)). if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate. dbfg(kf,kg) the exact Bhattacharyya divergence between the unweighted components of f and g The Bhattacharyya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)). It is a special case of the Chernoff Bound [2]. The Bhattacharyya divergence [1] satisfies: (1) D_B(f,g) >= 0 (2) D_B(f,g) = 0 iff f = g (3) D_B(f,g) = D_B(g,f) It is not a distance because it does not satisfy the triangle inequality. It upper bounds the Bayes divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1]. This routine calculates the "variational importance sampling" estimate of (or if nx=0, the "variational II" upper bound to) the Bhattacharyya divergence from [3]. It is exact when f and g are single component gaussians and is zero if f=g. Refs: [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection. IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967. [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952. [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007. Copyright (C) Mike Brookes 2024 Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $ VOICEBOX is a MATLAB toolbox for speech processing. Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You can obtain a copy of the GNU General Public License from http://www.gnu.org/copyleft/gpl.html or by writing to Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0001 function [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg,nx) 0002 %V_GAUSSMIXB approximate Bhattacharyya divergence between two GMMs 0003 % 0004 % Usage: (1) d=v_gaussmixb(mf,vf,wf,mg,vg,wg); % Estimate Bhattacharyya divergence between {mf,vf,wf} and {mg,vg,wg} 0005 % % vf and vg can independently be full or diagonal covariances 0006 % 0007 % (2) [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg); % Also calculate exact Bhattacharyya divergence between compnents of f and components of g 0008 % 0009 % (3) d=v_gaussmixb(mf,vf,wf,mg,vg,wg,0); % Calculate upper bound to Bhattacharyya divergence 0010 % 0011 % (4) [d,dbfg]=v_gaussmixb(mf,vf,wf); % Calculate Bhattacharyya divergence between compnents of f. d=0 always in this case. 0012 % 0013 % (5) v_gaussmixb(mf,vf,wf,mg,vg,wg); % Plot gra[hs of distributions (dimension p must equal 1) 0014 % 0015 % Inputs: with kf & kg mixtures, p data dimensions 0016 % 0017 % mf(kf,p) mixture means for GMM f 0018 % vf(kf,p) or vf(p,p,kf) variances (diagonal or full) for GMM f [default: identity] 0019 % wf(kf,1) weights for GMM f - must sum to 1 [default: uniform] 0020 % mg(kg,p) mixture means for GMM g [g=f if mg,vg,wg omitted] 0021 % vg(kg,p) or vg(p,p,kg) variances (diagonal or full) for GMM g [default: identity] 0022 % wg(kg,1) weights for GMM g - must sum to 1 [default: uniform] 0023 % nx number of samples to use in importance sampling [default: 1000] 0024 % Set nx=0 to save computation by returning only an upper bound to the Bhattacharyya divergence. 0025 % 0026 % Outputs: 0027 % d the approximate Bhattacharyya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)). 0028 % if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate. 0029 % dbfg(kf,kg) the exact Bhattacharyya divergence between the unweighted components of f and g 0030 % 0031 % The Bhattacharyya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)). 0032 % It is a special case of the Chernoff Bound [2]. The Bhattacharyya divergence [1] satisfies: 0033 % (1) D_B(f,g) >= 0 0034 % (2) D_B(f,g) = 0 iff f = g 0035 % (3) D_B(f,g) = D_B(g,f) 0036 % It is not a distance because it does not satisfy the triangle inequality. It upper bounds the Bayes 0037 % divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1]. 0038 % 0039 % This routine calculates the "variational importance sampling" estimate of (or if nx=0, 0040 % the "variational II" upper bound to) the Bhattacharyya divergence from [3]. It is exact 0041 % when f and g are single component gaussians and is zero if f=g. 0042 % 0043 % Refs: 0044 % [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection. 0045 % IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967. 0046 % [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the 0047 % sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952. 0048 % [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational 0049 % importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007. 0050 % 0051 % Copyright (C) Mike Brookes 2024 0052 % Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $ 0053 % 0054 % VOICEBOX is a MATLAB toolbox for speech processing. 0055 % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 0056 % 0057 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0058 % This program is free software; you can redistribute it and/or modify 0059 % it under the terms of the GNU General Public License as published by 0060 % the Free Software Foundation; either version 2 of the License, or 0061 % (at your option) any later version. 0062 % 0063 % This program is distributed in the hope that it will be useful, 0064 % but WITHOUT ANY WARRANTY; without even the implied warranty of 0065 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 0066 % GNU General Public License for more details. 0067 % 0068 % You can obtain a copy of the GNU General Public License from 0069 % http://www.gnu.org/copyleft/gpl.html or by writing to 0070 % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 0071 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0072 maxiter=15; % maximum iterations for upper bound calculation 0073 pruneth=0.2; % prune threshold for importance sampling (prob that any excluded mixture would have been chosen) 0074 [kf,p]=size(mf); 0075 if nargin<2 || isempty(vf) % if vf inoput is missing 0076 vf=ones(kf,p); % ... default to diagonal covariances 0077 end 0078 if nargin<3 || isempty(wf) % if wf inoput is missing 0079 wf=repmat(1/kf,kf,1); % ... default to uniform weights 0080 end 0081 if p==1 % if data dimension is 1 0082 vf=vf(:); % ... then variances are always diagonal 0083 dvf=true; 0084 else 0085 dvf=ismatrix(vf) && size(vf,1)==kf; % diagonal covariance matrix vf is supplied 0086 end 0087 if nargin<7 0088 nx=1000; % default count for importance sampling 0089 end 0090 hpl2=0.5*p*log(2); 0091 if nargin<5 % only f GMM specified: use this for g GMM as well 0092 dbfg=zeros(kf,kf); % space for pairwise divergences 0093 if kf>1 0094 nup=kf*(kf-1)/2; % number of elements in upper triangle 0095 gix=ceil((1+sqrt(8*(1:nup)-3))/2); % column of upper triangle 0096 fix=(1:nup)-(gix-1).*(gix-2)/2; % row of upper triangle 0097 if dvf % diagonal covariances 0098 mdif=mf(fix,:)-mf(gix,:); % difference in means 0099 vfpg=(vf(fix,:)+vf(gix,:)); % sum of variances 0100 qldf=0.25*log(prod(vf,2)); 0101 dbfg(fix+kf*(gix-1))=0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldf(gix)-hpl2; % fill in upper triangle 0102 else % full covariance matrices 0103 qldf=zeros(kf,1); 0104 for jf=1:kf % precalculate the log determinants for f covariances 0105 qldf(jf)=0.5*log(prod(diag(chol(vf(:,:,jf))))); % equivalent to 0.25*log(det(vf(:,:,jg))) 0106 end 0107 for jf=1:kf-1 0108 vjf=vf(:,:,jf); % covariance matrix for f 0109 for jg=jf+1:kf 0110 vfg=vjf+vf(:,:,jg); 0111 mdif=mf(jf,:)-mf(jg,:); % difference in means 0112 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jg)-qldf(jf)-hpl2; % fill in upper triangle 0113 end 0114 end 0115 end 0116 dbfg(gix+kf*(fix-1))=dbfg(fix+kf*(gix-1)); % now reflect upper triangle divergences into the symmmetric lower triangle 0117 end 0118 d=0; % divergence is always zero if f and g are identical 0119 else % both f and g GMMs are specified as inputs 0120 kg=size(mg,1); 0121 if nargin<5 || isempty(vg) 0122 vg=ones(kg,p); % default to diagonal covariances 0123 end 0124 if nargin<6 || isempty(wg) 0125 wg=repmat(1/kg,kg,1); % default to uniform weights 0126 end 0127 if p==1 % if data dimension is 1 0128 vg=vg(:); % ... then variances are always diagonal 0129 dvg=true; 0130 else 0131 dvg=ismatrix(vg) && size(vg,1)==kg; % diagonal covariance matrix vg is supplied 0132 end 0133 % first calculate pairwise Bhattacharyya divergences between the components of f and g 0134 dbfg=zeros(kf,kg); % space for full covariance matrices (overwritten below if f and g both diagonal) 0135 dix=1:p+1:p^2; % index of diagonal elements in covariance matrix 0136 if dvf 0137 if dvg % both f and g have diagonal covariances 0138 fix=repmat((1:kf)',kg,1); % index into f values 0139 gix=reshape(repmat(1:kg,kf,1),kf*kg,1); % index into g values 0140 mdif=mf(fix,:)-mg(gix,:); % difference in means 0141 vfpg=(vf(fix,:)+vg(gix,:)); % sum of variances 0142 qldf=0.25*log(prod(vf,2)); % 0.25 * log determinants of f components 0143 qldg=0.25*log(prod(vg,2)); % 0.25 * log determinants of g components 0144 dbfg=reshape(0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldg(gix),kf,kg); 0145 else % diagonal f covariance but not g 0146 qldf=0.25*log(prod(vf,2)); % precalculate the log determinants for f covariances 0147 for jg=1:kg % loop through g components 0148 vjg=vg(:,:,jg); % full covariance matrix for g 0149 qldg=0.5*log(prod(diag(chol(vjg)))); % equivalent to 0.25*log(det(vjg)) 0150 for jf=1:kf % loop through f components 0151 vfg=vjg; % take full g covariance matrix 0152 vfg(dix)=vfg(dix)+vf(jf,:); % ... and add diagonal f covariance 0153 mdif=mf(jf,:)-mg(jg,:); % difference in means 0154 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jf)-qldg; 0155 end 0156 end 0157 end 0158 else 0159 if dvg % diagonal g covariance but not f 0160 qldg=0.25*log(prod(vg,2)); % precalculate the log determinants for g covariances 0161 for jf=1:kf % loop through f components 0162 vjf=vf(:,:,jf); % full covariance matrix for f 0163 qldf=0.5*log(prod(diag(chol(vjf)))); % equivalent to 0.25*log(det(vjf)) 0164 for jg=1:kg % loop through g components 0165 vfg=vjf; % take full f covariance matrix 0166 vfg(dix)=vfg(dix)+vg(jg,:); % ... and add diagonal g covariance 0167 mdif=mf(jf,:)-mg(jg,:); % difference in means 0168 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf; 0169 end 0170 end 0171 else % both f and g have full covariance matrices 0172 qldg=zeros(kg,1); 0173 for jg=1:kg % precalculate the log determinants for g covariances 0174 qldg(jg)=0.5*log(prod(diag(chol(vg(:,:,jg))))); % equivalent to 0.25*log(det(vg(:,:,jg))) 0175 end 0176 for jf=1:kf % loop through f components 0177 vjf=vf(:,:,jf); % covariance matrix for f 0178 qldf=0.5*log(prod(diag(chol(vjf)))); % equivalent to 0.25*log(det(vjf)) 0179 for jg=1:kg % loop through g components 0180 vfg=vjf+vg(:,:,jg); % calculate sum of covariance matrices 0181 mdif=mf(jf,:)-mg(jg,:); % difference in means 0182 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf; 0183 end 0184 end 0185 end 0186 end 0187 dbfg=dbfg-hpl2; % add correction term to all the calculated covariances 0188 % 0189 % Now calculate the variational bound 0190 % Note that in [3], the psi and phi symbols are interchanged in (20) and also in the previous 0191 % line; in addition, the subscript of phi is incorrect in the denominator of (26). 0192 % 0193 lwf=repmat(log(wf),1,kg); % log of f component weights 0194 lwg=repmat(log(wg'),kf,1); % log of g component weights 0195 lhf=repmat(log(1/kf),kf,kg); % initialize psi_f|g from [3] (cols of exp(lhf) sum to 1) 0196 lhg=repmat(log(1/kg),kf,kg); % initialize phi_g|f from [3] (rows of exp(lhg) sum to 1) 0197 dbfg2=2*dbfg; % log of squared Bhattacharyya measure lower bound 0198 dbfg2f=lwf-dbfg2; % interation-independent term used to update lhg 0199 dbfg2g=lwg-dbfg2; % interation-independent term used to update lhf 0200 dbfg2fg=dbfg2(:)-lwf(:)-lwg(:); % iteration-independent termto calculate the divergence upper bound 0201 dub=Inf; % dummy upper bound for first iteration 0202 for ip=1:maxiter % maximum number of iterations 0203 dubp=dub; % save previous iteration's upper bound 0204 dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg)); % update the upper bound on Bhattacharyya divergence 0205 if dub>=dubp % quit if no longer decreasing 0206 break; 0207 end 0208 lhg=lhf+dbfg2f; % update phi_g|f as in numerator of [3]-(25) 0209 lhg=lhg-repmat(v_logsum(lhg,2),1,kg); % normalize phi_g|f as in [3]-(25) (rows of exp(lhg) sum to 1) 0210 dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg)); % update the upper bound on Bhattacharyya divergence 0211 lhf=lhg+dbfg2g; % update psi_f|g as in numerator of [3]-(26) 0212 lhf=lhf-repmat(v_logsum(lhf,1),kf,1); % normalize psi_f|g as in [3]-(26) (cols of exp(lhf) sum to 1) 0213 end 0214 if isempty(nx) || nx==0 % only calculate the upper divergence bound 0215 d=dub; 0216 else 0217 [lnwt,jlnwt]=sort(0.5*(lhf(:)+lhg(:)-dbfg2fg)+dub,'descend'); % normalized component log weights (highest first) 0218 wt=exp(lnwt); 0219 cwt=cumsum(wt); 0220 nmix=1+sum(cwt<1-pruneth/nx); % number of mixtures for <20% chance that any excluded ones would be picked 0221 [fix,gix]=ind2sub([kf kg],jlnwt(1:nmix)); % mixture indices that are needed 0222 % 0223 % now create the sampling GMM distribution 0224 % 0225 ws=wt(1:nmix)/cwt(nmix); % sampling GMM weight vector 0226 ms=zeros(nmix,p); % space for sampling GMM means 0227 vs=zeros(p,p,nmix); % space for sampling GMM full covariances 0228 if dvf 0229 if dvg % both f and g have diagonal covariances 0230 vff=vf(fix,:); 0231 vgg=vg(gix,:); 0232 vsumi=1./(vff+vgg); 0233 vs=2*vff.*vgg.*vsumi; % mixture covariance matrix 0234 ms=vff.*vsumi.*mg(gix,:)+vgg.*vsumi.*mf(fix,:); % mixture means 0235 else % diagonal f covariance but not g 0236 for jfg=1:nmix 0237 vgg=vg(:,:,gix(jfg)); 0238 vff=vf(fix(jfg),:); 0239 vsum=vgg; 0240 vsum(dix)=vsum(dix)+vff; % add diagonal components 0241 vs(:,:,jfg)=2*vgg/vsum.*repmat(vff,p,1); % mixture covariance matrix 0242 ms(jfg,:)=mg(gix(jfg),:)/vsum.*vff+mf(fix(jfg),:)/vsum*vgg; % mixture means 0243 end 0244 end 0245 else 0246 if dvg % diagonal g covariance but not f 0247 for jfg=1:nmix 0248 vff=vf(:,:,fix(jfg)); 0249 vgg=vg(gix(jfg),:); 0250 vsum=vff; 0251 vsum(dix)=vsum(dix)+vgg; % add diagonal components 0252 vs(:,:,jfg)=2*vff/vsum.*repmat(vgg,p,1); % mixture covariance matrix 0253 ms(jfg,:)=mf(fix(jfg),:)/vsum.*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means 0254 end 0255 else % both f and g have full covariance matrices 0256 for jfg=1:nmix 0257 vff=vf(:,:,fix(jfg)); 0258 vgg=vg(:,:,gix(jfg)); 0259 vsum=vff+vgg; 0260 vs(:,:,jfg)=2*vff/vsum*vgg; % mixture covariance matrix 0261 ms(jfg,:)=mf(fix(jfg),:)/vsum*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means 0262 end 0263 end 0264 end 0265 x=v_randvec(nx,ms,vs,ws); % draw from sampling distribution 0266 d=-(v_logsum(0.5*(v_gaussmixp(x,mf,vf,wf)+v_gaussmixp(x,mg,vg,wg))-v_gaussmixp(x,ms,vs,ws)))+log(nx); % montecarlo estimate of Bhatt divergence 0267 end 0268 end 0269 if ~nargout 0270 switch p 0271 case 1 0272 nsd=3; % number of std deviations to plot 0273 nxax=251; % number of points on x-axis (MUST be odd) 0274 xlo=min([mf;mg]-nsd*sqrt([vf;vg])); 0275 xhi=max([mf;mg]+nsd*sqrt([vf;vg])); 0276 xax=linspace(xlo,xhi,nxax)'; 0277 sint=(xax(2)-xax(1))/3*(4-2*mod(1:nxax,2)-[1 zeros(1,nxax-2) 1]); % Simpson's rule integration 0278 yf=exp(v_gaussmixp(xax,mf,vf,wf)); 0279 yg=exp(v_gaussmixp(xax,mg,vg,wg)); 0280 bayeserr=sint*min(yf,yg)*0.5; % calculate Bayes error 0281 plot(xax,yf,'-b',xax,yg,'-r',xax,sqrt(yf.*yg),'-g'); 0282 if ~isempty(nx) && nx~=0 0283 ys=exp(v_gaussmixp(xax,ms,vs,ws)); 0284 hold on 0285 plot(xax,exp(-d)*ys,'--k'); 0286 hold off 0287 legend('f(x)','g(x)','\surd(fg)','\approx \surd(fg)','location','northeast'); 0288 v_texthvc(0.02,0.98,sprintf('Bhattacharyya = %.1f%% (>=%.1f%%)\nBayes Err = 0.5 x %.1f%%',100*exp(-d),100*exp(-dub),200*bayeserr),'LTk'); 0289 else 0290 legend('f(x)','g(x)','\surd(fg)','location','northeast'); 0291 end 0292 xlabel('x'); 0293 ylabel('Prob density'); 0294 end 0295 end