矩阵乘法cache优化

简介:

好文要转,太棒了~~~~~~~~~~~~~~~~~~~~~~~~~

题目地址:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1113

昨晚为了优化这个题目弄到2点多,今天一早就写博,我真是太不蛋定了,哈哈。

做OJ的朋友都知道快速幂,我就不罗嗦了,我说的主要是矩阵乘法实现层面的优化。


最开始我的代码耗时1156ms,代码如下:

  1. void  mat_mul( int (*a)[MAXN], int (*b)[MAXN], int (*c)[MAXN], int n ) {  
  2.         int  i, j, k;  
  3.         ULL  sum;  
  4.         for( i = 0; i < n; ++i ) {  
  5.                 for( j = 0; j < n; ++j ) {  
  6.                         sum = 0;  
  7.                         for( k = 0; k < n; ++k ) sum = ( sum + (ULL)a[i][k] * (ULL)b[k][j] ) % P;  
  8.                         c[i][j] = sum;  
  9.                 }  
  10.         }  
  11. }  
因为结果需要取模,所以为了节省内存,我用了int数组存储矩阵(如果用long long,内存增加一倍)。

我脑海中就是模拟笔算,每次确定一个c的一个位置,然后用a对应的行和b对应的列去累加乘积。

用一个local variable sum来存储和,最后赋值给c[i][j],这样sum应该会被优化成一个寄存器。

代码的问题是:1 每次计算都要取模,而且是在最内层循环,是程序运算量最大的地方;2 对b的取值是按列取的,增大了cpu cache的miss率,我们都知道,按照顺序读取内存是最有效率的。


我想,也许数据没有那么强,可以取个巧,累计时不取模(假设sum一直不会溢出),最后才取模,这样取模操作就由O(n^3),降为O(n^2)。

结果得到了一个WA,也就是说,数据累加和超过了unsigned long long的最大范围(2^64-1,大约是18*10^18)代码如下:

  1. void  mat_mul( int (*a)[MAXN], int (*b)[MAXN], int (*c)[MAXN], int n ) {  
  2.         int  i, j, k;  
  3.         ULL  sum;  
  4.         for( i = 0; i < n; ++i ) {  
  5.                 for( j = 0; j < n; ++j ) {  
  6.                         sum = 0;  
  7.                         for( k = 0; k < n; ++k ) sum += (ULL)a[i][k] * (ULL)b[k][j];  
  8.                         c[i][j] = sum % P;  
  9.                 }  
  10.         }  
  11. }  

接下去我又做了一些尝试,包括:1 输出部分之前有个if判断,将这个分支拆开;2 尝试将unsigned long long换成long long;3 和排名靠前的网友对比代码(比我快100ms左右),暂时只发现他是用long long存储矩阵的。

结果时间没有变化,甚至3还导致我的内存增大一倍。我想哭了,同样都是3重循环,做人的差距咋就这么大捏?

于是我又认真的研究了锋巅网友的代码,为啥就比我快,终于让我发现了,原来他循环的顺序和我有区别。

照着修改得到了一个828ms的代码:

  1. void  mat_mul( int (*a)[MAXN], int (*b)[MAXN], int (*c)[MAXN], int n ) {  
  2.         int  i, j, k, *p1, *p2, *end;  
  3.         ULL  tmp;  
  4.         memset( c, 0, sizeof( A[0] ) );  
  5.         for( i = 0; i < n; ++i ) {  
  6.                 for( k = 0; k < n; ++k ) {  
  7.                         tmp = a[i][k];  
  8.                         for( p1 = c[i], p2 = b[k], end = p1 + n; p1 != end; ++p1, ++p2 )  
  9.                                 *p1 = ( *p1 + tmp * (*p2) ) % P;  // c[i][j] = ( c[i][j] + a[i][k] * b[k][j] ) % P;  
  10.                 }  
  11.         }  
  12. }  
这份代码将k循环移动到了中间,最内层变成了j循环(我用指针改写了,含义就是注释的那句,效率应该不会比二维引用的效率差,这个有待确定)。

这个代码的意义在于:在最内层循环,对于c和b的访问都是顺序的了,而这个循环中a[i][k]不变,这样就更好的利用了cpu cache。矩阵越大,这个加速效果越明显。

以后的实际工程,应该也会用到这个思路。


最后就是对取模的优化,既然全部累加不行,那我就部分累加,然后取一次模,这样终究可以减少取模这种最耗时的操作。

