关于我脑洞大开去用多线程优化快速排序这件事

今晚在复习 TopK 问题手写快排时,突发奇想:“既然快排每次都要划分出两个区间重复进行快排,那么我可不可以将新划分出的两个区间用两个新线程去跑 ? ” 于是就有了这篇文章。

初次尝试

如果每次划分新区间都开线程跑,那最后的线程数肯定会爆炸式增长,所以我首先想到用线程池去跑。在线程池中,多余的任务放在阻塞队列执行,保证最大执行线程数不超过 CPU 核心数。既能最大利用 CPU 多核,又不至于让线程数溢出,一举两得。

理论可行,开始实践!

代码

下面是一个原生的快排,我的写法可能跟常规写法不一样,不过效率是一样的:

public static void quickSort(int nums, int l, int r) {
    if (l >= r) return;
    int x = nums[i], i = l, j = r + 1;
    while (i < j) {
        while (nums[--j] > x && i != j);
        if (i == j) nums[j] = x;
        else {
            nums[i] = nums[j];
            while (nums[++i] < x && i != j);
            if (i == j) nums[i] = x;
            else nums[j] = nums[i];
        }
    }
    quickSort(nums, l, j - 1);
    quickSort(nums, j + 1, r);
}

那如何将线程池用到快排里去呢?其实用栈实现迭代写法会更易于理解。

这里线程池也起到了一个存储任务的作用,即任务队列。每次对区间进行划分后,将划分的区间的左右位置存到队列中,留到之后执行,类似于迭代法中的栈。不过线程池的好处就是,它会自动执行,而不需要我们通过循环去取任务执行。

那么先写一个线程需要执行的方法,我们不需要返回值,所以实现 Runnable,如下:

class Task implements Runnable {

    private int left;

    private int right;

    private int[] nums;

    private AtomicInteger count;

    private ExecutorService executor;

    // 传参
    public Task(int left, int right, int[] nums, AtomicInteger count, ExecutorService executor) {
        this.left = left;
        this.right = right;
        this.nums = nums;
        this.count = count;
        this.executor = executor;
    }

    @Override
    public void run() {
        // 划分区间前的移位操作
        int x = nums[left], i = left, j = right + 1;
        while (i < j) {
            while (nums[--j] > x && i != j);
            if (i == j) nums[j] = x;
            else {
                nums[i] = nums[j];
                while (nums[++i] < x && i != j);
                if (i == j) nums[i] = x;
                else nums[j] = nums[i];
            }
        }
        if (left < i - 1) {
            count.getAndIncrement(); // 将未完成任务数 + 1
            // 将下个区间的排序任务交给新线程执行
            executor.execute(new Task(left, i - 1, nums, count, executor));
        }
        if (i + 1 < right) {
            count.getAndIncrement();
            executor.execute(new Task(i + 1, right, nums, count, executor));
        }
        count.getAndDecrement(); // 最后记得扣除任务数
    }
}

因为线程需要传参,所以我们通过构造函数给字段赋值来传,这里一个个解释:

  • leftright :区间左右两边的索引
  • nums :数组
  • count :计数器,用来判断线程池何时将所有任务执行完成
  • executor :线程池

接着,我们来写出主类的结构,如下:

public class ParallelQuickSort {
    public static void main(String[] args) {
        // 生成随机数据
        int[] nums = generateRandomArray(10000000);
        double start = System.currentTimeMillis();
        // Arrays.sort(nums);
        parallelQuickSort(nums);
        double end = System.currentTimeMillis();
        System.out.println(((end - start) / 1000) + " seconds");
        // 验证排序结果
        for (int i = 1; i < nums.length; i++) {
            if (nums[i] < nums[i - 1]) {
                System.out.println("排序失败!");
                break;
            }
        }
    }

