Wednesday, November 20, 2013

Dart vs Java performance comparison with Recursive Maxtrix Multiplication (Part 3)

I ported the last Java implementation of recursive matrix multiplication to Dart.
It is interesting to see the actual performance difference of two languages.
Such comparison will depend on the type of application and level of expected performance.

As we have seen, Java allows considerable fast implementation of recursive matrix multiplication. It is interesting to compare the performance with C++.
But definitely its 1.6 Gaga flop(1 billion multiplication in 0.6 sec) is definitely assembler level performance.

Following is the result of the performance measurement of Dart.
From this, Dart is about 4 times slower than Java if we compare with 1 thread version of Java. but if we compare with 4 threads(4 cores) version of Java, it is about 13 times slower.
So 4 times difference is in a sense close to Java. but Dart does not allow to share memory among isolates, so even we use isolate, it will be more costly to copy matrix elements, but still it is interesting to see the performance using  isolates.
Even Dart does not support array, if the List representation internally uses array(by delegating to external code), it will not cause major performance difference. Other reason will be the representation of int. We may have similar result in Java if we replace long by BigInteger etc.

[Dart]
rec_test_loop: 12
>> dim: 2, elapsed time: 5 milli sec, dim^3; 8, ration(time/dim^3)=625000.0
>> dim: 4, elapsed time: 0 milli sec, dim^3; 64, ration(time/dim^3)=0.0
>> dim: 8, elapsed time: 1 milli sec, dim^3; 512, ration(time/dim^3)=1953.125
>> dim: 16, elapsed time: 32 milli sec, dim^3; 4096, ration(time/dim^3)=7812.5
>> dim: 32, elapsed time: 4 milli sec, dim^3; 32768, ration(time/dim^3)=122.0703125
>> dim: 64, elapsed time: 5 milli sec, dim^3; 262144, ration(time/dim^3)=19.073486328125
>> dim: 128, elapsed time: 16 milli sec, dim^3; 2097152, ration(time/dim^3)=7.62939453125
>> dim: 256, elapsed time: 119 milli sec, dim^3; 16777216, ration(time/dim^3)=7.092952728271484
>> dim: 512, elapsed time: 1239 milli sec, dim^3; 134217728, ration(time/dim^3)=9.231269359588623
>> dim: 1024, elapsed time: 8027 milli sec, dim^3; 1073741824, ration(time/dim^3)=7.475726306438446
>> dim: 2048, elapsed time: 60587 milli sec, dim^3; 8589934592, ration(time/dim^3)=7.0532551035285
>> dim: 4096, elapsed time: 484160 milli sec, dim^3; 68719476736, ration(time/dim^3)=7.045455276966095

[Java with single thread]
>> dim: 2, elapsed time: 1 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 0 milli sec
>> dim: 16, elapsed time: 0 milli sec
>> dim: 32, elapsed time: 5 milli sec
>> dim: 64, elapsed time: 15 milli sec
>> dim: 128, elapsed time: 5 milli sec
>> dim: 256, elapsed time: 30 milli sec
>> dim: 512, elapsed time: 217 milli sec
>> dim: 1024, elapsed time: 1915 milli sec
>> dim: 2048, elapsed time: 14691 milli sec
>> dim: 4096, elapsed time: 115860 milli sec

[Java with 4 threads(4 cores)]
>> dim: 2, elapsed time: 0.0 milli sec, dim^p; 8, ration(time/dim^3)=0
>> dim: 4, elapsed time: 0.0 milli sec, dim^p; 64, ration(time/dim^3)=0
>> dim: 8, elapsed time: 1.0 milli sec, dim^p; 512, ration(time/dim^3)=1,953.125
>> dim: 16, elapsed time: 1.0 milli sec, dim^p; 4,096, ration(time/dim^3)=244.141
>> dim: 32, elapsed time: 3.0 milli sec, dim^p; 32,768, ration(time/dim^3)=91.553
>> dim: 64, elapsed time: 17.0 milli sec, dim^p; 262,144, ration(time/dim^3)=64.85
>> dim: 128, elapsed time: 4.0 milli sec, dim^p; 2,097,152, ration(time/dim^3)=1.907
>> dim: 256, elapsed time: 18.0 milli sec, dim^p; 16,777,216, ration(time/dim^3)=1.073
>> dim: 512, elapsed time: 105.0 milli sec, dim^p; 134,217,728, ration(time/dim^3)=0.782
>> dim: 1024, elapsed time: 602.0 milli sec, dim^p; 1,073,741,824, ration(time/dim^3)=0.561
>> dim: 2048, elapsed time: 4771.0 milli sec, dim^p; 8,589,934,592, ration(time/dim^3)=0.555
>> dim: 4096, elapsed time: 38704.0 milli sec, dim^p; 68,719,476,736, ration(time/dim^3)=0.563

