in modules/core/src/matmul.cpp [869:1462]
void cv::gemm( InputArray matA, InputArray matB, double alpha,
InputArray matC, double beta, OutputArray _matD, int flags )
{
#ifdef HAVE_CLAMDBLAS
CV_OCL_RUN(ocl::haveAmdBlas() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2 && _matD.isUMat() &&
matA.cols() > 20 && matA.rows() > 20 && matB.cols() > 20, // since it works incorrect for small sizes
ocl_gemm_amdblas(matA, matB, alpha, matC, beta, _matD, flags))
#endif
#ifdef HAVE_OPENCL
CV_OCL_RUN(_matD.isUMat() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2,
ocl_gemm(matA, matB, alpha, matC, beta, _matD, flags))
#endif
const int block_lin_size = 128;
const int block_size = block_lin_size * block_lin_size;
static double zero[] = {0,0,0,0};
static float zerof[] = {0,0,0,0};
Mat A = matA.getMat(), B = matB.getMat(), C = beta != 0 ? matC.getMat() : Mat();
Size a_size = A.size(), d_size;
int i, len = 0, type = A.type();
CV_Assert( type == B.type() && (type == CV_32FC1 || type == CV_64FC1 || type == CV_32FC2 || type == CV_64FC2) );
switch( flags & (GEMM_1_T|GEMM_2_T) )
{
case 0:
d_size = Size( B.cols, a_size.height );
len = B.rows;
CV_Assert( a_size.width == len );
break;
case 1:
d_size = Size( B.cols, a_size.width );
len = B.rows;
CV_Assert( a_size.height == len );
break;
case 2:
d_size = Size( B.rows, a_size.height );
len = B.cols;
CV_Assert( a_size.width == len );
break;
case 3:
d_size = Size( B.rows, a_size.width );
len = B.cols;
CV_Assert( a_size.height == len );
break;
}
if( !C.empty() )
{
CV_Assert( C.type() == type &&
(((flags&GEMM_3_T) == 0 && C.rows == d_size.height && C.cols == d_size.width) ||
((flags&GEMM_3_T) != 0 && C.rows == d_size.width && C.cols == d_size.height)));
}
_matD.create( d_size.height, d_size.width, type );
Mat D = _matD.getMat();
if( (flags & GEMM_3_T) != 0 && C.data == D.data )
{
transpose( C, C );
flags &= ~GEMM_3_T;
}
if( flags == 0 && 2 <= len && len <= 4 && (len == d_size.width || len == d_size.height) )
{
if( type == CV_32F )
{
float* d = D.ptr<float>();
const float *a = A.ptr<float>(),
*b = B.ptr<float>(),
*c = (const float*)C.data;
size_t d_step = D.step/sizeof(d[0]),
a_step = A.step/sizeof(a[0]),
b_step = B.step/sizeof(b[0]),
c_step = C.data ? C.step/sizeof(c[0]) : 0;
if( !c )
c = zerof;
switch( len )
{
case 2:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
float t0 = a[0]*b[0] + a[1]*b[b_step];
float t1 = a[0]*b[1] + a[1]*b[b_step+1];
d[0] = (float)(t0*alpha + c[0]*beta);
d[1] = (float)(t1*alpha + c[1]*beta);
}
}
else if( a != d )
{
int c_step0 = 1;
if( c == zerof )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
float t0 = a[0]*b[0] + a[1]*b[b_step];
float t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step];
d[0] = (float)(t0*alpha + c[0]*beta);
d[d_step] = (float)(t1*alpha + c[c_step]*beta);
}
}
else
break;
return;
case 3:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
float t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2];
float t1 = a[0]*b[1] + a[1]*b[b_step+1] + a[2]*b[b_step*2+1];
float t2 = a[0]*b[2] + a[1]*b[b_step+2] + a[2]*b[b_step*2+2];
d[0] = (float)(t0*alpha + c[0]*beta);
d[1] = (float)(t1*alpha + c[1]*beta);
d[2] = (float)(t2*alpha + c[2]*beta);
}
}
else if( a != d )
{
int c_step0 = 1;
if( c == zerof )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
float t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2];
float t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step] + a[a_step+2]*b[b_step*2];
float t2 = a[a_step*2]*b[0] + a[a_step*2+1]*b[b_step] + a[a_step*2+2]*b[b_step*2];
d[0] = (float)(t0*alpha + c[0]*beta);
d[d_step] = (float)(t1*alpha + c[c_step]*beta);
d[d_step*2] = (float)(t2*alpha + c[c_step*2]*beta);
}
}
else
break;
return;
case 4:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
float t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2] + a[3]*b[b_step*3];
float t1 = a[0]*b[1] + a[1]*b[b_step+1] + a[2]*b[b_step*2+1] + a[3]*b[b_step*3+1];
float t2 = a[0]*b[2] + a[1]*b[b_step+2] + a[2]*b[b_step*2+2] + a[3]*b[b_step*3+2];
float t3 = a[0]*b[3] + a[1]*b[b_step+3] + a[2]*b[b_step*2+3] + a[3]*b[b_step*3+3];
d[0] = (float)(t0*alpha + c[0]*beta);
d[1] = (float)(t1*alpha + c[1]*beta);
d[2] = (float)(t2*alpha + c[2]*beta);
d[3] = (float)(t3*alpha + c[3]*beta);
}
}
else if( len <= 16 && a != d )
{
int c_step0 = 1;
if( c == zerof )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
float t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2] + a[3]*b[b_step*3];
float t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step] +
a[a_step+2]*b[b_step*2] + a[a_step+3]*b[b_step*3];
float t2 = a[a_step*2]*b[0] + a[a_step*2+1]*b[b_step] +
a[a_step*2+2]*b[b_step*2] + a[a_step*2+3]*b[b_step*3];
float t3 = a[a_step*3]*b[0] + a[a_step*3+1]*b[b_step] +
a[a_step*3+2]*b[b_step*2] + a[a_step*3+3]*b[b_step*3];
d[0] = (float)(t0*alpha + c[0]*beta);
d[d_step] = (float)(t1*alpha + c[c_step]*beta);
d[d_step*2] = (float)(t2*alpha + c[c_step*2]*beta);
d[d_step*3] = (float)(t3*alpha + c[c_step*3]*beta);
}
}
else
break;
return;
}
}
if( type == CV_64F )
{
double* d = D.ptr<double>();
const double *a = A.ptr<double>(),
*b = B.ptr<double>(),
*c = (const double*)C.data;
size_t d_step = D.step/sizeof(d[0]),
a_step = A.step/sizeof(a[0]),
b_step = B.step/sizeof(b[0]),
c_step = C.data ? C.step/sizeof(c[0]) : 0;
if( !c )
c = zero;
switch( len )
{
case 2:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
double t0 = a[0]*b[0] + a[1]*b[b_step];
double t1 = a[0]*b[1] + a[1]*b[b_step+1];
d[0] = t0*alpha + c[0]*beta;
d[1] = t1*alpha + c[1]*beta;
}
}
else if( a != d )
{
int c_step0 = 1;
if( c == zero )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
double t0 = a[0]*b[0] + a[1]*b[b_step];
double t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step];
d[0] = t0*alpha + c[0]*beta;
d[d_step] = t1*alpha + c[c_step]*beta;
}
}
else
break;
return;
case 3:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
double t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2];
double t1 = a[0]*b[1] + a[1]*b[b_step+1] + a[2]*b[b_step*2+1];
double t2 = a[0]*b[2] + a[1]*b[b_step+2] + a[2]*b[b_step*2+2];
d[0] = t0*alpha + c[0]*beta;
d[1] = t1*alpha + c[1]*beta;
d[2] = t2*alpha + c[2]*beta;
}
}
else if( a != d )
{
int c_step0 = 1;
if( c == zero )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
double t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2];
double t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step] + a[a_step+2]*b[b_step*2];
double t2 = a[a_step*2]*b[0] + a[a_step*2+1]*b[b_step] + a[a_step*2+2]*b[b_step*2];
d[0] = t0*alpha + c[0]*beta;
d[d_step] = t1*alpha + c[c_step]*beta;
d[d_step*2] = t2*alpha + c[c_step*2]*beta;
}
}
else
break;
return;
case 4:
if( len == d_size.width && b != d )
{
for( i = 0; i < d_size.height; i++, d += d_step, a += a_step, c += c_step )
{
double t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2] + a[3]*b[b_step*3];
double t1 = a[0]*b[1] + a[1]*b[b_step+1] + a[2]*b[b_step*2+1] + a[3]*b[b_step*3+1];
double t2 = a[0]*b[2] + a[1]*b[b_step+2] + a[2]*b[b_step*2+2] + a[3]*b[b_step*3+2];
double t3 = a[0]*b[3] + a[1]*b[b_step+3] + a[2]*b[b_step*2+3] + a[3]*b[b_step*3+3];
d[0] = t0*alpha + c[0]*beta;
d[1] = t1*alpha + c[1]*beta;
d[2] = t2*alpha + c[2]*beta;
d[3] = t3*alpha + c[3]*beta;
}
}
else if( d_size.width <= 16 && a != d )
{
int c_step0 = 1;
if( c == zero )
{
c_step0 = 0;
c_step = 1;
}
for( i = 0; i < d_size.width; i++, d++, b++, c += c_step0 )
{
double t0 = a[0]*b[0] + a[1]*b[b_step] + a[2]*b[b_step*2] + a[3]*b[b_step*3];
double t1 = a[a_step]*b[0] + a[a_step+1]*b[b_step] +
a[a_step+2]*b[b_step*2] + a[a_step+3]*b[b_step*3];
double t2 = a[a_step*2]*b[0] + a[a_step*2+1]*b[b_step] +
a[a_step*2+2]*b[b_step*2] + a[a_step*2+3]*b[b_step*3];
double t3 = a[a_step*3]*b[0] + a[a_step*3+1]*b[b_step] +
a[a_step*3+2]*b[b_step*2] + a[a_step*3+3]*b[b_step*3];
d[0] = t0*alpha + c[0]*beta;
d[d_step] = t1*alpha + c[c_step]*beta;
d[d_step*2] = t2*alpha + c[c_step*2]*beta;
d[d_step*3] = t3*alpha + c[c_step*3]*beta;
}
}
else
break;
return;
}
}
}
{
size_t b_step = B.step;
GEMMSingleMulFunc singleMulFunc;
GEMMBlockMulFunc blockMulFunc;
GEMMStoreFunc storeFunc;
Mat *matD = &D, tmat;
size_t tmat_size = 0;
const uchar* Cdata = C.data;
size_t Cstep = C.data ? (size_t)C.step : 0;
AutoBuffer<uchar> buf;
if( type == CV_32FC1 )
{
singleMulFunc = (GEMMSingleMulFunc)GEMMSingleMul_32f;
blockMulFunc = (GEMMBlockMulFunc)GEMMBlockMul_32f;
storeFunc = (GEMMStoreFunc)GEMMStore_32f;
}
else if( type == CV_64FC1 )
{
singleMulFunc = (GEMMSingleMulFunc)GEMMSingleMul_64f;
blockMulFunc = (GEMMBlockMulFunc)GEMMBlockMul_64f;
storeFunc = (GEMMStoreFunc)GEMMStore_64f;
}
else if( type == CV_32FC2 )
{
singleMulFunc = (GEMMSingleMulFunc)GEMMSingleMul_32fc;
blockMulFunc = (GEMMBlockMulFunc)GEMMBlockMul_32fc;
storeFunc = (GEMMStoreFunc)GEMMStore_32fc;
}
else
{
CV_Assert( type == CV_64FC2 );
singleMulFunc = (GEMMSingleMulFunc)GEMMSingleMul_64fc;
blockMulFunc = (GEMMBlockMulFunc)GEMMBlockMul_64fc;
storeFunc = (GEMMStoreFunc)GEMMStore_64fc;
}
if( D.data == A.data || D.data == B.data )
{
tmat_size = (size_t)d_size.width*d_size.height*CV_ELEM_SIZE(type);
// Allocate tmat later, once the size of buf is known
matD = &tmat;
}
if( (d_size.width == 1 || len == 1) && !(flags & GEMM_2_T) && B.isContinuous() )
{
b_step = d_size.width == 1 ? 0 : CV_ELEM_SIZE(type);
flags |= GEMM_2_T;
}
/*if( (d_size.width | d_size.height | len) >= 16 && icvBLAS_GEMM_32f_p != 0 )
{
blas_func = type == CV_32FC1 ? (icvBLAS_GEMM_32f_t)icvBLAS_GEMM_32f_p :
type == CV_64FC1 ? (icvBLAS_GEMM_32f_t)icvBLAS_GEMM_64f_p :
type == CV_32FC2 ? (icvBLAS_GEMM_32f_t)icvBLAS_GEMM_32fc_p :
type == CV_64FC2 ? (icvBLAS_GEMM_32f_t)icvBLAS_GEMM_64fc_p : 0;
}
if( blas_func )
{
const char* transa = flags & GEMM_1_T ? "t" : "n";
const char* transb = flags & GEMM_2_T ? "t" : "n";
int lda, ldb, ldd;
if( C->data.ptr )
{
if( C->data.ptr != D->data.ptr )
{
if( !(flags & GEMM_3_T) )
cvCopy( C, D );
else
cvTranspose( C, D );
}
}
if( CV_MAT_DEPTH(type) == CV_32F )
{
Complex32f _alpha, _beta;
lda = A->step/sizeof(float);
ldb = b_step/sizeof(float);
ldd = D->step/sizeof(float);
_alpha.re = (float)alpha;
_alpha.im = 0;
_beta.re = C->data.ptr ? (float)beta : 0;
_beta.im = 0;
if( CV_MAT_CN(type) == 2 )
lda /= 2, ldb /= 2, ldd /= 2;
blas_func( transb, transa, &d_size.width, &d_size.height, &len,
&_alpha, B->data.ptr, &ldb, A->data.ptr, &lda,
&_beta, D->data.ptr, &ldd );
}
else
{
CvComplex64f _alpha, _beta;
lda = A->step/sizeof(double);
ldb = b_step/sizeof(double);
ldd = D->step/sizeof(double);
_alpha.re = alpha;
_alpha.im = 0;
_beta.re = C->data.ptr ? beta : 0;
_beta.im = 0;
if( CV_MAT_CN(type) == 2 )
lda /= 2, ldb /= 2, ldd /= 2;
blas_func( transb, transa, &d_size.width, &d_size.height, &len,
&_alpha, B->data.ptr, &ldb, A->data.ptr, &lda,
&_beta, D->data.ptr, &ldd );
}
}
else*/ if( ((d_size.height <= block_lin_size/2 || d_size.width <= block_lin_size/2) &&
len <= 10000) || len <= 10 ||
(d_size.width <= block_lin_size &&
d_size.height <= block_lin_size && len <= block_lin_size) )
{
if( tmat_size > 0 ) {
buf.allocate(tmat_size);
tmat = Mat(d_size.height, d_size.width, type, (uchar*)buf );
}
singleMulFunc( A.ptr(), A.step, B.ptr(), b_step, Cdata, Cstep,
matD->ptr(), matD->step, a_size, d_size, alpha, beta, flags );
}
else
{
int is_a_t = flags & GEMM_1_T;
int is_b_t = flags & GEMM_2_T;
int elem_size = CV_ELEM_SIZE(type);
int dk0_1, dk0_2;
size_t a_buf_size = 0, b_buf_size, d_buf_size;
uchar* a_buf = 0;
uchar* b_buf = 0;
uchar* d_buf = 0;
int j, k, di = 0, dj = 0, dk = 0;
int dm0, dn0, dk0;
size_t a_step0, a_step1, b_step0, b_step1, c_step0, c_step1;
int work_elem_size = elem_size << (CV_MAT_DEPTH(type) == CV_32F ? 1 : 0);
if( !is_a_t )
a_step0 = A.step, a_step1 = elem_size;
else
a_step0 = elem_size, a_step1 = A.step;
if( !is_b_t )
b_step0 = b_step, b_step1 = elem_size;
else
b_step0 = elem_size, b_step1 = b_step;
if( C.empty() )
{
c_step0 = c_step1 = 0;
flags &= ~GEMM_3_T;
}
else if( !(flags & GEMM_3_T) )
c_step0 = C.step, c_step1 = elem_size;
else
c_step0 = elem_size, c_step1 = C.step;
dm0 = std::min( block_lin_size, d_size.height );
dn0 = std::min( block_lin_size, d_size.width );
dk0_1 = block_size / dm0;
dk0_2 = block_size / dn0;
dk0 = std::min( dk0_1, dk0_2 );
dk0 = std::min( dk0, len );
if( dk0*dm0 > block_size )
dm0 = block_size / dk0;
if( dk0*dn0 > block_size )
dn0 = block_size / dk0;
dk0_1 = (dn0+dn0/8+2) & -2;
b_buf_size = (size_t)(dk0+dk0/8+1)*dk0_1*elem_size;
d_buf_size = (size_t)(dk0+dk0/8+1)*dk0_1*work_elem_size;
if( is_a_t )
{
a_buf_size = (size_t)(dm0+dm0/8+1)*((dk0+dk0/8+2)&-2)*elem_size;
flags &= ~GEMM_1_T;
}
buf.allocate(d_buf_size + b_buf_size + a_buf_size + tmat_size);
d_buf = (uchar*)buf;
b_buf = d_buf + d_buf_size;
if( is_a_t )
a_buf = b_buf + b_buf_size;
if( tmat_size > 0 )
tmat = Mat(d_size.height, d_size.width, type, b_buf + b_buf_size + a_buf_size );
for( i = 0; i < d_size.height; i += di )
{
di = dm0;
if( i + di >= d_size.height || 8*(i + di) + di > 8*d_size.height )
di = d_size.height - i;
for( j = 0; j < d_size.width; j += dj )
{
uchar* _d = matD->ptr() + i*matD->step + j*elem_size;
const uchar* _c = Cdata + i*c_step0 + j*c_step1;
size_t _d_step = matD->step;
dj = dn0;
if( j + dj >= d_size.width || 8*(j + dj) + dj > 8*d_size.width )
dj = d_size.width - j;
flags &= 15;
if( dk0 < len )
{
_d = d_buf;
_d_step = dj*work_elem_size;
}
for( k = 0; k < len; k += dk )
{
const uchar* _a = A.ptr() + i*a_step0 + k*a_step1;
size_t _a_step = A.step;
const uchar* _b = B.ptr() + k*b_step0 + j*b_step1;
size_t _b_step = b_step;
Size a_bl_size;
dk = dk0;
if( k + dk >= len || 8*(k + dk) + dk > 8*len )
dk = len - k;
if( !is_a_t )
a_bl_size.width = dk, a_bl_size.height = di;
else
a_bl_size.width = di, a_bl_size.height = dk;
if( a_buf && is_a_t )
{
_a_step = dk*elem_size;
GEMM_TransposeBlock( _a, A.step, a_buf, _a_step, a_bl_size, elem_size );
std::swap( a_bl_size.width, a_bl_size.height );
_a = a_buf;
}
if( dj < d_size.width )
{
Size b_size;
if( !is_b_t )
b_size.width = dj, b_size.height = dk;
else
b_size.width = dk, b_size.height = dj;
_b_step = b_size.width*elem_size;
GEMM_CopyBlock( _b, b_step, b_buf, _b_step, b_size, elem_size );
_b = b_buf;
}
if( dk0 < len )
blockMulFunc( _a, _a_step, _b, _b_step, _d, _d_step,
a_bl_size, Size(dj,di), flags );
else
singleMulFunc( _a, _a_step, _b, _b_step, _c, Cstep,
_d, _d_step, a_bl_size, Size(dj,di), alpha, beta, flags );
flags |= 16;
}
if( dk0 < len )
storeFunc( _c, Cstep, _d, _d_step,
matD->ptr(i) + j*elem_size,
matD->step, Size(dj,di), alpha, beta, flags );
}
}
}
if( matD != &D )
matD->copyTo(D);
}
}