    public static void parallelQuickSort(int[] nums) {
        if (nums == null || nums.length == 0) return;
        // 这里图简单,直接用内置线程池
        ExecutorService executor = Executors.newFixedThreadPool(20);
        // count其实代表了未完成的任务数,包括正在执行和等待执行的
        AtomicInteger count = new AtomicInteger(1);
        executor.execute(new Task(0, nums.length - 1, nums, count, executor));
        // 自旋判断是否已经执行完毕
        while (count.get() > 0) {
            try {
                System.out.println("count:" + count.get());
                Thread.sleep(10);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        // 以下皆为关闭线程池的措施
        executor.shutdown();
        try {
            executor.awaitTermination(60, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    // 随机生成数据
    public static int[] generateRandomArray(int size) {
        Random random = new Random();
        int[] array = new int[size];
        for (int i = 0; i < size; i++) {
            array[i] = random.nextInt(size) + 1;
        }
        return array;
    }
}

整个流程其实就是将迭代法中的队列换成了可以自动执行的线程池,而计数器因为存在并发操作,所以使用原子类确保线程安全。

测试

我们利用随机函数生成随机整型数据,分别使用原生的 Arrays.sort 和我们的多线程快排来测试,结果如下:

数据量 Arrays.sort ParallelQuickSort
10000000 1.466 秒 2.305 秒
1000000 0.112 秒 0.208 秒

可以看出,我们的多线程竟然比单线程的原生方法还慢,几乎差了一倍,这是什么原因呢?

经过我的一波分析和查阅资料,基本锁定原因:多线程的频繁上下文切换。

在代码中,我们可以看到我并没有对“划分区间给新线程跑”这一行为做限制,以至于即使区间再小也会扔到线程池去执行。而这之间消耗的线程切换时间可要比直接用单线程跑要多。所以我们可以针对这一点进行优化。

二次优化

我们只需要对 run() 作以下修改并且添加一个普通的快排方法即可,经过我测试,当区间长度小于 10000 时直接使用快排效果最好,如下:

@Override
public void run() {
    int x = nums[left], i = left, j = right + 1;
    while (i < j) {
        while (nums[--j] > x && i != j);
        if (i == j) nums[j] = x;
        else {
            nums[i] = nums[j];
            while (nums[++i] < x && i != j);
            if (i == j) nums[i] = x;
            else nums[j] = nums[i];
        }
    }
    // 当区间长度小于 10000 时直接使用快排效果
    if (right - left <= 10000) {
        quicksort(nums, left, i - 1);
        quicksort(nums, i + 1, right);
    } else {
        if (left < i - 1) {
            count.getAndIncrement();
            executor.execute(new MyRunnable(left, i - 1, nums, count, executor));
        }
        if (i + 1 < right) {
            count.getAndIncrement();
            executor.execute(new MyRunnable(i + 1, right, nums, count, executor));
        }
    }
    count.getAndDecrement();
}

public static void quicksort(int[] nums, int l, int r) {
    if (l >= r) return;
    int x = nums[l], i = l, j = r + 1;
    while (i < j) {
        while (nums[--j] > x && i != j);
        if (i == j) nums[j] = x;
        else {
            nums[i] = nums[j];
            while (nums[++i] < x && i != j);
            if (i == j) nums[i] = x;
            else nums[j] = nums[i];
        }
    }
    quicksort(nums, l, j - 1);
    quicksort(nums, j + 1, r);
}

最终测试

优化完后,我们再来测试,依旧是五次结果取平均,结果如下:

数据量 Arrays.sort ParallelQuickSort
10000000 1.544 秒 0.339 秒
1000000 0.111 秒 0.054 秒

可以看出,在千万级数据量下快了将近五倍,只能说成效非常明显。

总结

最后我思考了一下,为什么原生的快排方法不使用多线程,可能的原因又如下几点:

  • 大量数据放在内存中很占空间的,更多会采用多路归并排序(外部排序)
  • 多线程会消耗 CPU 资源,仅仅用来排序感觉多少有点浪费
  • 两者的差距在千万级数据量下才开始有明显差距,大部分情况下不会有这么高

p.s. 我后来才知道 Java 的 Arrays.parallelSort() 方法就是原生的多线程快排😂