Here is again a puzzling result for the performance degradation of Simple Matrix multiplication in Dart. until dim is 512, simple matrix is faster than recursive matrix, but after that, simple matrix become slower than recursive version, and at 4096, it is 3.3 times slower.

[Simple Matrix]
>> dim: 2, elapsed time: 4 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 0 milli sec
>> dim: 16, elapsed time: 4 milli sec
>> dim: 32, elapsed time: 6 milli sec
>> dim: 64, elapsed time: 2 milli sec
>> dim: 128, elapsed time: 7 milli sec
>> dim: 256, elapsed time: 55 milli sec
>> dim: 512, elapsed time: 483 milli sec
>> dim: 1024, elapsed time: 12269 milli sec
>> dim: 2048, elapsed time: 118535 milli sec
>> dim: 4096, elapsed time: 1597558 milli sec

[Recursive Matrix]
>> dim: 2, elapsed time: 5 milli sec, dim^3; 8, ration(time/dim^3)=625000.0
>> dim: 4, elapsed time: 0 milli sec, dim^3; 64, ration(time/dim^3)=0.0
>> dim: 8, elapsed time: 1 milli sec, dim^3; 512, ration(time/dim^3)=1953.125
>> dim: 16, elapsed time: 32 milli sec, dim^3; 4096, ration(time/dim^3)=7812.5
>> dim: 32, elapsed time: 4 milli sec, dim^3; 32768, ration(time/dim^3)=122.0703125
>> dim: 64, elapsed time: 5 milli sec, dim^3; 262144, ration(time/dim^3)=19.073486328125
>> dim: 128, elapsed time: 16 milli sec, dim^3; 2097152, ration(time/dim^3)=7.62939453125
>> dim: 256, elapsed time: 119 milli sec, dim^3; 16777216, ration(time/dim^3)=7.092952728271484
>> dim: 512, elapsed time: 1239 milli sec, dim^3; 134217728, ration(time/dim^3)=9.231269359588623
>> dim: 1024, elapsed time: 8027 milli sec, dim^3; 1073741824, ration(time/dim^3)=7.475726306438446
>> dim: 2048, elapsed time: 60587 milli sec, dim^3; 8589934592, ration(time/dim^3)=7.0532551035285
>> dim: 4096, elapsed time: 484160 milli sec, dim^3; 68719476736, ration(time/dim^3)=7.045455276966095

Following is the Dart code used for this benchmark test.

library rec_matrix_v6;

import "dart:math";

Random rand = new Random(); 

List<List<int>> _randomMatrix(int dim) {
  //final List<List<int>> matrix = new int[dim][dim];
  final List<List<int>> matrix = new List(dim);
  for (int i = 0; i < dim; i++) {
    matrix[i] = new List(dim);
    for (int j = 0; j < dim; j++) {
      matrix[i][j] = rand.nextInt(50);
    }
  }
  return matrix;
}
List<List<int>> _zeroMatrix(int dim) {
  final List<List<int>> matrix = new List(dim);
  for (int i = 0; i < dim; i++) {
    matrix[i] = new List(dim);
    for (int j = 0; j < dim; j++) {
      matrix[i][j] = 0;
    }
  }
  return matrix;
}

//
// AbstractMatrix
//
abstract class AbstractMatrix {
  int _dim;
  List<List<int>> _matrix;
  
  AbstractMatrix(List<List<int>> matrix) {
    this._dim = matrix.length;
    this._matrix = matrix;
  }
  
  int dim() {
    return _dim;
  }

  SimpleMatrix add(SimpleMatrix m) {
    List<List<int>> m0 = _matrix;
    List<List<int>> m1 = m._matrix;
    List<List<int>> m2 = new List(_dim);
    for (int i = 0; i < _dim; i++) {
      m2[i] = new List(_dim);
      for (int j = 0; j < _dim; j++) {
        m2[i][j] = m0[i][j]+m1[i][j];
      }
    }
    return new SimpleMatrix(m2);
  }

