Tuesday, November 19, 2013

Java implemetation of Recursive Matrix Multiplication Algorithm (Part 1)

I tired to make some test case which can be used to evaluate language's capability.
Also I was interested in some concurrent nature of recursive matrix algorithm, I decided to use it as such sample.
But this subject itself is interesting and has  long history and many research on it.

The first time I found the algorithm was from a paper for open source hardware:

http://www.adapteva.com/wp-content/uploads/2013/07/adapteva_ipdps.pdf

http://www.parallella.org/board/

But later I recognized the idea of recursive matrix multiplication are rooted to Strassen's fast matrix multiplication which allows to multiply with O(n^(log_2 7)) versus ordinary O(n^3). (see a book, Modern Computer Algebra)  

But these days, rather than reducing number of multiplication from 8 to 7, it would be more effective if we can employ as many as multi cores.

And this algorithm provides a quite clean separation of jobs to delegate to multi cores. That is something more interesting aspect of the algorithm and actually more effective than reducing the number of multiplications(which may provide 20% improvement).

Originally I was considering writing this for Go, but in fact, Go's syntax is not comfortable to me, so I decided to do it on Java first.
In a sense, the key point to implement  fast  algorithm on Java is to avoid creating intermediate objects. since eventually GC took the most of time than reduced calculation time.
So what we need to do is to write an objectless object oriented program.
I think this is some interesting subject, and it is a bit related Go language object oriented style.

The first implementation in Java was really object oriented. Although I could use more extreme style to avoid casting, this level of implementation will be most common and easy to compare with other language where generics is not so  strong(or not supported). But Later I may add the codes of such an extreme version as well.

This first implementation is following the idea of original algorithm closely.
Namely we treate inner matrix block as element of coefficient ring of the matrix. (2X2 matrix over coefficient ring R).


package org.zen;

import java.util.Random;

public class Matrix_v3 {


 public static interface IRingClass {
  IRing create();
  
  IRing zero();

  IRing unit();
 }

 public static interface IRing {
  IRing add(IRing r);

  IRing subtract(IRing r);

  IRing mult(IRing r);
 }

 public static interface IMatrix extends IRing {
  int dim();

  IRing get(int row, int column);

  void set(int row, int column, IRing r);
 }

 public static class IntRingClass implements IRingClass {
  static final IRing zero = new IntRing(0);
  static final IRing unit = new IntRing(1);

  @Override
  public IRing create() {
   return new IntRing(0);
  }

  public IRing create(int i) {
   return new IntRing(i);
  }

  @Override
  public IRing zero() {
   return zero;
  }

  @Override
  public IRing unit() {
   return unit;
  }
 }

 public static class IntRing implements IRing {
  final int i;

  public IntRing(int i) {
   this.i = i;
  }

  @Override
  public IRing add(IRing r) {
   IntRing ir = (IntRing) r;
   return new IntRing(i + ir.i);
  }

  @Override
  public IRing subtract(IRing r) {
   IntRing ir = (IntRing) r;
   return new IntRing(i - ir.i);
  }

  @Override
  public IRing mult(IRing r) {
   IntRing ir = (IntRing) r;
   return new IntRing(i * ir.i);
  }
  
  @Override
  public String toString() {
   return ""+i;
  }
 }

 public static abstract class AbstractMatrixClass implements IRingClass {
  protected final IRingClass coefficientRing;
  protected final int dim;

  protected AbstractMatrix zero = null;
  protected AbstractMatrix unit = null;

  public AbstractMatrixClass(IRingClass coefficientRing, int dim) {
   this.coefficientRing = coefficientRing;
   this.dim = dim;
  }
  
