K-Way Merge Sort

Created: 2010-10-17 15:09:38
This was a fun assignment for an Algorithms class last spring. The assignment was simple, write an algorithm that does a merge sort given an array of integers and and the amount of k splits. It was surprisingly to see how the efficiency changed with the number of splits.

So what differentiates a K-Way Merge sort from a regular Merge sort?
Well a normal Merge sort splits up the given data set in to two sets of data recursively.

For example, in the set:
[4,6,2,8,9,1,7,5,3]

We would split it in to:
[4,6,2,8] and [9,1,7,5,3]

Then the same algorithm would be recursively ran on the left and the right pieces and then merged together:
merge_sort(left);
merge_sort(right);
merge(left, right);

Now because we ran merge_sort on the left and right pieces first we can assume they are now sorted, now it's just a matter of putting the two sorted arrays together, so this makes the actually merge_sort method pretty trivial, it's the merge() that gets interesting. And because you can see how we split it in to a left and a right you could pretty easily see how a splitting it three ways with a left, middle, and right would be pretty similar. With a K-Way we just split it up in to however many chunks we want trying to find better efficiency by reducing the chunk size much faster.

The Merge
The interesting piece of the algorithm, starting with the traditional two sided merge sort we have a base case of when we get down to two elements. With two elements things are pretty simple we just sort the two elements and return them, but what about when we have two sets of sorted data, what is the best way to merge them?

Well because both pieces are already individually sorted we can actually sort them together pretty quickly. We start at the beginning of each array and compare each number one by one removing the smaller and moving forward.

For example, lets take our earlier example, sorted:
[2,4,6,8] and [1,3,5,7,9]

So compare the first numbers in each set and 1 less than 2 so we move the 1 to the result set and continue on:
Result:
[1,?,?,?,?,?,?,?,?]

[2,4,6,8] and [x,3,5,7,9]

Compare 2 and 3 together:
Result:
[1,2,?,?,?,?,?,?,?]

[x,4,6,8] and [x,3,5,7,9]

Then 4 and 3 and so forth until we get our sorted array, this should happen in no more than O(n) time.

Now there is a slight trick here, instead of passing a left and right since the left and right pieces should always be next to each other in the original array we can do some trickery to turn:
merge(left, right)

in to:
merge(data, low, high)

And find our split we can do some quick math:
(high+low)/2

Now we still have to copy over the data from low to high in to a temporary array so we can put them back in to the original one, but this does allow us to sort the original array and minimize creations to just the temp arrays.

So what about the K-Way merge again?
Well you can see doing a merge(data, low, high) may work better than trying to pass k number of small arrays, but the concept is pretty similar.

For the base case we have something a little less trivial since our smallest merge will have k elements in it, which for a large data set could have 10 to 50 partitions OR MORE. To solve this we can sort by putting them in to a heap real quick and pulling them back out, in the given source code below, I didn't write the heap, but it's pretty trivial to do so.
if (high < low + k) {
// the subarray has k or fewer elements
// just make one big heap and do deleteMins on it
Comparable[] subarray = new MergesortHeapNode[high - low + 1];
for (int i = 0, j = low; i < subarray.length; i++, j++) {
subarray[i] = new MergesortHeapNode(data[j], 0);
}
BinaryHeap heap = BinaryHeap.buildHeap(subarray);
for (int j = low; j <= high; j++) {
try {
data[j] = ((MergesortHeapNode) heap.deleteMin()).getKey();
}
catch (EmptyHeapException e) {
System.out.println ("Tried to delete from an empty heap.");
}
}

}

And for the non-base case where we have several sorted arrays we run in to a similar problem, we can no longer do a simple, is left less than right. We have to put the first index of all arrays in to a heap to figure out which array has the lowest first element. And once we figure out which array has the lowest from the heap we have to pull the element from that array and put it where it goes and then we have to refill the heap with the next element in the small array we just took from.
else {
// divide the array into k subarrays and do a k-way merge
final int subarrSize = high-low+1;
final int[] tempArray = new int[subarrSize];

// Make temp array
for (int i = low; i < high + 1; i++)
tempArray[i-low] = data[i];

// Keep subarray index to keep track of where we are in each subarray
final int[] subarrayIndex = new int[k];
for (int i = 0; i < k; i++)
subarrayIndex[i] = i*(subarrSize)/k;

// Build heap
Comparable[] subarray = new MergesortHeapNode[k];
for (int i = 0; i < k; i++)
subarray[i] = new MergesortHeapNode(tempArray[subarrayIndex[i]++], i);

BinaryHeap heap = BinaryHeap.buildHeap(subarray);

// For each element low to high, find the lowest in each k subarray
for (int i = low; i < high + 1; i++)
{

// Take lowest element and add back in to original array
try
{
MergesortHeapNode a = ((MergesortHeapNode) heap.deleteMin());
data[i] = a.getKey();
if (subarrayIndex[a.getWhichSubarray()] < (a.getWhichSubarray()+1)*(subarrSize)/k)
{
heap.insert(new MergesortHeapNode(tempArray[subarrayIndex[a.getWhichSubarray()]]++, a.getWhichSubarray()));

// Increment the subarray index where the lowest element resides
subarrayIndex[a.getWhichSubarray()]++;
}
} catch (EmptyHeapException e)
{
System.out.println ("Tried to delete from an empty heap.");
}
}
}

Well I hope everyone learned something, if not it's always good practice. Feel free to send a comment if you are interested in more.

Download the source here: kwaymerge.zip (11 KB)