  SimpleMatrix subtract(SimpleMatrix m) {
    List<List<int>> m0 = _matrix;
    List<List<int>> m1 = m._matrix;
    List<List<int>> m2 = new List(_dim);
    for (int i = 0; i < _dim; i++) {
      m2[i] = new List(_dim);
      for (int j = 0; j < _dim; j++) {
        m2[i][j] = m0[i][j]-m1[i][j];
      }
    }
    return new SimpleMatrix(m2);
  }
  
  String toString() {
    return toString1(_matrix);
  }
  
  String toString1(List<List<int>> matrix) {
    int dim = matrix.length;
    StringBuffer sb = new StringBuffer();
    sb.write("(\n");
    for (int i = 0; i < dim; i++) {
      for (int j = 0; j < dim; j++) {
        int v = matrix[i][j];
        if (j != 0) {
          sb.write(", ");
        }
        sb.write(v);
      }
      sb.write("\n");
    }
    sb.write(")\n");
    return sb.toString();
  }
  
  bool equals(Object obj) {
    if (obj is SimpleMatrix) {
      SimpleMatrix m = obj;
      if (_dim != m._dim) {
        return false;
      }
      List<List<int>>matrix1 = m._matrix;
      for (int i = 0; i < _dim; i++) {
        for (int j = 0; j < _dim; j++) {
          if (_matrix[i][j] != matrix1[i][j]) {
            return false;
          }
        }
      }
      return true;
    } else {
      return false;
    }
  }
}

//
// SimpleMatrix
//
class SimpleMatrix extends AbstractMatrix {
  SimpleMatrix(List<List<int>> matrix): super(matrix) {}
  
  SimpleMatrix mult(SimpleMatrix m) {
    List<List<int>> m0 = _matrix;
    List<List<int>> m1 = m._matrix;
    List<List<int>> m2 = new List(_dim);
    for (int i = 0; i < _dim; i++) {
      m2[i] = new List(_dim);
      for (int j = 0; j < _dim; j++) {
        int v = 0;
        for (int k = 0; k < _dim; k++) {
          v += m0[i][k]*m1[k][j];
        }
        m2[i][j] = v;
      }
    }
    return new SimpleMatrix(m2);
  }
}

//
// RecMatrix
//
class RecMatrix extends AbstractMatrix {
  RecMatrix(List<List<int>> matrix): super(matrix) {}
  
  static RecMatrix randomMatrix(int dim) {
    List<List<int>> matrix = _randomMatrix(dim);
    return new RecMatrix(matrix);
  }

  RecMatrix mult(RecMatrix m) {
    RecMatrix m2 = new RecMatrix(_zeroMatrix(_dim));
    m2.mult0(_dim, _matrix, 0, 0, m._matrix, 0, 0);
    return m2;
  }
  
  void mult0(final int dim, final List<List<int>> m1, final int row_index1, final int column_index1, final List<List<int>> m2, final int row_index2, final int column_index2) {
    if (dim == 2) {
      for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
          int i1 = row_index1 | i;
          int j2 = column_index2 | j;
          int v = 0;
          for (int k = 0; k < 2; k++) {
            int j1 = column_index1 | k;
            int i2 = row_index2 | k;
            v += m1[i1][j1]*m2[i2][j2];
          }
          _matrix[i1][j2] += v;
        }
      }
    } else {
      final int dim0 = dim >> 1;
      for (int i = 0; i < 2; i++) {
        final int r_idx1 = (i == 0) ? row_index1 : (row_index1 | dim0);
        for (int j = 0; j < 2; j++) {
          final int c_idx2 = (j == 0) ? column_index2: (column_index2 | dim0);
          for (int k = 0; k < 2; k++) {
            final int c_idx1 = (k == 0) ? column_index1: (column_index1 | dim0);
            final int r_idx2 = (k == 0) ? row_index2 : (row_index2 | dim0);
            mult0(dim0, m1, r_idx1, c_idx1, m2, r_idx2, c_idx2);
          }
        }
      }
    }
  }
}

//
// Tests
//

