V_GAUSSMIXB approximate Bhattacharya divergence between two GMMs Inputs: with kf & kg mixtures, p data dimensions mf(kf,p) mixture means for GMM f vf(kf,p) or vf(p,p,kf) variances (diagonal or full) for GMM f [default: identity] wf(kf,1) weights for GMM f - must sum to 1 [default: uniform] mg(kg,p) mixture means for GMM g [g=f if mg,vg,wg omitted] vg(kg,p) or vg(p,p,kg) variances (diagonal or full) for GMM g [default: identity] wg(kg,1) weights for GMM g - must sum to 1 [default: uniform] nx number of samples to use in importance sampling [default: 1000] Set nx=0 to save computation by returning only an upper bound to the Bhattacharya divergence. Outputs: d the approximate Bhattacharya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)). if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate. dbfg(kf,kg) the exact Bhattacharya divergence between the unweighted components of f and g The Bhattacharya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)). It is a special case of the Chernoff Bound [2]. The Bhattacharya divergence [1] satisfies: (1) D_B(f,g) >= 0 (2) D_B(f,g) = 0 iff f = g (3) D_B(f,g) = D_B(g,f) It is not a distance because it does not satisfy the triangle inequality. It closely approximates the Bayes divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1]. This routine calculates the "variational importance sampling" estimate of (or if nx=0, the "variational II" upper bound to) the Bhattacharya divergence from [3]. It is exact when f and g are single component gaussians and is zero if f=g. Refs: [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection. IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967. [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952. [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007. Copyright (C) Mike Brookes 2024 Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $ VOICEBOX is a MATLAB toolbox for speech processing. Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You can obtain a copy of the GNU General Public License from http://www.gnu.org/copyleft/gpl.html or by writing to Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0001 function [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg,nx) 0002 %V_GAUSSMIXB approximate Bhattacharya divergence between two GMMs 0003 % 0004 % Inputs: with kf & kg mixtures, p data dimensions 0005 % 0006 % mf(kf,p) mixture means for GMM f 0007 % vf(kf,p) or vf(p,p,kf) variances (diagonal or full) for GMM f [default: identity] 0008 % wf(kf,1) weights for GMM f - must sum to 1 [default: uniform] 0009 % mg(kg,p) mixture means for GMM g [g=f if mg,vg,wg omitted] 0010 % vg(kg,p) or vg(p,p,kg) variances (diagonal or full) for GMM g [default: identity] 0011 % wg(kg,1) weights for GMM g - must sum to 1 [default: uniform] 0012 % nx number of samples to use in importance sampling [default: 1000] 0013 % Set nx=0 to save computation by returning only an upper bound to the Bhattacharya divergence. 0014 % 0015 % Outputs: 0016 % d the approximate Bhattacharya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)). 0017 % if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate. 0018 % dbfg(kf,kg) the exact Bhattacharya divergence between the unweighted components of f and g 0019 % 0020 % The Bhattacharya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)). 0021 % It is a special case of the Chernoff Bound [2]. The Bhattacharya divergence [1] satisfies: 0022 % (1) D_B(f,g) >= 0 0023 % (2) D_B(f,g) = 0 iff f = g 0024 % (3) D_B(f,g) = D_B(g,f) 0025 % It is not a distance because it does not satisfy the triangle inequality. It closely approximates the Bayes 0026 % divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1]. 0027 % 0028 % This routine calculates the "variational importance sampling" estimate of (or if nx=0, 0029 % the "variational II" upper bound to) the Bhattacharya divergence from [3]. It is exact 0030 % when f and g are single component gaussians and is zero if f=g. 0031 % 0032 % Refs: 0033 % [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection. 0034 % IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967. 0035 % [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the 0036 % sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952. 0037 % [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational 0038 % importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007. 0039 % 0040 % Copyright (C) Mike Brookes 2024 0041 % Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $ 0042 % 0043 % VOICEBOX is a MATLAB toolbox for speech processing. 0044 % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 0045 % 0046 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0047 % This program is free software; you can redistribute it and/or modify 0048 % it under the terms of the GNU General Public License as published by 0049 % the Free Software Foundation; either version 2 of the License, or 0050 % (at your option) any later version. 0051 % 0052 % This program is distributed in the hope that it will be useful, 0053 % but WITHOUT ANY WARRANTY; without even the implied warranty of 0054 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 0055 % GNU General Public License for more details. 0056 % 0057 % You can obtain a copy of the GNU General Public License from 0058 % http://www.gnu.org/copyleft/gpl.html or by writing to 0059 % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 0060 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0061 maxiter=15; % maximum iterations for upper bound calculation 0062 pruneth=0.2; % prune threshold for importance sampling (prob that any excluded mixture would have been chosen) 0063 [kf,p]=size(mf); 0064 if nargin<2 || isempty(vf) % if vf inoput is missing 0065 vf=ones(kf,p); % ... default to diagonal covariances 0066 end 0067 if nargin<3 || isempty(wf) % if wf inoput is missing 0068 wf=repmat(1/kf,kf,1); % ... default to uniform weights 0069 end 0070 if p==1 % if data dimension is 1 0071 vf=vf(:); % ... then variances are always diagonal 0072 dvf=true; 0073 else 0074 dvf=ismatrix(vf) && size(vf,1)==kf; % diagonal covariance matrix vf is supplied 0075 end 0076 if nargin<7 0077 nx=1000; % default count for importance sampling 0078 end 0079 hpl2=0.5*p*log(2); 0080 if nargin<5 % only f GMM specified: use this for g GMM as well 0081 dbfg=zeros(kf,kf); % space for pairwise divergences 0082 if kf>1 0083 nup=kf*(kf-1)/2; % number of elements in upper triangle 0084 gix=ceil((1+sqrt(8*(1:nup)-3))/2); % column of upper triangle 0085 fix=(1:nup)-(gix-1).*(gix-2)/2; % row of upper triangle 0086 if dvf % diagonal covariances 0087 mdif=mf(fix,:)-mf(gix,:); % difference in means 0088 vfpg=(vf(fix,:)+vf(gix,:)); % sum of variances 0089 qldf=0.25*log(prod(vf,2)); 0090 dbfg(fix+kf*(gix-1))=0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldf(gix)-hpl2; % fill in upper triangle 0091 else % full covariance matrices 0092 qldf=zeros(kf,1); 0093 for jf=1:kf % precalculate the log determinants for f covariances 0094 qldf(jf)=0.5*log(prod(diag(chol(vf(:,:,jf))))); % equivalent to 0.25*log(det(vf(:,:,jg))) 0095 end 0096 for jf=1:kf-1 0097 vjf=vf(:,:,jf); % covariance matrix for f 0098 for jg=jf+1:kf 0099 vfg=vjf+vf(:,:,jg); 0100 mdif=mf(jf,:)-mf(jg,:); % difference in means 0101 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jg)-qldf(jf)-hpl2; % fill in upper triangle 0102 end 0103 end 0104 end 0105 dbfg(gix+kf*(fix-1))=dbfg(fix+kf*(gix-1)); % now reflect upper triangle divergences into the symmmetric lower triangle 0106 end 0107 d=0; % divergence is always zero if f and g are identical 0108 else % both f and g GMMs are specified as inputs 0109 kg=size(mg,1); 0110 if nargin<5 || isempty(vg) 0111 vg=ones(kg,p); % default to diagonal covariances 0112 end 0113 if nargin<6 || isempty(wg) 0114 wg=repmat(1/kg,kg,1); % default to uniform weights 0115 end 0116 if p==1 % if data dimension is 1 0117 vg=vg(:); % ... then variances are always diagonal 0118 dvg=true; 0119 else 0120 dvg=ismatrix(vg) && size(vg,1)==kg; % diagonal covariance matrix vg is supplied 0121 end 0122 % first calculate pairwise Bhattacharya divergences between the components of f and g 0123 dbfg=zeros(kf,kg); % space for full covariance matrices (overwritten below if f and g both diagonal) 0124 dix=1:p+1:p^2; % index of diagonal elements in covariance matrix 0125 if dvf 0126 if dvg % both f and g have diagonal covariances 0127 fix=repmat((1:kf)',kg,1); % index into f values 0128 gix=reshape(repmat(1:kg,kf,1),kf*kg,1); % index into g values 0129 mdif=mf(fix,:)-mg(gix,:); % difference in means 0130 vfpg=(vf(fix,:)+vg(gix,:)); % sum of variances 0131 qldf=0.25*log(prod(vf,2)); % 0.25 * log determinants of f components 0132 qldg=0.25*log(prod(vg,2)); % 0.25 * log determinants of g components 0133 dbfg=reshape(0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldg(gix),kf,kg); 0134 else % diagonal f covariance but not g 0135 qldf=0.25*log(prod(vf,2)); % precalculate the log determinants for f covariances 0136 for jg=1:kg % loop through g components 0137 vjg=vg(:,:,jg); % full covariance matrix for g 0138 qldg=0.5*log(prod(diag(chol(vjg)))); % equivalent to 0.25*log(det(vjg)) 0139 for jf=1:kf % loop through f components 0140 vfg=vjg; % take full g covariance matrix 0141 vfg(dix)=vfg(dix)+vf(jf,:); % ... and add diagonal f covariance 0142 mdif=mf(jf,:)-mg(jg,:); % difference in means 0143 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jf)-qldg; 0144 end 0145 end 0146 end 0147 else 0148 if dvg % diagonal g covariance but not f 0149 qldg=0.25*log(prod(vg,2)); % precalculate the log determinants for g covariances 0150 for jf=1:kf % loop through f components 0151 vjf=vf(:,:,jf); % full covariance matrix for f 0152 qldf=0.5*log(prod(diag(chol(vjf)))); % equivalent to 0.25*log(det(vjf)) 0153 for jg=1:kg % loop through g components 0154 vfg=vjf; % take full f covariance matrix 0155 vfg(dix)=vfg(dix)+vg(jg,:); % ... and add diagonal g covariance 0156 mdif=mf(jf,:)-mg(jg,:); % difference in means 0157 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf; 0158 end 0159 end 0160 else % both f and g have full covariance matrices 0161 qldg=zeros(kg,1); 0162 for jg=1:kg % precalculate the log determinants for g covariances 0163 qldg(jg)=0.5*log(prod(diag(chol(vg(:,:,jg))))); % equivalent to 0.25*log(det(vg(:,:,jg))) 0164 end 0165 for jf=1:kf % loop through f components 0166 vjf=vf(:,:,jf); % covariance matrix for f 0167 qldf=0.5*log(prod(diag(chol(vjf)))); % equivalent to 0.25*log(det(vjf)) 0168 for jg=1:kg % loop through g components 0169 vfg=vjf+vg(:,:,jg); % calculate sum of covariance matrices 0170 mdif=mf(jf,:)-mg(jg,:); % difference in means 0171 dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf; 0172 end 0173 end 0174 end 0175 end 0176 dbfg=dbfg-hpl2; % add correction term to all the calculated covariances 0177 % 0178 % Now calculate the variational bound 0179 % Note that in [3], the psi and phi symbols are interchanged in (20) and also in the previous 0180 % line; in addition, the subscript of phi is incorrect in the denominator of (26). 0181 % 0182 lwf=repmat(log(wf),1,kg); % log of f component weights 0183 lwg=repmat(log(wg'),kf,1); % log of g component weights 0184 lhf=repmat(log(1/kf),kf,kg); % initialize psi_f|g from [3] (cols of exp(lhf) sum to 1) 0185 lhg=repmat(log(1/kg),kf,kg); % initialize phi_g|f from [3] (rows of exp(lhg) sum to 1) 0186 dbfg2=2*dbfg; % log of squared Bhattacharya measure lower bound 0187 dbfg2f=lwf-dbfg2; % interation-independent term used to update lhg 0188 dbfg2g=lwg-dbfg2; % interation-independent term used to update lhf 0189 dbfg2fg=dbfg2(:)-lwf(:)-lwg(:); % iteration-independent termto calculate the divergence upper bound 0190 dub=Inf; % dummy upper bound for first iteration 0191 for ip=1:maxiter % maximum number of iterations 0192 dubp=dub; % save previous iteration's upper bound 0193 dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg)); % update the upper bound on Bhattacharya divergence 0194 if dub>=dubp % quit if no longer decreasing 0195 break; 0196 end 0197 lhg=lhf+dbfg2f; % update phi_g|f as in numerator of [3]-(25) 0198 lhg=lhg-repmat(v_logsum(lhg,2),1,kg); % normalize phi_g|f as in [3]-(25) (rows of exp(lhg) sum to 1) 0199 dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg)); % update the upper bound on Bhattacharya divergence 0200 lhf=lhg+dbfg2g; % update psi_f|g as in numerator of [3]-(26) 0201 lhf=lhf-repmat(v_logsum(lhf,1),kf,1); % normalize psi_f|g as in [3]-(26) (cols of exp(lhf) sum to 1) 0202 end 0203 if isempty(nx) || nx==0 % only calculate the upper divergence bound 0204 d=dub; 0205 else 0206 [lnwt,jlnwt]=sort(0.5*(lhf(:)+lhg(:)-dbfg2fg)+dub,'descend'); % normalized component log weights (highest first) 0207 wt=exp(lnwt); 0208 cwt=cumsum(wt); 0209 nmix=1+sum(cwt<1-pruneth/nx); % number of mixtures for <20% chance that any excluded ones would be picked 0210 [fix,gix]=ind2sub([kf kg],jlnwt(1:nmix)); % mixture indices that are needed 0211 % 0212 % now create the sampling GMM distribution 0213 % 0214 ws=wt(1:nmix)/cwt(nmix); % sampling GMM weight vector 0215 ms=zeros(nmix,p); % space for sampling GMM means 0216 vs=zeros(p,p,nmix); % space for sampling GMM full covariances 0217 if dvf 0218 if dvg % both f and g have diagonal covariances 0219 vff=vf(fix,:); 0220 vgg=vg(gix,:); 0221 vsumi=1./(vff+vgg); 0222 vs=2*vff.*vgg.*vsumi; % mixture covariance matrix 0223 ms=vff.*vsumi.*mg(gix,:)+vgg.*vsumi.*mf(fix,:); % mixture means 0224 else % diagonal f covariance but not g 0225 for jfg=1:nmix 0226 vgg=vg(:,:,gix(jfg)); 0227 vff=vf(fix(jfg),:); 0228 vsum=vgg; 0229 vsum(dix)=vsum(dix)+vff; % add diagonal components 0230 vs(:,:,jfg)=2*vgg/vsum.*repmat(vff,p,1); % mixture covariance matrix 0231 ms(jfg,:)=mg(gix(jfg),:)/vsum.*vff+mf(fix(jfg),:)/vsum*vgg; % mixture means 0232 end 0233 end 0234 else 0235 if dvg % diagonal g covariance but not f 0236 for jfg=1:nmix 0237 vff=vf(:,:,fix(jfg)); 0238 vgg=vg(gix(jfg),:); 0239 vsum=vff; 0240 vsum(dix)=vsum(dix)+vgg; % add diagonal components 0241 vs(:,:,jfg)=2*vff/vsum.*repmat(vgg,p,1); % mixture covariance matrix 0242 ms(jfg,:)=mf(fix(jfg),:)/vsum.*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means 0243 end 0244 else % both f and g have full covariance matrices 0245 for jfg=1:nmix 0246 vff=vf(:,:,fix(jfg)); 0247 vgg=vg(:,:,gix(jfg)); 0248 vsum=vff+vgg; 0249 vs(:,:,jfg)=2*vff/vsum*vgg; % mixture covariance matrix 0250 ms(jfg,:)=mf(fix(jfg),:)/vsum*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means 0251 end 0252 end 0253 end 0254 x=v_randvec(nx,ms,vs,ws); % draw from sampling distribution 0255 d=-(v_logsum(0.5*(v_gaussmixp(x,mf,vf,wf)+v_gaussmixp(x,mg,vg,wg))-v_gaussmixp(x,ms,vs,ws)))+log(nx); % montecarlo estimate of Bhatt divergence 0256 end 0257 end