Tuesday, November 19, 2013

Java implemetation of Recursive Matrix Multiplication Algorithm (Part 2)

This new implementation of Recursive Matrix Multiplication had a few unexpected observations.
The performance of this recursive implementation (without concurrency) is faster that simple matrix multiplication(without recursion)!
I still do not understand the reason, but it may be related to swapping memory space or something that low level.

As this result shows, at dim 4096, recursive version is 9 times faster than simple version.

[simple_test_loop]
>> dim: 2, elapsed time: 0 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 0 milli sec
>> dim: 16, elapsed time: 0 milli sec
>> dim: 32, elapsed time: 1 milli sec
>> dim: 64, elapsed time: 6 milli sec
>> dim: 128, elapsed time: 6 milli sec
>> dim: 256, elapsed time: 24 milli sec
>> dim: 512, elapsed time: 206 milli sec
>> dim: 1024, elapsed time: 6478 milli sec
>> dim: 2048, elapsed time: 90698 milli sec
>> dim: 4096, elapsed time: 1023754 milli sec

[rec_matrix_test1]
>> dim: 2, elapsed time: 1 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 0 milli sec
>> dim: 16, elapsed time: 0 milli sec
>> dim: 32, elapsed time: 5 milli sec
>> dim: 64, elapsed time: 15 milli sec
>> dim: 128, elapsed time: 5 milli sec
>> dim: 256, elapsed time: 30 milli sec
>> dim: 512, elapsed time: 217 milli sec
>> dim: 1024, elapsed time: 1915 milli sec
>> dim: 2048, elapsed time: 14691 milli sec
>> dim: 4096, elapsed time: 115860 milli sec

Also the result I got in the first implementation was so slow, I thought it seems no point implementing fast algorithm in Java, but new version were surprising fast and no GC is involved during calculation since no such intermediate objects are created. compared to the previous version, at dim 256 case, 13757/20 = 688 time faster. also if we closely examine the unit performance of calculation, they are constant(almost), so kind of linear performance.

[rec]
>> dim: 2, elapsed time: 3 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 2 milli sec
>> dim: 16, elapsed time: 32 milli sec
>> dim: 32, elapsed time: 378 milli sec
>> dim: 64, elapsed time: 494 milli sec
>> dim: 128, elapsed time: 1620 milli sec
>> dim: 256, elapsed time: 13757 milli sec

Further, if we introduce the multiple threads(4 threads(4 cores)), the performance will be quadrupled, or maybe this is a bit exaggerated, but around 3-4 range.
This number of thread should be close to the number of CPU cores to get maximum performance.

>> dim: 2, elapsed time: 0.0 milli sec, dim^p; 8, ration(time/dim^3)=0
>> dim: 4, elapsed time: 0.0 milli sec, dim^p; 64, ration(time/dim^3)=0
>> dim: 8, elapsed time: 1.0 milli sec, dim^p; 512, ration(time/dim^3)=1,953.125
>> dim: 16, elapsed time: 1.0 milli sec, dim^p; 4,096, ration(time/dim^3)=244.141
>> dim: 32, elapsed time: 3.0 milli sec, dim^p; 32,768, ration(time/dim^3)=91.553
>> dim: 64, elapsed time: 17.0 milli sec, dim^p; 262,144, ration(time/dim^3)=64.85
>> dim: 128, elapsed time: 4.0 milli sec, dim^p; 2,097,152, ration(time/dim^3)=1.907
>> dim: 256, elapsed time: 18.0 milli sec, dim^p; 16,777,216, ration(time/dim^3)=1.073
>> dim: 512, elapsed time: 105.0 milli sec, dim^p; 134,217,728, ration(time/dim^3)=0.782
>> dim: 1024, elapsed time: 602.0 milli sec, dim^p; 1,073,741,824, ration(time/dim^3)=0.561
>> dim: 2048, elapsed time: 4771.0 milli sec, dim^p; 8,589,934,592, ration(time/dim^3)=0.555
>> dim: 4096, elapsed time: 38704.0 milli sec, dim^p; 68,719,476,736, ration(time/dim^3)=0.563

