Home > voicebox > gausprod.m

gausprod

PURPOSE ^

GAUSPROD calculates a product of gaussians [G,U,K]=(M,C)

SYNOPSIS ^

function [g,u,k]=gausprod(m,c,e)

DESCRIPTION ^

GAUSPROD calculates a product of gaussians [G,U,K]=(M,C)
 calculates the product of n d-dimensional multivariate gaussians
 this product is itself a gaussian
 Inputs: m(d,n) - each column is the mean of one of the gaussians
         c(d,d,n) - contains the d#d covariance matrix for each gaussian
                    Alternatives: (i) c(d,n) if diagonal (ii) c(n) if c*I or (iii) omitted if I
         e(d,d,n) - contains orthogonal eigenvalue matrices and c(d,n) contains eigenvalues.
                    Covariance matrix is E(:,:,k)*diag(c(:,k))*E(:,:,k)'
                    c(d,n) can then contain 0 and Inf entries

 Outputs: g log gain (= log(integral) )
          u(d,1) mean vector
          k(d,d) or k(d) or k(1) = covariance matrix, diagonal or multiple of I (same form as input)

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [g,u,k]=gausprod(m,c,e)
0002 %GAUSPROD calculates a product of gaussians [G,U,K]=(M,C)
0003 % calculates the product of n d-dimensional multivariate gaussians
0004 % this product is itself a gaussian
0005 % Inputs: m(d,n) - each column is the mean of one of the gaussians
0006 %         c(d,d,n) - contains the d#d covariance matrix for each gaussian
0007 %                    Alternatives: (i) c(d,n) if diagonal (ii) c(n) if c*I or (iii) omitted if I
0008 %         e(d,d,n) - contains orthogonal eigenvalue matrices and c(d,n) contains eigenvalues.
0009 %                    Covariance matrix is E(:,:,k)*diag(c(:,k))*E(:,:,k)'
0010 %                    c(d,n) can then contain 0 and Inf entries
0011 %
0012 % Outputs: g log gain (= log(integral) )
0013 %          u(d,1) mean vector
0014 %          k(d,d) or k(d) or k(1) = covariance matrix, diagonal or multiple of I (same form as input)
0015 %
0016 
0017 % this version works with singular covariance matrices provided that their null spaces are distinct
0018 % we could improve it slightly by doing the pseudo inverses locally and keeping track of null spaces
0019 
0020 %       Copyright (C) Mike Brookes 2004
0021 %      Version: $Id: gausprod.m 713 2011-10-16 14:45:43Z dmb $
0022 %
0023 %   VOICEBOX is a MATLAB toolbox for speech processing.
0024 %   Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html
0025 %
0026 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0027 %   This program is free software; you can redistribute it and/or modify
0028 %   it under the terms of the GNU General Public License as published by
0029 %   the Free Software Foundation; either version 2 of the License, or
0030 %   (at your option) any later version.
0031 %
0032 %   This program is distributed in the hope that it will be useful,
0033 %   but WITHOUT ANY WARRANTY; without even the implied warranty of
0034 %   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0035 %   GNU General Public License for more details.
0036 %
0037 %   You can obtain a copy of the GNU General Public License from
0038 %   http://www.gnu.org/copyleft/gpl.html or by writing to
0039 %   Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA.
0040 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0041 
0042 [d,n]=size(m);
0043 if nargin>2
0044     error('third argument not yet implemented in gausprod');
0045 end
0046 if nargin<2     % all covariance matrices = I
0047     c=ones(n,1);
0048 end
0049 if ~nargout     % save input data for plotting
0050     m0=m;
0051     c0=c;
0052 end
0053 
0054 sc=size(c);
0055 if length(sc)<3
0056     if(sc(2)==1) & (n>1)    % covariance matrices are multiples of the identity
0057         jj=1;
0058         jj2=2;
0059         gj=zeros(n,1);
0060         while jj<n
0061             for j=1+jj:jj2:n        % we combine the gaussians in pairs
0062                 k=j-jj;
0063                 cjk=c(k)+c(j);
0064                 dm=m(:,k)-m(:,j);
0065                 gj(k)=gj(k)+gj(j)-0.5*(d*log(cjk)+dm'*dm/cjk);
0066                 m(:,k)=(c(k)*m(:,j)+c(j)*m(:,k))/cjk;
0067                 c(k)=c(k)*c(j)/cjk;
0068             end
0069             jj=jj2;
0070             jj2=2*jj;
0071         end
0072         g=gj(1);
0073         k=c(1);
0074         u=m(:,1);
0075     else                    % diagonal covariance matrices
0076         jj=1;
0077         jj2=2;
0078         gj=zeros(n,1);
0079         while jj<n
0080             for j=1+jj:jj2:n        % we combine the gaussians in pairs
0081                 k=j-jj;
0082                 cjk=c(:,k)+c(:,j);
0083                 dm=m(:,k)-m(:,j);
0084                 ix=cjk>d*max(cjk)*eps;      % calculate the psedo inverse directly
0085                 piv=zeros(d,1);
0086                 piv(ix)=cjk(ix).^(-1);
0087                 gj(k)=gj(k)+gj(j)-0.5*(log(prod(cjk))+piv'*dm.^2);
0088                 m(:,k)=piv.*(c(:,k).*m(:,j)+c(:,j).*m(:,k));
0089                 c(:,k)=c(:,k).*piv.*c(:,j);
0090             end
0091             jj=jj2;
0092             jj2=2*jj;
0093         end
0094         g=gj(1);
0095         k=c(:,1);
0096         u=m(:,1);
0097     end
0098 else                        % full covariance matrices
0099     jj=1;
0100     jj2=2;
0101     gj=zeros(n,1);
0102     while jj<n
0103         for j=1+jj:jj2:n        % we combine the gaussians in pairs
0104             k=j-jj;
0105             cjk=c(:,:,k)+c(:,:,j);
0106             dm=m(:,k)-m(:,j);
0107             piv=pinv(cjk);
0108             gj(k)=gj(k)+gj(j)-0.5*(log(det(cjk))+dm'*piv*dm);
0109             m(:,k)=c(:,:,k)*piv*m(:,j)+c(:,:,j)*piv*m(:,k);
0110             c(:,:,k)=c(:,:,k)*piv*c(:,:,j);
0111             c(:,:,k)=0.5*(c(:,:,k)+c(:,:,k)');  % ensure exactly symmetric
0112         end
0113         jj=jj2;
0114         jj2=2*jj;
0115     end
0116     g=gj(1);
0117     k=c(:,:,1);
0118     u=m(:,1);
0119 end
0120 g=g-0.5*(n-1)*d*log(2*pi);
0121 
0122 if ~nargout                 % plot results if no output arguments
0123     if d==1                 % one-dimensional vectors
0124         x0=linspace(-3,3,100)';
0125         x=zeros(length(x0),n);
0126         y=x;
0127         for j=1:n
0128             x(:,j)=x0+m0(1,j);
0129             cj=c0(j);
0130             y(:,j)=normpdf(x0,0,sqrt(cj));
0131         end
0132         plot(x,log10(y),':',x0+u,log10(normpdf(x0,0,k)*exp(g)),'k-');
0133         ylabel('Log10(pdf)');
0134     else
0135         if length(sc)<3
0136             if(sc(2)==1) & (n>1)    % covariance matrices are multiples of the identity
0137                 sk=k*eye(d);
0138             else                    % diagonal covariance matrices
0139                 sk=diag(k);
0140             end
0141             uk=eye(d);
0142             vk=uk;
0143         else                        % full covariance matrices
0144             [uk,sk,vk]=svd(k);
0145         end
0146         
0147         
0148         u2=uk(:,1:2);
0149         t0=linspace(0,2*pi,100);
0150         x=zeros(length(t0),n);
0151         y=x;
0152         x0=[cos(t0); sin(t0)];
0153         for j=1:n
0154             if length(sc)<3
0155                 if(sc(2)==1) & (n>1)    % covariance matrices are multiples of the identity
0156                     cj=c0(j)*eye(2);
0157                 else                    % diagonal covariance matrices
0158                     cj=u2'*diag(c0(:,j))*u2;
0159                 end
0160             else                        % full covariance matrices
0161                 cj=u2'*c0(:,:,j)*u2;
0162             end
0163             mj=u2'*m0(:,j);
0164             v=sqrt(sum((x0'/cj).*x0',2).^(-1));
0165             x(:,j)=mj(1)+v.*x0(1,:)';
0166             y(:,j)=mj(2)+v.*x0(2,:)';
0167         end
0168         
0169         if length(sc)<3
0170             if(sc(2)==1) & (n>1)    % covariance matrices are multiples of the identity
0171                 cj=k*eye(2);
0172             else                    % diagonal covariance matrices
0173                 cj=u2'*diag(k)*u2;
0174             end
0175         else                        % full covariance matrices
0176             cj=u2'*k*u2;
0177         end
0178         mj=u2'*u;
0179         v=sqrt(sum((x0'/cj).*x0',2).^(-1));
0180         plot(x,y,':',mj(1)+v.*x0(1,:)',mj(2)+v.*x0(2,:)','k-');
0181         axis equal;
0182     end
0183 end

Generated on Thu 02-Feb-2012 09:15:04 by m2html © 2003