v_gaussmixb

PURPOSE ^

V_GAUSSMIXB approximate Bhattacharya divergence between two GMMs

SYNOPSIS ^

function [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg,nx)

DESCRIPTION ^

V_GAUSSMIXB approximate Bhattacharya divergence between two GMMs

 Inputs: with kf & kg mixtures, p data dimensions

   mf(kf,p)                mixture means for GMM f
   vf(kf,p) or vf(p,p,kf)  variances (diagonal or full) for GMM f [default: identity]
   wf(kf,1)                weights for GMM f - must sum to 1 [default: uniform]
   mg(kg,p)                mixture means for GMM g [g=f if mg,vg,wg omitted]
   vg(kg,p) or vg(p,p,kg)  variances (diagonal or full) for GMM g [default: identity]
   wg(kg,1)                weights for GMM g - must sum to 1 [default: uniform]
   nx                      number of samples to use in importance sampling [default: 1000]
                           Set nx=0 to save computation by returning only an upper bound to the Bhattacharya divergence.

 Outputs:
   d             the approximate Bhattacharya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)).
                 if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate.
   dbfg(kf,kg)   the exact Bhattacharya divergence between the unweighted components of f and g

 The Bhattacharya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)).
 It is a special case of the Chernoff Bound [2]. The Bhattacharya divergence [1] satisfies:
     (1) D_B(f,g) >= 0
     (2) D_B(f,g) = 0 iff f = g
     (3) D_B(f,g) = D_B(g,f)
 It is not a distance because it  does not satisfy the triangle inequality. It closely approximates the Bayes
 divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1].

 This routine calculates the "variational importance sampling" estimate of (or if nx=0,
 the "variational II" upper bound to) the Bhattacharya divergence from [3]. It is exact
 when f and g are single component gaussians and is zero if f=g.

 Refs:
 [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection.
     IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967.
 [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the
     sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952.
 [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational
     importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007.

       Copyright (C) Mike Brookes 2024
      Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $

   VOICEBOX is a MATLAB toolbox for speech processing.
   Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You can obtain a copy of the GNU General Public License from
   http://www.gnu.org/copyleft/gpl.html or by writing to
   Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [d,dbfg]=v_gaussmixb(mf,vf,wf,mg,vg,wg,nx)
0002 %V_GAUSSMIXB approximate Bhattacharya divergence between two GMMs
0003 %
0004 % Inputs: with kf & kg mixtures, p data dimensions
0005 %
0006 %   mf(kf,p)                mixture means for GMM f
0007 %   vf(kf,p) or vf(p,p,kf)  variances (diagonal or full) for GMM f [default: identity]
0008 %   wf(kf,1)                weights for GMM f - must sum to 1 [default: uniform]
0009 %   mg(kg,p)                mixture means for GMM g [g=f if mg,vg,wg omitted]
0010 %   vg(kg,p) or vg(p,p,kg)  variances (diagonal or full) for GMM g [default: identity]
0011 %   wg(kg,1)                weights for GMM g - must sum to 1 [default: uniform]
0012 %   nx                      number of samples to use in importance sampling [default: 1000]
0013 %                           Set nx=0 to save computation by returning only an upper bound to the Bhattacharya divergence.
0014 %
0015 % Outputs:
0016 %   d             the approximate Bhattacharya divergence D_B(f,g)=-log(Int(sqrt(f(x)g(x)) dx)).
0017 %                 if nx=0 this will be an upper bound (typically 0.3 to 0.7 too high) rather than an estimate.
0018 %   dbfg(kf,kg)   the exact Bhattacharya divergence between the unweighted components of f and g
0019 %
0020 % The Bhattacharya divergence, D_B(f,g), between two distributions, f(x) and g(x), is -log(Int(sqrt(f(x)g(x)) dx)).
0021 % It is a special case of the Chernoff Bound [2]. The Bhattacharya divergence [1] satisfies:
0022 %     (1) D_B(f,g) >= 0
0023 %     (2) D_B(f,g) = 0 iff f = g
0024 %     (3) D_B(f,g) = D_B(g,f)
0025 % It is not a distance because it  does not satisfy the triangle inequality. It closely approximates the Bayes
0026 % divergence -log(Int(min(f(x),g(x)) dx) which relates to the probability of 2-class misclassification [1].
0027 %
0028 % This routine calculates the "variational importance sampling" estimate of (or if nx=0,
0029 % the "variational II" upper bound to) the Bhattacharya divergence from [3]. It is exact
0030 % when f and g are single component gaussians and is zero if f=g.
0031 %
0032 % Refs:
0033 % [1] T. Kailath. The divergence and Bhattacharyya distance measures in signal selection.
0034 %     IEEE Trans Communication Technology, 15 (1): 52–60, Feb. 1967.
0035 % [2] H. Chernoff. A measure of asymptotic efficiency for tests of a hypothesis based on the
0036 %     sum of observations. The Annals of Mathematical Statistics, 23 (4): 493–507, Dec. 1952.
0037 % [3] P. A. Olsen and J. R. Hershey. Bhattacharyya error and divergence using variational
0038 %     importance sampling. In Proc. Interspeech Conf., pages 46–49, 2007.
0039 %
0040 %       Copyright (C) Mike Brookes 2024
0041 %      Version: $Id: v_gaussmixk.m 10865 2018-09-21 17:22:45Z dmb $
0042 %
0043 %   VOICEBOX is a MATLAB toolbox for speech processing.
0044 %   Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html
0045 %
0046 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0047 %   This program is free software; you can redistribute it and/or modify
0048 %   it under the terms of the GNU General Public License as published by
0049 %   the Free Software Foundation; either version 2 of the License, or
0050 %   (at your option) any later version.
0051 %
0052 %   This program is distributed in the hope that it will be useful,
0053 %   but WITHOUT ANY WARRANTY; without even the implied warranty of
0054 %   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0055 %   GNU General Public License for more details.
0056 %
0057 %   You can obtain a copy of the GNU General Public License from
0058 %   http://www.gnu.org/copyleft/gpl.html or by writing to
0059 %   Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA.
0060 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0061 maxiter=15;                                                     % maximum iterations for upper bound calculation
0062 pruneth=0.2;                                                    % prune threshold for importance sampling (prob that any excluded mixture would have been chosen)
0063 [kf,p]=size(mf);
0064 if nargin<2 || isempty(vf)                                      % if vf inoput is missing
0065     vf=ones(kf,p);                                              % ... default to diagonal covariances
0066 end
0067 if nargin<3 || isempty(wf)                                      % if wf inoput is missing
0068     wf=repmat(1/kf,kf,1);                                       % ... default to uniform weights
0069 end
0070 if p==1                                                         % if data dimension is 1
0071     vf=vf(:);                                                   % ... then variances are always diagonal
0072     dvf=true;
0073 else
0074     dvf=ismatrix(vf) && size(vf,1)==kf;                         % diagonal covariance matrix vf is supplied
0075 end
0076 if nargin<7
0077     nx=1000;                                                    % default count for importance sampling
0078 end
0079 hpl2=0.5*p*log(2);
0080 if nargin<5                                                     % only f GMM specified: use this for g GMM as well
0081     dbfg=zeros(kf,kf);                                          % space for pairwise divergences
0082     if kf>1
0083         nup=kf*(kf-1)/2;                                        % number of elements in upper triangle
0084         gix=ceil((1+sqrt(8*(1:nup)-3))/2);                      % column of upper triangle
0085         fix=(1:nup)-(gix-1).*(gix-2)/2;                         % row of upper triangle
0086         if dvf                                                  % diagonal covariances
0087             mdif=mf(fix,:)-mf(gix,:);                           % difference in means
0088             vfpg=(vf(fix,:)+vf(gix,:));                         % sum of variances
0089             qldf=0.25*log(prod(vf,2));
0090             dbfg(fix+kf*(gix-1))=0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldf(gix)-hpl2; % fill in upper triangle
0091         else                                                    % full covariance matrices
0092             qldf=zeros(kf,1);
0093             for jf=1:kf                                         % precalculate the log determinants for f covariances
0094                 qldf(jf)=0.5*log(prod(diag(chol(vf(:,:,jf))))); % equivalent to 0.25*log(det(vf(:,:,jg)))
0095             end
0096             for jf=1:kf-1
0097                 vjf=vf(:,:,jf);                                 % covariance matrix for f
0098                 for jg=jf+1:kf
0099                     vfg=vjf+vf(:,:,jg);
0100                     mdif=mf(jf,:)-mf(jg,:);                     % difference in means
0101                     dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jg)-qldf(jf)-hpl2; % fill in upper triangle
0102                 end
0103             end
0104         end
0105         dbfg(gix+kf*(fix-1))=dbfg(fix+kf*(gix-1));              % now reflect upper triangle divergences into the symmmetric lower triangle
0106     end
0107     d=0;                                                        % divergence is always zero if f and g are identical
0108 else                                                            % both f and g GMMs are specified as inputs
0109     kg=size(mg,1);
0110     if nargin<5 || isempty(vg)
0111         vg=ones(kg,p);                                          % default to diagonal covariances
0112     end
0113     if nargin<6 || isempty(wg)
0114         wg=repmat(1/kg,kg,1);                                   % default to uniform weights
0115     end
0116     if p==1                                                     % if data dimension is 1
0117         vg=vg(:);                                               % ... then variances are always diagonal
0118         dvg=true;
0119     else
0120         dvg=ismatrix(vg) && size(vg,1)==kg;                     % diagonal covariance matrix vg is supplied
0121     end
0122     % first calculate pairwise Bhattacharya divergences between the components of f and g
0123     dbfg=zeros(kf,kg);                                          % space for full covariance matrices (overwritten below if f and g both diagonal)
0124     dix=1:p+1:p^2;                                              % index of diagonal elements in covariance matrix
0125     if dvf
0126         if dvg                                                  % both f and g have diagonal covariances
0127             fix=repmat((1:kf)',kg,1);                           % index into f values
0128             gix=reshape(repmat(1:kg,kf,1),kf*kg,1);             % index into g values
0129             mdif=mf(fix,:)-mg(gix,:);                           % difference in means
0130             vfpg=(vf(fix,:)+vg(gix,:));                         % sum of variances
0131             qldf=0.25*log(prod(vf,2));                          % 0.25 * log determinants of f components
0132             qldg=0.25*log(prod(vg,2));                          % 0.25 * log determinants of g components
0133             dbfg=reshape(0.25*sum((mdif./vfpg).*mdif,2)+0.5*log(prod(vfpg,2))-qldf(fix)-qldg(gix),kf,kg);
0134         else                                                    % diagonal f covariance but not g
0135             qldf=0.25*log(prod(vf,2));                          % precalculate the log determinants for f covariances
0136             for jg=1:kg                                         % loop through g components
0137                 vjg=vg(:,:,jg);                                 % full covariance matrix for g
0138                 qldg=0.5*log(prod(diag(chol(vjg))));            % equivalent to 0.25*log(det(vjg))
0139                 for jf=1:kf                                     % loop through f components
0140                     vfg=vjg;                                    % take full g covariance matrix
0141                     vfg(dix)=vfg(dix)+vf(jf,:);                 % ... and add diagonal f covariance
0142                     mdif=mf(jf,:)-mg(jg,:);                     % difference in means
0143                     dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldf(jf)-qldg;
0144                 end
0145             end
0146         end
0147     else
0148         if dvg                                                  % diagonal g covariance but not f
0149             qldg=0.25*log(prod(vg,2));                          % precalculate the log determinants for g covariances
0150             for jf=1:kf                                         % loop through f components
0151                 vjf=vf(:,:,jf);                                 % full covariance matrix for f
0152                 qldf=0.5*log(prod(diag(chol(vjf))));            % equivalent to 0.25*log(det(vjf))
0153                 for jg=1:kg                                     % loop through g components
0154                     vfg=vjf;                                    % take full f covariance matrix
0155                     vfg(dix)=vfg(dix)+vg(jg,:);                 % ... and add diagonal g covariance
0156                     mdif=mf(jf,:)-mg(jg,:);                     % difference in means
0157                     dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf;
0158                 end
0159             end
0160         else                                                    % both f and g have full covariance matrices
0161             qldg=zeros(kg,1);
0162             for jg=1:kg                                         % precalculate the log determinants for g covariances
0163                 qldg(jg)=0.5*log(prod(diag(chol(vg(:,:,jg))))); % equivalent to 0.25*log(det(vg(:,:,jg)))
0164             end
0165             for jf=1:kf                                         % loop through f components
0166                 vjf=vf(:,:,jf);                                 % covariance matrix for f
0167                 qldf=0.5*log(prod(diag(chol(vjf))));            % equivalent to 0.25*log(det(vjf))
0168                 for jg=1:kg                                     % loop through g components
0169                     vfg=vjf+vg(:,:,jg);                         % calculate sum of covariance matrices
0170                     mdif=mf(jf,:)-mg(jg,:);                     % difference in means
0171                     dbfg(jf,jg)=0.25*(mdif/vfg)*mdif'+log(prod(diag(chol(vfg))))-qldg(jg)-qldf;
0172                 end
0173             end
0174         end
0175     end
0176     dbfg=dbfg-hpl2;                                             % add correction term to all the calculated covariances
0177     %
0178     % Now calculate the variational bound
0179     % Note that in [3], the psi and phi symbols are interchanged in (20) and also in the previous
0180     % line; in addition, the subscript of phi is incorrect in the denominator of (26).
0181     %
0182     lwf=repmat(log(wf),1,kg);                                   % log of f component weights
0183     lwg=repmat(log(wg'),kf,1);                                  % log of g component weights
0184     lhf=repmat(log(1/kf),kf,kg);                                % initialize psi_f|g from [3] (cols of exp(lhf) sum to 1)
0185     lhg=repmat(log(1/kg),kf,kg);                                % initialize phi_g|f from [3] (rows of exp(lhg) sum to 1)
0186     dbfg2=2*dbfg;                                               % log of squared Bhattacharya measure lower bound
0187     dbfg2f=lwf-dbfg2;                                           % interation-independent term used to update lhg
0188     dbfg2g=lwg-dbfg2;                                           % interation-independent term used to update lhf
0189     dbfg2fg=dbfg2(:)-lwf(:)-lwg(:);                             % iteration-independent termto calculate the divergence upper bound
0190     dub=Inf;                                                    % dummy upper bound for first iteration
0191     for ip=1:maxiter                                            % maximum number of iterations
0192         dubp=dub;                                               % save previous iteration's upper bound
0193         dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg));             % update the upper bound on Bhattacharya divergence
0194         if dub>=dubp                                            % quit if no longer decreasing
0195             break;
0196         end
0197         lhg=lhf+dbfg2f;                                         % update phi_g|f as in numerator of [3]-(25)
0198         lhg=lhg-repmat(v_logsum(lhg,2),1,kg);                   % normalize phi_g|f as in [3]-(25) (rows of exp(lhg) sum to 1)
0199         dub=-v_logsum(0.5*(lhf(:)+lhg(:)-dbfg2fg));             % update the upper bound on Bhattacharya divergence
0200         lhf=lhg+dbfg2g;                                         % update psi_f|g as in numerator of [3]-(26)
0201         lhf=lhf-repmat(v_logsum(lhf,1),kf,1);                   % normalize psi_f|g as in [3]-(26) (cols of exp(lhf) sum to 1)
0202     end
0203     if isempty(nx) || nx==0                                     % only calculate the upper divergence bound
0204         d=dub;
0205     else
0206         [lnwt,jlnwt]=sort(0.5*(lhf(:)+lhg(:)-dbfg2fg)+dub,'descend'); % normalized component log weights (highest first)
0207         wt=exp(lnwt);
0208         cwt=cumsum(wt);
0209         nmix=1+sum(cwt<1-pruneth/nx);                           % number of mixtures for <20% chance that any excluded ones would be picked
0210         [fix,gix]=ind2sub([kf kg],jlnwt(1:nmix));               % mixture indices that are needed
0211         %
0212         % now create the sampling GMM distribution
0213         %
0214         ws=wt(1:nmix)/cwt(nmix);                                % sampling GMM weight vector
0215         ms=zeros(nmix,p);                                       % space for sampling GMM means
0216         vs=zeros(p,p,nmix);                                     % space for sampling GMM full covariances
0217         if dvf
0218             if dvg                                              % both f and g have diagonal covariances
0219                 vff=vf(fix,:);
0220                 vgg=vg(gix,:);
0221                 vsumi=1./(vff+vgg);
0222                 vs=2*vff.*vgg.*vsumi;                           % mixture covariance matrix
0223                 ms=vff.*vsumi.*mg(gix,:)+vgg.*vsumi.*mf(fix,:); % mixture means
0224             else                                                % diagonal f covariance but not g
0225                 for jfg=1:nmix
0226                     vgg=vg(:,:,gix(jfg));
0227                     vff=vf(fix(jfg),:);
0228                     vsum=vgg;
0229                     vsum(dix)=vsum(dix)+vff;                    % add diagonal components
0230                     vs(:,:,jfg)=2*vgg/vsum.*repmat(vff,p,1);    % mixture covariance matrix
0231                     ms(jfg,:)=mg(gix(jfg),:)/vsum.*vff+mf(fix(jfg),:)/vsum*vgg; % mixture means
0232                 end
0233             end
0234         else
0235             if dvg                                              % diagonal g covariance but not f
0236                 for jfg=1:nmix
0237                     vff=vf(:,:,fix(jfg));
0238                     vgg=vg(gix(jfg),:);
0239                     vsum=vff;
0240                     vsum(dix)=vsum(dix)+vgg;                    % add diagonal components
0241                     vs(:,:,jfg)=2*vff/vsum.*repmat(vgg,p,1);    % mixture covariance matrix
0242                     ms(jfg,:)=mf(fix(jfg),:)/vsum.*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means
0243                 end
0244             else                                                % both f and g have full covariance matrices
0245                 for jfg=1:nmix
0246                     vff=vf(:,:,fix(jfg));
0247                     vgg=vg(:,:,gix(jfg));
0248                     vsum=vff+vgg;
0249                     vs(:,:,jfg)=2*vff/vsum*vgg;                 % mixture covariance matrix
0250                     ms(jfg,:)=mf(fix(jfg),:)/vsum*vgg+mg(gix(jfg),:)/vsum*vff; % mixture means
0251                 end
0252             end
0253         end
0254         x=v_randvec(nx,ms,vs,ws);                               % draw from sampling distribution
0255         d=-(v_logsum(0.5*(v_gaussmixp(x,mf,vf,wf)+v_gaussmixp(x,mg,vg,wg))-v_gaussmixp(x,ms,vs,ws)))+log(nx); % montecarlo estimate of Bhatt divergence
0256     end
0257 end

Generated by m2html © 2003