dim^3 is the actual number of multiplications of long value. So this took only 0.6 second to calculate 1 billion multiplications. not bad. and even we change the long to double, the performance is almost the same.
And this does involve calculation to traverse matrix locations as well.

The matrix of dimension 1000 can express a network of 1000 nodes, so there would be practical applications.
And in this algorithm, it is very easy to utilize more cores, 4, 16, 256, ... so this number is scalable with the increase of cores in future.

At this moment, I checked a paper on Strassen algorithm implementation on NVidia GPU-GPU.

The number was actually quite impressive.
at dim = 2048, it took only 48 millisec. so 100 time faster than 4 threads(4 cores) version!

BTW, following is the my CPU spec. it is intel I5, 4 cores, 32GB RAM

Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                4
On-line CPU(s) list:   0-3
Thread(s) per core:    1
Core(s) per socket:    4
Socket(s):             1
NUMA node(s):          1
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 58
Stepping:              9
CPU MHz:               1600.000
BogoMIPS:              6935.22
Virtualization:        VT-x
L1d cache:             32K                                                                                                    
L1i cache:             32K                                                                                                    
L2 cache:              256K                                                                                                   
L3 cache:              6144K                                                                                                  
NUMA node0 CPU(s):     0-3                       

Following code are used for this testing.

package org.zen;

import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.util.Random;

public class Matrix_v6 {
 static Random rand = new Random(); 
 static long[][] _randomMatrix(int dim) {
  final long[][] matrix = new long[dim][dim];
  for (int i = 0; i < dim; i++) {
   for (int j = 0; j < dim; j++) {
    matrix[i][j] = rand.nextInt(50);
   }
  }
  return matrix;
 }
 static long[][] _zeroMatrix(int dim) {
  final long[][] matrix = new long[dim][dim];
  for (int i = 0; i < dim; i++) {
   for (int j = 0; j < dim; j++) {
    matrix[i][j] = 0;
   }
  }
  return matrix;
 }
 
 //
 // AbstractMatrix
 //
 public static abstract class AbstractMatrix {
  protected int _dim;
  protected long[][] _matrix;
  
  public AbstractMatrix(long[][] matrix) {
   this._dim = matrix.length;
   this._matrix = matrix;
  }
  
  public int dim() {
   return _dim;
  }

  public SimpleMatrix add(SimpleMatrix m) {
   long[][] m0 = _matrix;
   long[][] m1 = m._matrix;
   long[][] m2 = new long[_dim][_dim];
   for (int i = 0; i < _dim; i++) {
    for (int j = 0; j < _dim; j++) {
     m2[i][j] = m0[i][j]+m1[i][j];
    }
   }
   return new SimpleMatrix(m2);
  }

  public SimpleMatrix subtract(SimpleMatrix m) {
   long[][] m0 = _matrix;
   long[][] m1 = m._matrix;
   long[][] m2 = new long[_dim][_dim];
   for (int i = 0; i < _dim; i++) {
    for (int j = 0; j < _dim; j++) {
     m2[i][j] = m0[i][j]-m1[i][j];
    }
   }
   return new SimpleMatrix(m2);
  }
  
  public String toString() {
   return toString(_matrix);
  }
  
