In this post we will look on how would we go about building a general purpose parallel reduce using the ForkJoin model.

Btw, you don’t need to do that if you are on Java 1.8 or later versions because Streams already support reduction and you should be using them rather than rolling your own, so this is for educational purposes only.

The idea of reduction is very simple, given a dataset, transform it into a single value. An example of reduce would be, given an array of integers, return the sum of all its elements. Such an operation could be done using a simple loop and an accumulator.

int sum = 0;
for(int element:array){
sum += element;
}

The only problem with code above is that it doesn’t scale to multi-core. Given infinite number of processors, the algorithm above will run in O(n) time regardless. At the end of this post, you should be able to convince yourself that such an operation could done O(lgn) if the number of processors we have tends to infinity.

**A Divide and Conquer sum.**

Instead of calculating the sum through scanning the entire array at once sequentially, we can use a divide and conquer algorithm for that, something very similar to merge sort; divide the array into two parts and solve the subproblem recursively. As in merge sort, we aiming for recursion tree of depth O(lgn), although in practice our recursion tree would be shorter than that.

**A Naive Implementation.**

A** **naive way would be to have a shared atomic accumulator, split the array into parts recursively and when when the size of the array is small enough to be computed sequentially, we calculate the partial sum, append it to the accumulator. At the end of this, the accumulator will have the total sum of the element.

The algorithm is definitely correct, it derives its correctness from the sequential one anyway. The only problem is that this algorithm suffers from contention. All worker threads have to coordinate when they ready to contribute with their partial sum. This would introduce a sequential bottleneck to our algorithm and according to Amdahl’s law, sequential parts of the algorithm is what governs its performance.

One more problem with the algorithm described above is that it is not cache friendly, atomics are usually implemented using a CAS instruction. Although CAS scales better than a mutex as it doesn’t suffer from context switching, it still requires cache line invalidation.

**A Better Solution.**

A better solution would be a shared nothing solution where our solution tree looks like binary balanced tree, when we reach leaf ( a partial sum), we combine it with it’s sibling’s partial solution building a bottom up aggregated result.

A pseudo code.

sum(a){
if(size == sequential_threshold){
sum = 0
for(i : a){
sum+=i;
}
return sum;
} else {
leftArray = a[0..mid]
rightArray = a[mid+1..length]
leftTask = fork(sum(left))
rightTask = fork(sum(right))
return leftTask.join() + rightTask.join()
}
}

**A General Purpose Reduce**

But this looks like a pattern, we often accumulate, not necessarily for sum, but counting is a similar thing, so is MAX, so is MIN. So, can we build a general purpose reduce? Of course we can. However, there is limitation to what can be achieved using this pattern. The operation has to be Associative, in simple words, the order doesn’t matter. Sum for example is associative.

A + (B + C) = (A+B) + C

So is MAX

MAX(1,-77,13,4) = MAX(4,1,13,-77)

A property that division doesn’t have since order matters.

**Ingredients for our Recipe **

Now it’s time to think of the tools needed for that.

- We need a source collection, a one that we can split efficiently. A linked list would be a horrible one, because splitting it requires linear time. Arrays to rescue ! Arrays split in constant time, remember merge sort?? Moreover, arrays get the best out of cache lines because of locality.
- An operator, a binary operator in particular. Something that takes two parameters, does something, and returns a result. For that, we will be using BinaryOperator from JDK 1.8.
- ForkJoin framework.
- An identity or seed, a value that we can as a baseline ( “initialiser”) for our computation. For sum it would be 0, for MAX it would be Integer.MIN_VALUE

**Java Implementation**

The code for this is incredibly simple. All what we need is a simple RecursiveTask and tiny wrapper around that hides some details.

static class ReductionTask<T> extends RecursiveTask<T>{
private final T[]a;
private final int low;
private final int high;
private final BinaryOperator<T> operator;
private static final int THRESHOLD = 1000;
public ReductionTask(T[]a, BinaryOperator<T>operator){
this(a,0,a.length-1,operator);
}
private ReductionTask(T[]a,int low,int high,BinaryOperator<T>operator) {
this.a = a;
this.low = low;
this.high = high;
this.operator = operator;
}
@Override
protected T compute() {
if((high-low)<=THRESHOLD){
T result = a[low];
for (int i = low + 1;i<= high;i++) {
result = operator.apply(result,a[i]);
}
return result;
}
int mid = (low + high) <<< 1;
ReductionTask<T> leftTask = new ReductionTask<>(a,low,mid,operator);
ReductionTask<T> rightTask = new ReductionTask<>(a,mid+1,high,operator);
ForkJoinTask<T>fork1 = leftTask.fork();
ForkJoinTask<T>fork2 = rightTask.fork();
T result = operator.apply(fork1.join(),fork2.join());
return result;
}
}

And a tiny class to submit the task to the ForkJoinPool

public class ArrayReducer<T>{
private final T[]a;
private final T identity;
private final BinaryOperator<T> operator;
public ArrayReducer(T[]a,T identity, BinaryOperator<T> operator){
this.a = a;
this.identity = identity;
this.operator = operator;
}
public T compute() throws ExecutionException, InterruptedException {
if(a.length==0){
return this.identity;
}
ReductionTask<T> reductionTask = new ReductionTask<>(a,operator);
ForkJoinTask<T> task = ForkJoinPool.commonPool().submit(reductionTask);
T result = task.invoke();
result = operator.apply(identity, result);
return result;
}
}

Now we can use this abstraction to perform our parallel sum as follows

ArrayReduce<Integer> reduce = new ArrayReduce<>(a,0,(x,y) -> x+ y);
int sum = reduce.compute();

The same could be used for MAX as well.

ArrayReduce<Integer> reduce = new ArrayReduce<>(a,Integer.MIN_VALUE,Math::max);
int max = reduce.compute();