1. 概述

fork/join 框架在 Java 7 中引入。它基于分而治之的思想,通过尝试利用所有可用处理器内核来帮助加速并行计算。

什么是分而治之?它分为任务分解,和结果合并两个阶段。

首先是fork 。通过递归方式将一个复杂任务分解为更小的独立子任务,直至子任务简单到无需再分。

分完之后, “join” 部分开始,将所有子任务结果递归地合并为一个结果。如果任务的返回值为 void类型,那么程序只需等待所有子任务执行完毕。

为了提高并行计算效率,fork/join 框架使用一个名为 ForkJoinPool 的线程池。该线程池负责管理类型为 ForkJoinWorkerThread 的工作线程。

ForkJoinPool

ForkJoinPool 是整个框架的核心,它实现了 ExecutorService 接口。

我们知道一个工作线程(Worker Thread)同一时间只能执行一个任务, ForkJoinPool 不会为每个子任务创建一个独立的线程,而是每个线程都维护了一个双端队列(deque),用来存储需要执行的任务。

这种架构对于借助工作窃取算法平衡线程的工作负载至关重要。

2.1. 工作窃取算法

何为工作窃取算法?

简单来说 – 空闲的线程尝试从其他繁忙线程的deque双端队列中窃取一个任务来执行

默认情况下,一个工作线程从 deque 头部读取任务。如果队列为空,则该线程会从其他繁忙线程的 deque 尾部或全局队列中获取一个任务。

这种算法最大限度地避免发生线程竞争任务,同时减少线程寻找任务的次数。

2.2. 实例化 ForkJoinPool

在 Java 8 中,获取 ForkJoinPool 实例最便捷的方法是使用其静态方法 commonPool()。顾名思义,它返回公共池的引用,公共池是每个 ForkJoinTask 的默认线程池。

根据 Oracle文档,建议使用预定义的公共线程池以减少资源消耗,避免每个任务都创建一个单独的线程池。

Java 7 中,需要我们自己实现单例模式,例如用饿汉式:

public static ForkJoinPool forkJoinPool = new ForkJoinPool(2);

获取实例:

ForkJoinPool forkJoinPool = PoolUtil.forkJoinPool;

通过 ForkJoinPool 构造函数我们可以创建自定义线程池,自定义并行级别(parallelism),线程创建工厂(ThreadFactory),异常处理器(ExceptionHandler)。上面例子中,parallelism参数为2, 表示该线程池将使用2个处理器内核。

3. ForkJoinTask <V>

ForkJoinTask 是我们任务的基类。实际中,我们应该继承它的两个子类:无返回值的 RecursiveAction 和带返回值的 RecursiveTask<V>。两者都有一个抽象方法 compute(),在里面实现我们的任务执行逻辑。

3.1. RecursiveAction 示例

下面例子中,我们将变量 workload 的所有字母转为大写并打印日志。本例仅仅是用于演示目的,这个任务没有实际意义。

为了演示框架任务分解行为,本例使用 createSubtask() 方法在workload.length()大于设定阈值时分解任务

workload 被递归地分解为子串,并创建基于这些子串的 CustomRecursiveTask 实例。

结果返回一个子任务集合 List<CustomRecursiveAction>。

使用 invokeAll() 将集合中的任务提交到 ForkJoinPool

    public class CustomRecursiveAction extends RecursiveAction {
    
        private String workload = "";
        private static final int THRESHOLD = 4;
    
        private static Logger logger = 
          Logger.getAnonymousLogger();
    
        public CustomRecursiveAction(String workload) {
            this.workload = workload;
        }
    
        @Override
        protected void compute() {
            if (workload.length() > THRESHOLD) {
                ForkJoinTask.invokeAll(createSubtasks());
            } else {
               processing(workload);
            }
        }
    
        private List<CustomRecursiveAction> createSubtasks() {
            List<CustomRecursiveAction> subtasks = new ArrayList<>();
    
            String partOne = workload.substring(0, workload.length() / 2);
            String partTwo = workload.substring(workload.length() / 2, workload.length());
    
            subtasks.add(new CustomRecursiveAction(partOne));
            subtasks.add(new CustomRecursiveAction(partTwo));
    
            return subtasks;
        }
    
        private void processing(String work) {
            String result = work.toUpperCase();
            logger.info("This result - (" + result + ") - was processed by " 
              + Thread.currentThread().getName());
        }
    }