  public static String toString(long[][] matrix) {
   int dim = matrix.length;
   StringBuilder sb = new StringBuilder();
   sb.append("(\n");
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     long v = matrix[i][j];
     if (j != 0) {
      sb.append(", ");
     }
     sb.append(v);
    }
    sb.append("\n");
   }
   sb.append(")\n");
   return sb.toString();
  }
  
  public boolean equals(Object obj) {
   if (obj instanceof SimpleMatrix) {
    SimpleMatrix m = (SimpleMatrix)obj;
    if (_dim != m._dim) {
     return false;
    }
    long[][]matrix1 = m._matrix;
    for (int i = 0; i < _dim; i++) {
     for (int j = 0; j < _dim; j++) {
      if (_matrix[i][j] != matrix1[i][j]) {
       return false;
      }
     }
    }
    return true;
   } else {
    return false;
   }
  }
 }
 
 //
 // SimpleMatrix
 //
 public static class SimpleMatrix extends AbstractMatrix {
  public SimpleMatrix(long[][] matrix) {
   super(matrix);
  }
  
  public SimpleMatrix mult(SimpleMatrix m) {
   long[][] m0 = _matrix;
   long[][] m1 = m._matrix;
   long[][] m2 = new long[_dim][_dim];
   for (int i = 0; i < _dim; i++) {
    for (int j = 0; j < _dim; j++) {
     int v = 0;
     for (int k = 0; k < _dim; k++) {
      v += m0[i][k]*m1[k][j];
     }
     m2[i][j] = v;
    }
   }
   return new SimpleMatrix(m2);
  }
 }

 //
 // RecMatrix
 //
 public static class RecMatrix extends AbstractMatrix {
  public RecMatrix(long[][] matrix) {
   super(matrix);
  }
  
  public static RecMatrix randomMatrix(int dim) {
   long[][] matrix = _randomMatrix(dim);
   return new RecMatrix(matrix);
  }

  public RecMatrix mult(RecMatrix m) {
   RecMatrix m2 = new RecMatrix(new long[_dim][_dim]);
   m2.mult0(_dim, _matrix, 0, 0, m._matrix, 0, 0);
   return m2;
  }
  public void mult0(final int dim, final long[][] m1, final int row_index1, final int column_index1, final long[][] m2, final int row_index2, final int column_index2) {
   if (dim == 2) {
    for (int i = 0; i < 2; i++) {
     int i1 = row_index1 | i;
     for (int j = 0; j < 2; j++) {
      int j2 = column_index2 | j;
      long v = 0;
      for (int k = 0; k < 2; k++) {
       int j1 = column_index1 | k;
       int i2 = row_index2 | k;
       v += m1[i1][j1]*m2[i2][j2];
      }
      _matrix[i1][j2] += v;
     }
    }
   } else if (dim == _dim) {
    Thread[] threads = new Thread[4];
    int idx = 0;
    final int dim0 = dim >> 1;
    for (int i = 0; i < 2; i++) {
     final int r_idx1 = (i == 0) ? row_index1 : (row_index1 | dim0);
     for (int j = 0; j < 2; j++) {
      final int c_idx2 = (j == 0) ? column_index2: (column_index2 | dim0);
      threads[idx] = new Thread(new Runnable() {
       public void run() {
              for (int k = 0; k < 2; k++) {
         final int c_idx1 = (k == 0) ? column_index1: (column_index1 | dim0);
         final int r_idx2 = (k == 0) ? row_index2 : (row_index2 | dim0);
         mult0(dim0, m1, r_idx1, c_idx1, m2, r_idx2, c_idx2);
              }
          }
      });
      threads[idx++].start();
     }
    }
    for (int i = 0; i < threads.length; i++) {
        try {
            threads[i].join();
        } catch (InterruptedException e) {
        }
    }
   } else {
    final int dim0 = dim >> 1;
    for (int i = 0; i < 2; i++) {
     final int r_idx1 = (i == 0) ? row_index1 : (row_index1 | dim0);
     for (int j = 0; j < 2; j++) {
      final int c_idx2 = (j == 0) ? column_index2: (column_index2 | dim0);
      for (int k = 0; k < 2; k++) {
       final int c_idx1 = (k == 0) ? column_index1: (column_index1 | dim0);
       final int r_idx2 = (k == 0) ? row_index2 : (row_index2 | dim0);
       mult0(dim0, m1, r_idx1, c_idx1, m2, r_idx2, c_idx2);
      }
     }
    }
   }
  }
 }
 
 //
 // Tests
 //
 
 // simple_matrix_test1
 public static SimpleMatrix simple_matrix_test1(long[][] matrx1, long[][] matrx2, boolean show) {
  SimpleMatrix m1 = new SimpleMatrix(matrx1);
  SimpleMatrix m2 = new SimpleMatrix(matrx2);
  SimpleMatrix m3 = m1.mult(m2);
  if (show) {
   System.out.println("m1: "+m1);
   System.out.println("m2: "+m2);
   System.out.println("m3: "+m3);
  }
  return m3;
 }
 
 public static long simple_matrix_test1(int dim) {
  long[][] matrx1 = _randomMatrix(dim);
  long[][] matrx2 = _randomMatrix(dim);
  
  simple_matrix_test1(matrx1, matrx2, false);

  long end = System.currentTimeMillis();
  return end;
 }
 public static void simple_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
   dim = dim*2;
   long startTime = System.currentTimeMillis();
   long endTime = simple_matrix_test1(dim);
   System.out.println(">> dim: "+dim+", elapsed time: "+(endTime - startTime)+" milli sec");
  }
 }
 
 // rec_matrix_test1
 public static RecMatrix rec_matrix_test1(long[][] matrx1, long[][] matrx2, boolean show) {
  RecMatrix rm1 = new RecMatrix(matrx1);
  RecMatrix rm2 = new RecMatrix(matrx2);
  RecMatrix rm3 = rm1.mult(rm2);
  if (show) {
   System.out.println("m1: "+rm1);
   System.out.println("m2: "+rm2);
   System.out.println("m3: "+rm3);
  }
  return rm3;
 }
 
 public static long rec_matrix_test1(int dim) {
  long[][] matrx1 = _randomMatrix(dim);
  long[][] matrx2 = _randomMatrix(dim);
  
  rec_matrix_test1(matrx1, matrx2, false);

  long end = System.currentTimeMillis();
  return end;
 }
 
 private static String formatBigDecimal(BigDecimal bd){
     DecimalFormat df = new DecimalFormat();
     return df.format(bd);
 }
 
 public static void rec_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
   dim = dim*2;
   long startTime = System.currentTimeMillis();
   long endTime = rec_matrix_test1(dim);
   double duration = endTime - startTime;
   BigDecimal b_duration = new BigDecimal(duration);
   BigDecimal dim_p3=new BigDecimal(dim).multiply(new BigDecimal(dim)).multiply(new BigDecimal(dim));
   BigDecimal b_r = b_duration.divide(new BigDecimal(dim)).divide(new BigDecimal(dim)).divide(new BigDecimal(dim)).multiply(new BigDecimal(1000000));
   System.out.println(">> dim: "+dim+", elapsed time: "+duration+" milli sec, dim^p; "+formatBigDecimal(dim_p3)+", ration(time/dim^3)="+formatBigDecimal(b_r));
   //System.out.println(">> dim: "+dim+", elapsed time: "+duration+" milli sec, dim^3: "+dim_p3);
  }
 }
 
 // verify_matrix_test1
 public static void verify_matrix_test1(int dim, boolean show) {
  long[][] matrx1 = _randomMatrix(dim);
  long[][] matrx2 = _randomMatrix(dim);
  
  SimpleMatrix sm3 = simple_matrix_test1(matrx1, matrx2, show);
  RecMatrix rm3 = rec_matrix_test1(matrx1, matrx2, show);
  SimpleMatrix srm3 = new SimpleMatrix(rm3._matrix);
  if (!sm3.equals(srm3)) {
   System.out.println("!!verify_matrix_test1: not equals ");
   System.out.println("m1: "+new SimpleMatrix(matrx1));
   System.out.println("m2: "+new SimpleMatrix(matrx2));
   System.out.println("sm3: "+sm3);
   System.out.println("srm3:"+srm3);
   throw new RuntimeException();
  } else if (show) {
   System.out.println("m1: "+new SimpleMatrix(matrx1));
   System.out.println("m2: "+new SimpleMatrix(matrx2));
   System.out.println("m3: "+sm3);
  }
 }
 
 public static void verify_test_loop(int max, boolean show) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
   dim = dim*2;
   verify_matrix_test1(dim, show);
  }
 }
 
 public static void test1(int max) {
  /*
  System.out.println("verify_test_loop: "+max);
  verify_test_loop(max, false);
  
  System.out.println("simple_test_loop: "+max);
  simple_test_loop(max);
  */
  System.out.println("rec_test_loop: "+max);
  rec_test_loop(max);
 }
 
 public static void main(String[] args) {
  int max = 12;
  test1(max);
  System.out.println(">> done");
  //verify_matrix_test1(2);
  //verify_matrix_test1(4);
  //verify_test_loop(max);
  
  //simple_test_loop(max);
  
  //rec_matrix_test1(4);
  //rec_test_loop(max);
 }
}

No comments:

Post a Comment