Hi All!
I am trying to implement one Linear programming algorithm in C++. For the matrix multiplication, I use Blas and Lapack. However, I find C++ code performs worse than Matlab as the size of problem is large. Indeed, the difference becomes significant as the size increases.
I am wondering if it is caused by the optimization tricks of matlab use to call Intel MKL. Could some one help explain why Matlab sometimes outperform C++ with Blas/Lapack? Is there any way to improve this version of C++ code, or any option to optimize compiling?
Thank you for your time!
The following is my simplified code.
#include <math.h>
#include <mex.h>
#include <string.h>
#include "blas.h"
#include "lapack.h"
#if !defined(MAX)
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#endif
#if !defined(MIN)
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#endif
void mexFunction( int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[] )
{
double *A, *b, *c, *Ac, *AAT, *peigATA;
double *x, *Ax, *tx, *y, *ATy, *s, *As, *lambda1, *lambda2, *ATlambda1, *Alambda2;
double *ppobj, *pdobj, *ppresi, *pdresi, *pdgap, *piter;
double *tmp, *tmpinv, *tmpres, *res1, *res2;
double pobj, dobj, presi, dresi, dgap, duration;
double gamma, alpha, beta_max, beta_min, tol, final_mu, final_tol_admm, feas_mul;
double temp, beta1 = 1.0, beta2 = 1.0, one = 1.0, mone = -1.0, zero = 0.0;
double mu, tol_admm, bnorm, cnorm, tau, eigATA, lambdaRes1, lambdaRes2;
ptrdiff_t i, j, m, n, m2, mn, feas_count, info, verbose, inc = 1;
ptrdiff_t k, outer, iter_all = 0, count_pbig = 0, count_dbig = 0, count1 = 0, count2 = 0, iter = 0;
char *NTRANS, *TTRANS, *uplo;
NTRANS = "N"; TTRANS = "T"; uplo = "U";
mu = final_mu/pow(gamma,(int)outer);
tol_admm = final_tol_admm/pow(gamma,(int)outer);
/* Ax = A*x */
dgemv(NTRANS,&m,&n,&one,A,&m,x,&inc,&zero,Ax,&inc);
bnorm = dnrm2(&m,b,&inc);
cnorm = dnrm2(&n,c,&inc);
/*prepare stats*/
/* AAT = A*A' */
dgemm(NTRANS,TTRANS,&m,&m,&n,&one,A,&m,A,&m,&zero,AAT,&m);
/* Ac = A*c */
dgemv(NTRANS,&m,&n,&one,A,&m,c,&inc,&zero,Ac,&inc);
/* Compute largest eigenvalue */
memcpy(mxGetPr(rhsAAT[0]),AAT,(m*m)*sizeof(double));
mexCallMATLAB(1,lhsAAT,1,rhsAAT,"eigs");
peigATA = mxGetPr(lhsAAT[0]);
eigATA = *peigATA;
/* Cholesky Factorization: AAT = U */
dpotrf(uplo, &m, AAT, &m, &info);
while (mu >= final_mu) {
k = 0;
iter = iter + 1;
presi = 10000; dresi = 10000; lambdaRes1 = 10000; lambdaRes2 = 10000;
while ((MAX(MAX(MAX(presi,dresi),lambdaRes1),lambdaRes2) > tol_admm)&&(k < round(pow(1/mu,0.5)))) {
k = k + 1;
tau = 0.99/(beta1*eigATA);
// Update x
for (i=0; i<2; i++) {
/* Ax = -b + Ax */
daxpy(&m, &mone, b, &inc, Ax, &inc);
/* Ax = beta1*Ax */
dscal(&m, &beta1, Ax, &inc);
/* Ax = lambda1 + Ax */
daxpy(&m, &one, lambda1, &inc, Ax, &inc);
/* tx = tau*A'*Ax */
dgemv(TTRANS,&m,&n,&tau,A,&m,Ax,&inc,&zero,tx,&inc);
/* tx = -x + tx */
daxpy(&n, &mone, x, &inc, tx, &inc);
/* tx = tau*c + tx */
daxpy(&n, &tau, c, &inc, tx, &inc);
/* x = (-tx+sqrt(tx.^2+4*mu*tau))/2 */
for(j=0; j<n; j++) {
*(x+j) = (-*(tx+j)+sqrt((*(tx+j))*(*(tx+j))+4.0*mu*tau))/2.0;
}
/* Ax = A*x */
dgemv(NTRANS,&m,&n,&one,A,&m,x,&inc,&zero,Ax,&inc);
}
// Update s
/* ATy = -c+ATy */
daxpy(&n, &mone, c, &inc, ATy, &inc);
/* ATy = beta2*ATy*/
dscal(&n, &beta2, ATy, &inc);
/* ATy = lambda2+ATy*/
daxpy(&n, &one, lambda2, &inc, ATy, &inc);
/* s = (-ts+sqrt(ts.^2+4*mu*beta2))/(2*beta2) */
for(j=0; j<n; j++) {
*(s+j) = (-*(ATy+j)+sqrt((*(ATy+j))*(*(ATy+j))+4.0*mu*beta2))/(2.0*beta2);
}
/* As = A*s */
dgemv(NTRANS,&m,&n,&one,A,&m,s,&inc,&zero,As,&inc);
//Update y
/* As = -Ac + As */
daxpy(&m, &mone, Ac, &inc, As, &inc);
/* As = -beta2*As */
temp = mone*beta2; dscal(&m, &temp, As, &inc);
/* As = b + As */
daxpy(&m, &one, b, &inc, As, &inc);
/* As = -Alambda2 + As */
daxpy(&m, &mone, Alambda2, &inc, As, &inc);
/* As = As/beta2 */
temp = 1.0/beta2; dscal(&m, &temp, As, &inc);
/* y = (AAT)^(-1)*As */
dpotrs(uplo, &m, &inc, AAT, &m, As, &m, &info);
dcopy(&m, As, &inc, y, &inc);
/* ATy = A'*y */
dgemv(TTRANS,&m,&n,&one,A,&m,y,&inc,&zero,ATy,&inc);
//Update multipliers
dcopy(&m, Ax, &inc, res1, &inc);
dcopy(&n, ATy, &inc, res2, &inc);
/* res1 = -b + res1 */
daxpy(&m, &mone, b, &inc, res1, &inc);
/* res2 = s + res2 */
daxpy(&n, &one, s, &inc, res2, &inc);
/* res2 = -c + res2 */
daxpy(&n, &mone, c, &inc, res2, &inc);
/* lambda1 = alpha*beta1*res1+lambda1 */
temp = alpha*beta1; daxpy(&m, &temp, res1, &inc, lambda1, &inc);
/* ATlambda1 = A'*lambda1 */
dgemv(TTRANS,&m,&n,&one,A,&m,lambda1,&inc,&zero,ATlambda1,&inc);
/* lambda2 = alpha*beta2*res2+lambda2 */
temp = alpha*beta2; daxpy(&n, &temp, res2, &inc, lambda2, &inc);
/* Alambda2 = A*lambda2 */
dgemv(NTRANS,&m,&n,&one,A,&m,lambda2,&inc,&zero,Alambda2,&inc);
//Stats
/* presi = ||Ax - b||/(1+||b||) */
temp = dnrm2(&m,res1,&inc);
presi = temp/(1.0+bnorm);
/* dresi = ||A'*y+s-c||/(1+||c||) */
temp = dnrm2(&n,res2,&inc);
dresi = temp/(1.0+cnorm);
pobj = ddot(&n, c, &inc, x, &inc);
dobj = ddot(&m, b, &inc, y, &inc);
dgap = fabs(pobj-dobj)/(1.0+fabs(pobj)+fabs(dobj));
/* tmpinv = 1.0./s */
for(j=0; j<n; j++) { *(tmpinv+j) = 1.0/(*(s+j));}
/* tmpres = (mu/beta2)*A*tmpinv */
temp = mu/beta2; dgemv(NTRANS,&m,&n,&temp,A,&m,tmpinv,&inc,&zero,tmpres,&inc);
/* tmpres = -b + tmpres */
daxpy(&n, &mone, b, &inc, tmpres, &inc);
temp = dnrm2(&m,tmpres,&inc); lambdaRes1 = temp/bnorm;
/* tmpinv = 1.0./x */
for(j=0; j<n; j++) { *(tmpinv+j) = 1.0/(*(x+j));}
/* tmpinv = -mu*tmpinv */
temp = mone*mu; dscal(&n, &temp, tmpinv, &inc);
/* tmpinv = c + tmpinv */
daxpy(&n, &one, c, &inc, tmpinv, &inc);
/* tmpinv = ATlambda1 + tmpinv*/
daxpy(&n, &one, ATlambda1, &inc, tmpinv, &inc);
temp = dnrm2(&n,tmpinv,&inc); lambdaRes2 = temp/cnorm;
if (MAX(MAX(presi,dresi),dgap) < tol) {
iter_all = iter_all + k;
return;
}
}
iter_all = iter_all + k;
mu= mu*gamma;
tol_admm = tol_admm*gamma*0.5;
}
return;
}
The following is the compiling command.
function Installmex
% src = pwd;
% sdet = 'src';
fname{1} = 'BLAS-BRADMM'; ofname{1} = 'BRADMMw'; fcc{1} = 'cpp';
hasMKL = 0; % with MKL or not
details = 0 ; % 1 if details of each command are to be printed
v = version ;
try
% ispc does not appear in MATLAB 5.3
pc = ispc ;
mac = ismac ;
catch %#ok
% if ispc fails, assume we are on a Windows PC if it's not unix
pc = ~isunix ;
mac = 0 ;
end
% if (~pc) && (~mac)
% mex -O -largeArrayDims -lmwlapack -lmwblas sfmult.cpp
% mex -O -largeArrayDims -lmwlapack -lmwblas dfeast.cpp
% return
% end
flags = '' ;
is64 = ~isempty (strfind (computer, '64')) ;
if (is64)
% 64-bit MATLAB
flags = '-largeArrayDims' ;
end
% MATLAB 8.3.0 now has a -silent option to keep 'mex' from burbling too much
if (~verLessThan ('matlab', '8.3.0'))
flags = ['-silent ' flags] ;
end
%---------------------------------------------------------------------------
% BLAS option
%---------------------------------------------------------------------------
% This is exceedingly ugly. The MATLAB mex command needs to be told where to
% fine the LAPACK and BLAS libraries, which is a real portability nightmare.
if (pc)
if (verLessThan ('matlab', '6.5'))
% MATLAB 6.1 and earlier: use the version supplied here
lapack = 'lcc_lib/libmwlapack.lib' ;
elseif (verLessThan ('matlab', '7.5'))
lapack = 'libmwlapack.lib' ;
else
lapack = 'libmwlapack.lib libmwblas.lib' ;
end
else
if (verLessThan ('matlab', '7.5'))
lapack = '-lmwlapack' ;
else
lapack = '-lmwlapack -lmwblas' ;
end
end
if (is64 && ~verLessThan ('matlab', '7.8'))
% versions 7.8 and later on 64-bit platforms use a 64-bit BLAS
fprintf ('with 64-bit BLAS\n') ;
flags = [flags ' -DBLAS64'] ;
end
if (~(pc || mac))
% for POSIX timing routine
lapack = [lapack ' -lrt'] ;
end
include = '';
mkl = '';
if hasMKL
include = ['-I',MKLHOMEINCLUDE];
if mac
mkl = ['', MKLHOMELIB,filesep,'libmkl_intel_lp64.dylib '];
mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_core.dylib '];
mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_intel_thread.dylib -ldl -lm '];
else
mkl = ['', MKLHOMELIB,filesep,'libmkl_intel_lp64.so '];
mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_core.so '];
mkl = [mkl, '', MKLHOMELIB,filesep,'libmkl_intel_thread.so -ldl -lm -lrt'];
end
end
if (verLessThan ('matlab', '7.0'))
% do not attempt to compile CHOLMOD with large file support
include = [include ' -DNLARGEFILE'] ;
elseif (~pc)
% Linux/Unix require these flags for large file support
include = [include ' -D_FILE_OFFSET_BITS=64 -D_LARGEFILE64_SOURCE'] ;
end
if (verLessThan ('matlab', '6.5'))
% logical class does not exist in MATLAB 6.1 or earlie
include = [include ' -DMATLAB6p1_OR_EARLIER'] ;
end
% compile each mexFunction
for k = 1:length(fname)
s = sprintf ('mex %s -DDLONG -O %s %s.%s -output %s', flags, ...
include, fname{k}, fcc{k}, ofname{k}) ;
s = [s '' lapack mkl] ; %#ok
%s = [s '' mkl] ;
cmd (s, details) ;
end
%------------------------------------------------------------------------------
function cmd (s, details)
%DO_CMD: evaluate a command, and either print it or print a "."
if (details)
fprintf ('%s\n', s) ;
end
eval (s) ;