可以套用此模版开发我们自己的 RecursiveAction类。创建一个对象表示我们的总任务,选择一个合适的阈值,定义一个用于分解任务的方法,以及实际处理任务的方法。

3.2. RecursiveTask<V>

对于带返回值的任务,实现逻辑类似,只是需要把每个子任务的结果合并到一个结果中:

    public class CustomRecursiveTask extends RecursiveTask<Integer> {
        private int[] arr;
    
        private static final int THRESHOLD = 20;
    
        public CustomRecursiveTask(int[] arr) {
            this.arr = arr;
        }
    
        @Override
        protected Integer compute() {
            if (arr.length > THRESHOLD) {
                return ForkJoinTask.invokeAll(createSubtasks())
                  .stream()
                  .mapToInt(ForkJoinTask::join)
                  .sum();
            } else {
                return processing(arr);
            }
        }
    
        private Collection<CustomRecursiveTask> createSubtasks() {
            List<CustomRecursiveTask> dividedTasks = new ArrayList<>();
            dividedTasks.add(new CustomRecursiveTask(
              Arrays.copyOfRange(arr, 0, arr.length / 2)));
            dividedTasks.add(new CustomRecursiveTask(
              Arrays.copyOfRange(arr, arr.length / 2, arr.length)));
            return dividedTasks;
        }
    
        private Integer processing(int[] arr) {
            return Arrays.stream(arr)
              .filter(a -> a > 10 && a < 27)
              .map(a -> a * 10)
              .sum();
        }
    }

本例中,变量 arr 表示我们的任务。createSubtasks() 方法递归地将一个大任务分解为小任务,直到小于阈值时不再分解。然后,invokeAll() 方法将子任务提交到公共池,并返回一个 Future 集合。

为了触发执行,为每个子任务调用 join() 方法。

这里使用了 Java 8 中的 Stream API 实现。sum()方法将子结果合并为最终结果。

4. 提交任务到 ForkJoinPool 中

要将任务提交到 ForkJoinPool 线程池中,可以使用:

submit()execute() 方法 :

forkJoinPool.execute(customRecursiveTask);
int result = customRecursiveTask.join();

invoke() 方法 fork 任务并等待返回结果,不需要手动 join 操作。

int result = forkJoinPool.invoke(customRecursiveTask);

invokeAll() 方法将 ForkJoinTask 任务批量提交到 ForkjoinPool。将任务作为参数传入(该方法有多个重载方法,可以传2个任务,或变长参数,或集合形式),fork然后按顺序返回一个 Future 集合。

或者,你也可以单独使用 fork() 和 join() 方法。fork() 提交一个任务到线程池中, 调用join()方法等待任务执行完毕。如果是 RecursiveAction 类型的任务,join() 返回 null,如果是RecursiveTask<V>类型,则返回任务执行结果。

customRecursiveTaskFirst.fork();
result = customRecursiveTaskLast.join();

在上面 RecursiveTask<V> 例子中我们使用 invokeAll() 批量提交子任务到线程池中。 同样的工作也可以通过 fork()join() 来完成,不过这会影响结果的排序。

为了避免混淆,通常最好使用 invokeAll() 方法将多个任务提交到 ForkJoinPool

5. 总结

在处理大型任务时,使用 fork/join 框架能加快处理速度。但前提是遵守以下几个原则:

  • 尽可能少使用线程池 – 大多数场景下,一个应用最好使用一个线程池
  • 使用默认的公共线程池l, 如果不需要特殊调优
  • 使用一个合理的阈值 将 ForkJoinTask 拆分为子任务
  • 避免在ForkJoinTask中编写阻塞代码

本文用到的例子源码可以从 GitHub 上获取。