  void init() {
   this.zero = (AbstractMatrix)create();
   this.unit = (AbstractMatrix)create();
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     zero.set(i, j, coefficientRing.zero());
     unit.set(i, j, (i == j) ? coefficientRing.unit(): coefficientRing.zero());
    }
   }
  }

  @Override
  public IRing zero() {
   if (zero == null) {
    init();
   }
   return zero;
  }

  @Override
  public IRing unit() {
   if (unit == null) {
    init();
   }
   return unit;
  }

 }

 public static abstract class AbstractMatrix implements IMatrix {
  
  final AbstractMatrixClass _matrixClass;
  
  public AbstractMatrix(final AbstractMatrixClass _matrixClass) {
   this._matrixClass = _matrixClass;
  }

  @Override
  public int dim() {
   return _matrixClass.dim;
  }

  @Override
  public IRing add(IRing m) {
   AbstractMatrix m0 = (AbstractMatrix) this;
   AbstractMatrix m1 = (AbstractMatrix) m;
   AbstractMatrix m2 = (AbstractMatrix)_matrixClass.create();
   for (int i = 0; i < dim(); i++) {
    for (int j = 0; j < dim(); j++) {
     IRing r = m0.get(i, j).add(m1.get(i, j));
     m2.set(i, j, r);
    }
   }
   return m2;
  }

  @Override
  public IRing subtract(IRing m) {
   AbstractMatrix m0 = (AbstractMatrix) this;
   AbstractMatrix m1 = (AbstractMatrix) m;
   AbstractMatrix m2 = (AbstractMatrix)_matrixClass.create();
   for (int i = 0; i < dim(); i++) {
    for (int j = 0; j < dim(); j++) {
     IRing r = m0.get(i, j).subtract(m1.get(i, j));
     m2.set(i, j, r);
    }
   }
   return m2;
  }

  @Override
  public IRing mult(IRing m) {
   AbstractMatrix m0 = (AbstractMatrix) this;
   AbstractMatrix m1 = (AbstractMatrix) m;
   AbstractMatrix m2 = (AbstractMatrix)_matrixClass.create();
   for (int i = 0; i < dim(); i++) {
    for (int j = 0; j < dim(); j++) {
     IRing r = _matrixClass.coefficientRing.zero();
     for (int k = 0; k < dim(); k++) {
      r = r.add(m0.get(i, k).mult(m1.get(k, j)));
     }
     m2.set(i, j, r);
    }
   }
   return m2;
  }

  protected void rangeCheck(int row, int column) {
   if (row < 0 || row >= dim()) {
    throw new RuntimeException();
   }
   if (column < 0 || column >= dim()) {
    throw new RuntimeException();
   }
  }
 }

 public static class SimpleMatrixClass extends AbstractMatrixClass {
  
  public SimpleMatrixClass(final IRingClass coefficientRing, int dim) {
   super(coefficientRing, dim);
  }

  @Override
  public AbstractMatrix create() {
   return new SimpleMatrix(this);
  }
  
  public static SimpleMatrix create(int[][] matrix) {
   int dim = matrix.length;
   SimpleMatrixClass simpeMatrixClass = new SimpleMatrixClass(new IntRingClass(), dim);
   SimpleMatrix sm = new SimpleMatrix(simpeMatrixClass);
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     sm.set(i, j, new IntRing(matrix[i][j]));
    }
   }
   return sm;
  }
  
  public SimpleMatrix randomMatrix() {
   int[][] matrix = _randomMatrix(dim);
   SimpleMatrix sm = new SimpleMatrix(this);
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     sm.set(i, j, new IntRing(matrix[i][j]));
    }
   }
   return sm;
  }  
 }
 
 public static class SimpleMatrix extends AbstractMatrix {
  protected final IRing[][] _matrix;

  public SimpleMatrix(final AbstractMatrixClass _matrixClass) {
   super(_matrixClass);
   _matrix = new IRing[_matrixClass.dim][_matrixClass.dim];
   for (int i = 0; i < _matrixClass.dim; i++) {
    for (int j = 0; j < _matrixClass.dim; j++) {
     _matrix[i][j] = _matrixClass.coefficientRing.create();
    }
   }
  }
  public SimpleMatrix(final AbstractMatrixClass _matrixClass, IRing[][] _matrix) {
   super(_matrixClass);
   this._matrix = _matrix;
  }

  @Override
  public IRing get(int row, int column) {
   rangeCheck(row, column);
   return _matrix[row][column];
  }

  @Override
  public void set(int row, int column, IRing r) {
   rangeCheck(row, column);
   _matrix[row][column] = r;
  }
  
  @Override
  public String toString() {
   int dim = _matrix.length;
   StringBuilder sb = new StringBuilder();
   sb.append("(\n");
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     String s = _matrix[i][j].toString();
     if (j != 0) {
      sb.append(", ");
     }
     sb.append(s);
    }
    sb.append("\n");
   }
   sb.append(")\n");
   return sb.toString();
  }
  
  @Override
  public boolean equals(Object obj) {
   if (obj instanceof SimpleMatrix) {
    SimpleMatrix sm = (SimpleMatrix)obj;
    int dim = sm.dim();
    if (dim() != dim) {
     return false;
    }
    for (int i = 0; i < dim; i++) {
     for (int j = 0; j < dim; j++) {
      if (_matrix[i][j] != sm._matrix[i][j]) return false;
     }
    }
    return true;
   } else {
    return false;
   }
  }
 }

 static Random rand = new Random(); 
 static int[][] _randomMatrix(int dim) {
  final int[][] matrix = new int[dim][dim];
  for (int i = 0; i < dim; i++) {
   for (int j = 0; j < dim; j++) {
    matrix[i][j] = rand.nextInt(50);
   }
  }
  return matrix;
 }
 static int[][] _zeroMatrix(int dim) {
  final int[][] matrix = new int[dim][dim];
  for (int i = 0; i < dim; i++) {
   for (int j = 0; j < dim; j++) {
    matrix[i][j] = 0;
   }
  }
  return matrix;
 }
 
 public static class RecMatrixClass extends SimpleMatrixClass {
  int flat_dim;
  public RecMatrixClass(int flat_dim) {
   super((flat_dim == 2)?new IntRingClass():new RecMatrixClass(flat_dim >> 1), 2);
   this.flat_dim = flat_dim;
   if (flat_dim < 2) {
    throw new RuntimeException(">> RecMatrixClass flat_dim: "+flat_dim);
   }
  }
  
  public RecMatrix randomMatrix() {
   return new RecMatrix(this, _randomMatrix(flat_dim), flat_dim, 0, 0);
  }  
  
  @Override
  public AbstractMatrix create() {
   return new RecMatrix(this, new int[flat_dim][flat_dim], flat_dim, 0, 0);
  }

  public AbstractMatrix create(int[][] matrix, int dim, int row_index, int column_index) {
   return new RecMatrix(this, matrix, dim, row_index, column_index);
  }

 }
 
 public static class RecMatrix extends SimpleMatrix {
  public RecMatrix(final AbstractMatrixClass _matrixClass, int[][] matrix, int dim, int row_index, int column_index) {
   super(_matrixClass);
   
   if (_matrixClass.coefficientRing instanceof SimpleMatrixClass) {
    RecMatrixClass rmCls = (RecMatrixClass)_matrixClass.coefficientRing;
    for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
      final int dim0 = dim >> 1;
      final int row_index0 = (i == 0)?row_index:row_index|dim0;
      final int column_index0 = (j == 0)?column_index:column_index|dim0;
      set(i, j, rmCls.create(matrix, dim0, row_index0, column_index0));
     }
    }
   } else if (_matrixClass.coefficientRing instanceof IntRingClass) {
    IntRingClass irCls = (IntRingClass)_matrixClass.coefficientRing;
    for (int i = 0; i < 2; i++) {
     for (int j = 0; j < 2; j++) {
      set(i, j, irCls.create(matrix[row_index|i][column_index|j]));
     }
    }
   }
  }

  public String toString() {
   int[][] rep = getMatrixRep();
   int dim = rep.length;
   StringBuilder sb = new StringBuilder();
   sb.append("(\n");
   for (int i = 0; i < dim; i++) {
    for (int j = 0; j < dim; j++) {
     int v = rep[i][j];
     if (j != 0) {
      sb.append(", ");
     }
     sb.append(v);
    }
    sb.append("\n");
   }
   sb.append(")\n");
   return sb.toString();
  }
  
  public int[][] getMatrixRep() {
   int flat_dim = ((RecMatrixClass)_matrixClass).flat_dim;
   int[][] matrix = new int[flat_dim][flat_dim];
   setMatrixRep(matrix, flat_dim, 0, 0);
   return matrix;
  }
  
  void setMatrixRep(int[][] matrix, int dim, int row_index, int column_index) {
   for (int i = 0; i < _matrixClass.dim; i++) {
    for (int j = 0; j < _matrixClass.dim; j++) {
     IRing ring = get(i, j);
     if (ring instanceof RecMatrix) {
      final int dim0 = dim >> 1;
      final int row_index0 = (i == 0)?row_index:row_index|dim0;
      final int column_index0 = (j == 0)?column_index:column_index|dim0;
      ((RecMatrix)ring).setMatrixRep(matrix, dim0, row_index0, column_index0);
     } else if (ring instanceof IntRing) {
      IntRing intRing = (IntRing)ring;
      matrix[row_index+i][column_index+j] = intRing.i;
     } else {
      throw new RuntimeException(">>  setMatrixRep unsupported ring: "+ring);
     }
    }
   }
  }
 }
 
 //
 // Tests
 //
 public static long simple_matrix_test1(int dim) {
  SimpleMatrixClass simleMatrixClass = new SimpleMatrixClass(new IntRingClass(), dim);
  SimpleMatrix sm1 = simleMatrixClass.randomMatrix();
  SimpleMatrix sm2 = simleMatrixClass.randomMatrix();
  SimpleMatrix sm3 = (SimpleMatrix)sm1.mult(sm2);
  
  long end = System.currentTimeMillis();
  /*
  System.out.println("sm1: "+sm1);
  System.out.println("sm2: "+sm2);
  System.out.println("sm3: "+sm3);
  */
  return end;
 }
 public static long rec_matrix_test1(int dim) {
  RecMatrixClass recMatrixClass = new RecMatrixClass(dim);
  RecMatrix rm1 = recMatrixClass.randomMatrix();
  RecMatrix rm2 = recMatrixClass.randomMatrix();
  RecMatrix rm3 = (RecMatrix)rm1.mult(rm2);
  
  long end = System.currentTimeMillis();
  /*
  System.out.println("rm1: "+rm1);
  System.out.println("rm2: "+rm2);
  System.out.println("rm3: "+rm3);
  */
  return end;
 }
 
 public static void simple_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
   dim = dim*2;
   long startTime = System.currentTimeMillis();
   long endTime = simple_matrix_test1(dim);
   System.out.println(">> dim: "+dim+", elapsed time: "+(endTime - startTime)+" milli sec");
  }
 }
 public static void rec_test_loop(int max) {
  int dim = 1;
  for (int i = 1; i <= max; i++) {
   dim = dim*2;
   long startTime = System.currentTimeMillis();
   long endTime = rec_matrix_test1(dim);
   System.out.println(">> dim: "+dim+", elapsed time: "+(endTime - startTime)+" milli sec");
  }
 }
 
 public static void main(String[] args) {
  int max = 8;
  //simple_test_loop(max);
  rec_test_loop(max);
  
 }

Here is the execution results:
[simple]
>> dim: 2, elapsed time: 2 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 1 milli sec
>> dim: 16, elapsed time: 2 milli sec
>> dim: 32, elapsed time: 6 milli sec
>> dim: 64, elapsed time: 13 milli sec
>> dim: 128, elapsed time: 31 milli sec
>> dim: 256, elapsed time: 227 milli sec

------------------
[rec]
>> dim: 2, elapsed time: 3 milli sec
>> dim: 4, elapsed time: 0 milli sec
>> dim: 8, elapsed time: 2 milli sec
>> dim: 16, elapsed time: 32 milli sec
>> dim: 32, elapsed time: 378 milli sec
>> dim: 64, elapsed time: 494 milli sec
>> dim: 128, elapsed time: 1620 milli sec
>> dim: 256, elapsed time: 13757 milli sec
}

The main point of this approach is there is no special implementation for multiplication method for this recursive matrix. the algorithm is reflected in how it creates the cofficient ring element recursively.
But the performance is quite low.

Next, I will show new version which improve performance significantly, and the effectiveness of using 4 threads(4 times faster indeed).

No comments:

Post a Comment