分析数据,假设a和b矩阵的数据都接近最大可能值(对于P取模,最大值是P-1,P的值大约是10^9),那么一次乘积就是10^18,那么一个unsigned long long大约可以放18个累加。我取了16个,每累加16次(取16一方面是因为已经比较接近18了,另一方面是可以很好地利用位操作),取一次模,这样取模次数大约变成原来的1/16,当然判断16这个次数又增加了分支,不过这个相对于取模的优化,已经几乎可以忽略了。耗时484ms,代码如下(有点ugly):

  1. void  mat_mul( int  (*a)[MAXN], int (*b)[MAXN], int  (*c)[MAXN], int n ) {  
  2.         int  i, j, k, L, *p2;  
  3.         ULL  tmp[MAXN], con;  
  4.         //memset( c, 0, sizeof( A[0] ) );  
  5.         for( i = 0; i < n; ++i ) {  
  6.                 memset( tmp, 0, sizeof( tmp ) );  
  7.                 for( k = 0, L = (n & ~15); k < L; ++k ) {  
  8.                         con = a[i][k];  
  9.                         for( j = 0, p2 = b[k]; j < n; ++j, ++p2 )  
  10.                                 tmp[j] += con * (*p2);  
  11.                         if( ( k & 15 ) == 15 ) {  
  12.                                 for( j = 0; j < n; ++j ) tmp[j] %= P;  
  13.                         }  
  14.                 }  
  15.                   
  16.                 for( ; k < n; ++k ) {  
  17.                         con = a[i][k];  
  18.                         for( j = 0, p2 = b[k]; j < n; ++j, ++p2 )  
  19.                                 tmp[j] += con * (*p2);  
  20.                 }  
  21.                  for( j = 0; j < n; ++j )  
  22.                         c[i][j] = tmp[j] % P;  
  23.         }  
  24. }  
代码变长了,L是16的整倍数,后面的循环是处理不足16剩下的。用了tmp数组来优化(按我的理解,栈变量会比全局变量的访问更快)。


当然,这个题还可以优化IO,因为输入输出量很大,但是这样带来的速度提升意义不大,就没再修改。


本来以为不用写了,没想到刚才又做了一个优化,竟然达到了265ms,既然如此,还是再写一下吧。。。

代码如下:

  1. void  mat_mul( int  (*a)[MAXN], int (*b)[MAXN], int  (*c)[MAXN], int n ) {  
  2.         int  i, j, k, L, *p2;  
  3.         ULL  tmp[MAXN], con;  
  4.         //memset( c, 0, sizeof( A[0] ) );  
  5.         for( i = 0; i < n; ++i ) {  
  6.                 memset( tmp, 0, sizeof( tmp ) );  
  7.                 for( k = 0, L = (n & ~15); k < L; ) {  
  8.   
  9. #define  OP  do { for( con = a[i][k], p2 = b[k], j = 0; j < n; ++j, ++p2 ) \  
  10.                         tmp[j] += con * (*p2); \  
  11.                     ++k; } while(0)  
  12.                         OP; OP; OP; OP;  
  13.                         OP; OP; OP; OP;  
  14.                         OP; OP; OP; OP;  
  15.                         OP; OP; OP; OP;  
  16.                           
  17.                         for( j = 0; j < n; ++j ) tmp[j] %= P;  
  18.                 }  
  19.                   
  20.                 for( ; k < n; ) {  
  21.                         OP;  
  22.                 }  
  23.                  for( j = 0; j < n; ++j )  
  24.                         c[i][j] = tmp[j] % P;  
  25.         }  
  26. }  


这个代码相当于去掉了分支预测,手动将16个操作展开,没想到效果这么明显。
相关文章
|
1天前
|
缓存 并行计算 负载均衡
大模型推理优化实践:KV cache复用与投机采样
在本文中,我们将详细介绍两种在业务中实践的优化策略:多轮对话间的 KV cache 复用技术和投机采样方法。我们会细致探讨这些策略的应用场景、框架实现,并分享一些实现时的关键技巧。
|
2天前
|
前端开发
R语言用HESSIAN-FREE 、NELDER-MEAD优化方法对数据进行参数估计
R语言用HESSIAN-FREE 、NELDER-MEAD优化方法对数据进行参数估计
24 2
|
1月前
|
存储 缓存 算法
深入探究LRU缓存机制:优化内存利用与提升性能
深入探究LRU缓存机制:优化内存利用与提升性能
114 1
|
8月前
|
芯片 Anolis
性能优化特性之:TLBI - TLB range优化
本文介绍了倚天实例上的内存优化特性:TLBi,并从优化原理、使用方法进行详细阐述。
|
10月前
|
缓存 编译器 C++
C/C++性能提升之cache分析
C/C++性能提升之cache分析
241 0
|
11月前
|
机器学习/深度学习
Lesson 5.3 ROC-AUC 的计算方法、基本原理与核心特性
Lesson 5.3 ROC-AUC 的计算方法、基本原理与核心特性
|
11月前
|
机器学习/深度学习 移动开发 编解码
Skip-Attention:一种能显著降低Transformer计算量的模型轻量化方法
Skip-Attention:一种能显著降低Transformer计算量的模型轻量化方法
418 0
【5分钟+】计算机系统结构:CPU性能公式
【5分钟+】计算机系统结构:CPU性能公式
598 0
【5分钟+】计算机系统结构:CPU性能公式
|
算法 C语言 C++
非线性优化--NLopt算法使用及C++实例
非线性优化--NLopt算法使用及C++实例
非线性优化--NLopt算法使用及C++实例
|
算法
内存管理——页面置换算法计算缺页率
内存管理——页面置换算法计算缺页率
422 0