Hi All!
I am trying to implement one Linear programming algorithm in C++. For the matrix multiplication, I use Blas and Lapack. However, I find C++ code performs worse than Matlab as the size of problem is large. Indeed, the difference becomes significant as the size increases.
I am wondering if it is caused by the optimization tricks of matlab use to call Intel MKL. Could some one help explain why Matlab sometimes outperform C++ with Blas/Lapack? Is there any way to improve this version of C++ code, or any option to optimize compiling?
Thank you for your time!
The following is my simplified code.
#include <math.h> #include <mex.h> #include <string.h> #include "blas.h" #include "lapack.h" #if !defined(MAX) #define MAX(A, B) ((A) > (B) ? (A) : (B)) #endif #if !defined(MIN) #define MIN(A, B) ((A) < (B) ? (A) : (B)) #endif void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] ) { double *A, *b, *c, *Ac, *AAT, *peigATA; double *x, *Ax, *tx, *y, *ATy, *s, *As, *lambda1, *lambda2, *ATlambda1, *Alambda2; double *ppobj, *pdobj, *ppresi, *pdresi, *pdgap, *piter; double *tmp, *tmpinv, *tmpres, *res1, *res2; double pobj, dobj, presi, dresi, dgap, duration; double gamma, alpha, beta_max, beta_min, tol, final_mu, final_tol_admm, feas_mul; double temp, beta1 = 1.0, beta2 = 1.0, one = 1.0, mone = -1.0, zero = 0.0; double mu, tol_admm, bnorm, cnorm, tau, eigATA, lambdaRes1, lambdaRes2; ptrdiff_t i, j, m, n, m2, mn, feas_count, info, verbose, inc = 1; ptrdiff_t k, outer, iter_all = 0, count_pbig = 0, count_dbig = 0, count1 = 0, count2 = 0, iter = 0; char *NTRANS, *TTRANS, *uplo; NTRANS = "N"; TTRANS = "T"; uplo = "U"; mu = final_mu/pow(gamma,(int)outer); tol_admm = final_tol_admm/pow(gamma,(int)outer); /* Ax = A*x */ dgemv(NTRANS,&m,&n,&one,A,&m,x,&inc,&zero,Ax,&inc); bnorm = dnrm2(&m,b,&inc); cnorm = dnrm2(&n,c,&inc); /*prepare stats*/ /* AAT = A*A' */ dgemm(NTRANS,TTRANS,&m,&m,&n,&one,A,&m,A,&m,&zero,AAT,&m); /* Ac = A*c */ dgemv(NTRANS,&m,&n,&one,A,&m,c,&inc,&zero,Ac,&inc); /* Compute largest eigenvalue */ memcpy(mxGetPr(rhsAAT[0]),AAT,(m*m)*sizeof(double)); mexCallMATLAB(1,lhsAAT,1,rhsAAT,"eigs"); peigATA = mxGetPr(lhsAAT[0]); eigATA = *peigATA; /* Cholesky Factorization: AAT = U */ dpotrf(uplo, &m, AAT, &m, &info); while (mu >= final_mu) { k = 0; iter = iter + 1; presi = 10000; dresi = 10000; lambdaRes1 = 10000; lambdaRes2 = 10000; while ((MAX(MAX(MAX(presi,dresi),lambdaRes1),lambdaRes2) > tol_admm)&&(k < round(pow(1/mu,0.5)))) { k = k + 1; tau = 0.99/(beta1*eigATA); // Update x for (i=0; i<2; i++) { /* Ax = -b + Ax */ daxpy(&m, &mone, b, &inc, Ax, &inc); /* Ax = beta1*Ax */ dscal(&m, &beta1, Ax, &inc); /* Ax = lambda1 + Ax */ daxpy(&m, &one, lambda1, &inc, Ax, &inc); /* tx = tau*A'*Ax */ dgemv(TTRANS,&m,&n,&tau,A,&m,Ax,&inc,&zero,tx,&inc); /* tx = -x + tx */ daxpy(&n, &mone, x, &inc, tx, &inc); /* tx = tau*c + tx */ daxpy(&n, &tau, c, &inc, tx, &inc); /* x = (-tx+sqrt(tx.^2+4*mu*tau))/2 */ for(j=0; j<n; j++) { *(x+j) = (-*(tx+j)+sqrt((*(tx+j))*(*(tx+j))+4.0*mu*tau))/2.0; } /* Ax = A*x */ dgemv(NTRANS,&m,&n,&one,A,&m,x,&inc,&zero,Ax,&inc); } // Update s /* ATy = -c+ATy */ daxpy(&n, &mone, c, &inc, ATy, &inc); /* ATy = beta2*ATy*/ dscal(&n, &beta2, ATy, &inc); /* ATy = lambda2+ATy*/ daxpy(&n, &one, lambda2, &inc, ATy, &inc); /* s = (-ts+sqrt(ts.^2+4*mu*beta2))/(2*beta2) */ for(j=0; j<n; j++) { *(s+j) = (-*(ATy+j)+sqrt((*(ATy+j))*(*(ATy+j))+4.0*mu*beta2))/(2.0*beta2); } /* As = A*s */ dgemv(NTRANS,&m,&n,&one,A,&m,s,&inc,&zero,As,&inc); //Update y /* As = -Ac + As */ daxpy(&m, &mone, Ac, &inc, As, &inc); /* As = -beta2*As */ temp = mone*beta2; dscal(&m, &temp, As, &inc); /* As = b + As */ daxpy(&m, &one, b, &inc, As, &inc); /* As = -Alambda2 + As */ daxpy(&m, &mone, Alambda2, &inc, As, &inc); /* As = As/beta2 */ temp = 1.0/beta2; dscal(&m, &temp, As, &inc); /* y = (AAT)^(-1)*As */ dpotrs(uplo, &m, &inc, AAT, &m, As, &m, &info); dcopy(&m, As, &inc, y, &inc); /* ATy = A'*y */ dgemv(TTRANS,&m,&n,&one,A,&m,y,&inc,&zero,ATy,&inc); //Update multipliers dcopy(&m, Ax, &inc, res1, &inc); dcopy(&n, ATy, &inc, res2, &inc); /* res1 = -b + res1 */ daxpy(&m, &mone, b, &inc, res1, &inc); /* res2 = s + res2 */ daxpy(&n, &one, s, &inc, res2, &inc); /* res2 = -c + res2 */ daxpy(&n, &mone, c, &inc, res2, &inc); /* lambda1 = alpha*beta1*res1+lambda1 */ temp = alpha*beta1; daxpy(&m, &temp, res1, &inc, lambda1, &inc); /* ATlambda1 = A'*lambda1 */ dgemv(TTRANS,&m,&n,&one,A,&m,lambda1,&inc,&zero,ATlambda1,&inc); /* lambda2 = alpha*beta2*res2+lambda2 */ temp = alpha*beta2; daxpy(&n, &temp, res2, &inc, lambda2, &inc); /* Alambda2 = A*lambda2 */ dgemv(NTRANS,&m,&n,&one,A,&m,lambda2,&inc,&zero,Alambda2,&inc); //Stats /* presi = ||Ax - b||/(1+||b||) */ temp = dnrm2(&m,res1,&inc); presi = temp/(1.0+bnorm); /* dresi = ||A'*y+s-c||/(1+||c||) */ temp = dnrm2(&n,res2,&inc); dresi = temp/(1.0+cnorm); pobj = ddot(&n, c, &inc, x, &inc); dobj = ddot(&m, b, &inc, y, &inc); dgap = fabs(pobj-dobj)/(1.0+fabs(pobj)+fabs(dobj)); /* tmpinv = 1.0./s */ for(j=0; j<n; j++) { *(tmpinv+j) = 1.0/(*(s+j));} /* tmpres = (mu/beta2)*A*tmpinv */ temp = mu/beta2; dgemv(NTRANS,&m,&n,&temp,A,&m,tmpinv,&inc,&zero,tmpres,&inc); /* tmpres = -b + tmpres */ daxpy(&n, &mone, b, &inc, tmpres, &inc); temp = dnrm2(&m,tmpres,&inc); lambdaRes1 = temp/bnorm; /* tmpinv = 1.0./x */ for(j=0; j<n; j++) { *(tmpinv+j) = 1.0/(*(x+j));} /* tmpinv = -mu*tmpinv */ temp = mone*mu; dscal(&n, &temp, tmpinv, &inc); /* tmpinv = c + tmpinv */ daxpy(&n, &one, c, &inc, tmpinv, &inc); /* tmpinv = ATlambda1 + tmpinv*/ daxpy(&n, &one, ATlambda1, &inc, tmpinv, &inc); temp = dnrm2(&n,tmpinv,&inc); lambdaRes2 = temp/cnorm; if (MAX(MAX(presi,dresi),dgap) < tol) { iter_all = iter_all + k; return; } } iter_all = iter_all + k; mu= mu*gamma; tol_admm = tol_admm*gamma*0.5; } return; }
The following is the compiling command.
function Installmex % src = pwd; % sdet = 'src'; fname{1} = 'BLAS-BRADMM'; ofname{1} = 'BRADMMw'; fcc{1} = 'cpp'; hasMKL = 0; % with MKL or not details = 0 ; % 1 if details of each command are to be printed v = version ; try % ispc does not appear in MATLAB 5.3 pc = ispc ; mac = ismac ; catch %#ok % if ispc fails, assume we are on a Windows PC if it's not unix pc = ~isunix ; mac = 0 ; end % if (~pc) && (~mac) % mex -O -largeArrayDims -lmwlapack -lmwblas sfmult.cpp % mex -O -largeArrayDims -lmwlapack -lmwblas dfeast.cpp % return % end flags = '' ; is64 = ~isempty (strfind (computer, '64')) ; if (is64) % 64-bit MATLAB flags = '-largeArrayDims' ; end % MATLAB 8.3.0 now has a -silent option to keep 'mex' from burbling too much if (~verLessThan ('matlab', '8.3.0')) flags = ['-silent ' flags] ; end %--------------------------------------------------------------------------- % BLAS option %--------------------------------------------------------------------------- % This is exceedingly ugly. The MATLAB mex command needs to be told where to % fine the LAPACK and BLAS libraries, which is a real portability nightmare. if (pc) if (verLessThan ('matlab', '6.5')) % MATLAB 6.1 and earlier: use the version supplied here lapack = 'lcc_lib/libmwlapack.lib' ; elseif (verLessThan ('matlab', '7.5')) lapack = 'libmwlapack.lib' ; else lapack = 'libmwlapack.lib libmwblas.lib' ; end else if (verLessThan ('matlab', '7.5')) lapack = '-lmwlapack' ; else lapack = '-lmwlapack -lmwblas' ; end end if (is64 && ~verLessThan ('matlab', '7.8')) % versions 7.8 and later on 64-bit platforms use a 64-bit BLAS fprintf ('with 64-bit BLAS\n') ; flags = [flags ' -DBLAS64'] ; end if (~(pc || mac)) % for POSIX timing routine lapack = [lapack ' -lrt'] ; end include = ''; mkl = ''; if hasMKL include = ['-I',MKLHOMEINCLUDE]; if mac mkl = ['', MKLHOMELIB,filesep,'libmkl_intel_lp64.dylib ']; mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_core.dylib ']; mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_intel_thread.dylib -ldl -lm ']; else mkl = ['', MKLHOMELIB,filesep,'libmkl_intel_lp64.so ']; mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_core.so ']; mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_intel_thread.so -ldl -lm -lrt']; end end if (verLessThan ('matlab', '7.0')) % do not attempt to compile CHOLMOD with large file support include = [include ' -DNLARGEFILE'] ; elseif (~pc) % Linux/Unix require these flags for large file support include = [include ' -D_FILE_OFFSET_BITS=64 -D_LARGEFILE64_SOURCE'] ; end if (verLessThan ('matlab', '6.5')) % logical class does not exist in MATLAB 6.1 or earlie include = [include ' -DMATLAB6p1_OR_EARLIER'] ; end % compile each mexFunction for k = 1:length(fname) s = sprintf ('mex %s -DDLONG -O %s %s.%s -output %s', flags, ... include, fname{k}, fcc{k}, ofname{k}) ; s = [s '' lapack mkl] ; %#ok %s = [s '' mkl] ; cmd (s, details) ; end %------------------------------------------------------------------------------ function cmd (s, details) %DO_CMD: evaluate a command, and either print it or print a "." if (details) fprintf ('%s\n', s) ; end eval (s) ;