// simple_matrix_test1
SimpleMatrix simple_matrix_test1(List<List<int>> matrx1, List<List<int>> matrx2, bool show) {
  SimpleMatrix m1 = new SimpleMatrix(matrx1);
  SimpleMatrix m2 = new SimpleMatrix(matrx2);
  SimpleMatrix m3 = m1.mult(m2);
  if (show) {
    print("m1: ${m1}");
    print("m2: ${m2}");
    print("m3: ${m3}");
  }
  return m3;
}

DateTime simple_matrix_test2(int dim) {
  List<List<int>> matrx1 = _randomMatrix(dim);
  List<List<int>> matrx2 = _randomMatrix(dim);
  
  simple_matrix_test1(matrx1, matrx2, false);

  return new DateTime.now();
}

void simple_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
    dim = dim*2;
    DateTime startTime = new DateTime.now();
    DateTime endTime = simple_matrix_test2(dim);
    Duration du = endTime.difference(startTime);
    print(">> dim: ${dim}, elapsed time: ${du.inMilliseconds} milli sec");
  }
}

// rec_matrix_test1
RecMatrix rec_matrix_test1(List<List<int>> matrx1, List<List<int>> matrx2, bool show) {
  RecMatrix rm1 = new RecMatrix(matrx1);
  RecMatrix rm2 = new RecMatrix(matrx2);
  RecMatrix rm3 = rm1.mult(rm2);
  if (show) {
    print("m1: ${rm1}");
    print("m2: ${rm2}");
    print("m3: ${rm3}");
  }
  return rm3;
}

DateTime rec_matrix_test2(int dim) {
  List<List<int>> matrx1 = _randomMatrix(dim);
  List<List<int>> matrx2 = _randomMatrix(dim);
  
  rec_matrix_test1(matrx1, matrx2, false);
  return new DateTime.now();
}

void rec_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
    dim = dim*2;
    DateTime startTime = new DateTime.now();
    DateTime endTime = rec_matrix_test2(dim);
    Duration duration = endTime.difference(startTime);
    int b_duration = duration.inMilliseconds;
    int dim_p3 = dim*dim*dim;
    double b_r = (b_duration/dim_p3)*1000000;
    print(">> dim: ${dim}, elapsed time: ${b_duration} milli sec, dim^3; ${dim_p3}, ration(time/dim^3)=${b_r}");
  }
}

// verify_matrix_test1
void verify_matrix_test1(int dim, bool show) {
  List<List<int>> matrx1 = _randomMatrix(dim);
  List<List<int>> matrx2 = _randomMatrix(dim);
  
  SimpleMatrix sm3 = simple_matrix_test1(matrx1, matrx2, show);
  RecMatrix rm3 = rec_matrix_test1(matrx1, matrx2, show);
  SimpleMatrix srm3 = new SimpleMatrix(rm3._matrix);
  if (!sm3.equals(srm3)) {
    print("!!verify_matrix_test1: not equals ");
    print("m1: ${new SimpleMatrix(matrx1)}");
    print("m2: ${new SimpleMatrix(matrx2)}");
    print("sm3: ${sm3}");
    print("srm3:${srm3}");
    throw new Exception();
  } else if (show) {
    print("m1: ${new SimpleMatrix(matrx1)}");
    print("m2: ${new SimpleMatrix(matrx2)}");
    print("m3: ${sm3}");
  }
}

void verify_test_loop(int max, bool show) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
    dim = dim*2;
    verify_matrix_test1(dim, show);
  }
}

void test1(int max) {
  /*
  print("verify_test_loop: "+max);
  verify_test_loop(max, false);
  
  print("simple_test_loop: "+max);
  simple_test_loop(max);
   */
  print("rec_test_loop: ${max}");
  rec_test_loop(max);
}

void main() {
  int max = 12;
  test1(max);
  print(">> done");
  //verify_matrix_test1(2);
  //verify_matrix_test1(4);
  //verify_test_loop(max);
  
  //simple_test_loop(max);
  
  //rec_matrix_test1(4);
  //rec_test_loop(max);
}

2 comments:

  1. Nice series of posts, but advice for the next time - don't include entire code in post. Use gists or something similar. Short snippets would be fine too. :)

    ReplyDelete
    Replies
    1. thanks for the comment.
      I'm using git and google drive whenner I think it is useful.
      For this kind of algorithm discussion, checking the code is most important part of blog. Indeed, I did not put much comments on it, so your impression is understandable though. in addition it will be easiest to test by just copy and past, since this is working whole code.

      Delete