Java并发编程--线程池之ForkJoinPool

摘要

  • 本文介绍线程池ForkJoinPool相关技术

  • 本文基于jdk1.8

Fork/Join框架介绍

  • Fork/Join是一个是一个并行计算的框架,主要就是用来支持分治任务模型的,这个计算框架里的 Fork对应的是分治任务模型里的任务分解,Join 对应的是结果合并。

  • 它的核心思想是将一个大任务分成许多小任务,然后并行执行这些小任务,最终将它们的结果合并成一个大的结果。

  • Fork/Join模式是实现任务并行性的一种常用模式,它将大任务递归地分解成小任务,然后利用多线程并行执行这些小任务。

  • Fork/Join框架的主要组成部分是ForkJoinPoolForkJoinTaskForkJoinPool是一个线程池,它用于管理Fork/Join任务的执行。ForkJoinTask是一个抽象类,用于表示可以被分割成更小部分的任务。

  • Fork/Join框架更适合执行CPU密集型任务,同时需要避免在ForkJoinPool中提交大量的阻塞型任务,以免影响整个线程池的性能。

应用场景

  • 1.递归分解型任务
    Fork/Join框架特别适用于递归分解型的任务,例如排序、归并、遍历等。这些任务通常可以将大的任务分解成若干个子任务,每个子任务可以独立执行,并且可以通过归并操作将子任务的结果合并成一个有序的结果。

  • 2.数组处理
    Fork/Join框架还可以用于数组的处理,例如数组的排序、查找、统计等。在处理大型数组时,Fork/Join框架可以将数组分成若干个子数组,并行地处理每个子数组,最后将处理后的子数组合并成一个有序的大数组。

  • 3.并行化算法
    Fork/Join框架还可以用于并行化算法的实现,例如并行化的图像处理算法、并行化的机器学习算法等。在这些算法中,可以将问题分解成若干个子问题,并行地解决每个子问题,然后将子问题的结果合并起来得到最终的解决方案。

  • 4.大数据处理
    Fork/Join框架还可以用于大数据处理,例如大型日志文件的处理、大型数据库的查询等。在处理大数据时,可以将数据分成若干个分片,并行地处理每个分片,最后将处理后的分片合并成一个完整的结果

ForkJoinPool介绍

  • ForkJoinPoolFork/Join框架中的线程池类,它用于管理Fork/Join任务的线程,基于工作窃取算法(work-stealing)来实现任务的分配和执行。

  • ForkJoinPool类包括一些重要的方法,例如submit()invoke()shutdown()awaitTermination()等,用于提交任务、执行任务、关闭线程池和等待任务的执行结果。ForkJoinPool类中还包括一些参数,例如线程池的大小、工作线程的优先级、任务队列的容量等,可以根据具体的应用场景进行设置。

  • 常用API

API/方法 返回值 描述
ForkJoinPool(int parallelism) 使用指定的并行度创建一个新的ForkJoinPool。
invoke(ForkJoinTask<T> task) T 同步执行给定的任务,并返回结果。
submit(ForkJoinTask<T> task) Future<T> 异步执行给定的任务,并返回一个Future对象,可以用于获取任务的执行结果。
execute(ForkJoinTask<?> task) 异步执行给定的任务,没有返回值。
awaitTermination(long timeout, TimeUnit unit) boolean 阻塞当前线程,直到所有任务执行完成或超过指定的超时时间,并返回是否成功终止。
isShutdown() boolean 判断ForkJoinPool是否已经关闭。
isTerminated() boolean 判断ForkJoinPool中的所有任务是否已经执行完成。
shutdown() 优雅地关闭ForkJoinPool,不再接受新的任务,并等待已提交的任务执行完成。
shutdownNow() List<Runnable> 强制关闭ForkJoinPool,尝试取消所有正在执行的任务,并返回等待执行的任务列表。
getParallelism() int 获取ForkJoinPool的并行度,即同时执行任务的线程数。
getPoolSize() int 获取ForkJoinPool中当前的工作线程数。
getActiveThreadCount() int 获取ForkJoinPool中当前活动的线程数。
getQueuedTaskCount() long 获取ForkJoinPool中当前等待执行的任务数。
getRunningThreadCount() int 获取ForkJoinPool中当前正在执行任务的线程数。
getStealCount() long 获取ForkJoinPool中总共发生的工作窃取次数。

ForkJoinPool的创建

  • ForkJoinPool中有四个核心参数,用于控制线程池的并行数、工作线程的创建、异常处理和模式指定等。各参数解释如下:

    • int parallelism:指定并行级别(parallelism level)。
      ForkJoinPool将根据这个设定,决定工作线程的数量。如果未设置的话,将使用Runtime.getRuntime().availableProcessors()来设置并行级别;
      1
      2
      3
      public ForkJoinPool(int parallelism) {
      this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
      }
    • ForkJoinWorkerThreadFactory factory:ForkJoinPool在创建线程时,会通过factory来创建。
      注意,这里需要实现的是ForkJoinWorkerThreadFactory,而不是ThreadFactory。如果你不指定factory,那么将由默认的DefaultForkJoinWorkerThreadFactory负责线程的创建工作;
    • UncaughtExceptionHandler handler:指定异常处理器
      当任务在运行中出错时,将由设定的处理器处理;
    • boolean asyncMode:设置队列的工作模式。
      asyncModetrue时,将使用先进先出队列,而为false时则使用后进先出的模式。
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      public ForkJoinPool(int parallelism,
      ForkJoinWorkerThreadFactory factory,
      UncaughtExceptionHandler handler,
      boolean asyncMode) {
      this(checkParallelism(parallelism),
      checkFactory(factory),
      handler,
      asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
      "ForkJoinPool-" + nextPoolId() + "-worker-");
      checkPermission();
      }
1
2
3
4
5
6
7
//获取处理器数量
int processors = Runtime.getRuntime().availableProcessors();
//构建forkjoin线程池
ForkJoinPool forkJoinPool = new ForkJoinPool(processors);

//此时等同于无参创建
ForkJoinPool forkJoinPool = new ForkJoinPool();

ForkJoinTask介绍

  • ForkJoinTaskFork/Join框架中的抽象类,它定义了执行任务的基本接口。用户可以通过继承ForkJoinTask类来实现自己的任务类,并重写其中的compute()方法来定义任务的执行逻辑。

  • 通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下三个子类:

    • RecursiveAction:用于递归执行但不需要返回结果的任务。
    • RecursiveTask :用于递归执行需要返回结果的任务。
    • CountedCompleter<T> :在任务完成执行后会触发执行一个自定义的钩子函数
  • ForkJoinTask 最核心的是 fork() 方法和 join()方法,承载着主要的任务协调作用,一个用于任务提交,一个用于结果获取。

    • fork()–提交任务
      fork()方法用于向当前任务所运行的线程池中提交任务。如果当前线程是ForkJoinWorkerThread类型,将会放入该线程的工作队列,否则放入common线程池的工作队列中。
    • join()–获取任务执行结果
      join()方法用于获取任务的执行结果。调用join()时,将阻塞当前线程直到对应的子任务完成运行并返回结果

ForkJoinPool 与 ForkJoinTask 使用示例

  • 利用fork-join实现数组归并排序算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
* 利用fork-join实现数组归并排序算法
*/
public class MergeSortRecursiveTask extends RecursiveTask<int[]> {

private final int threshold; //拆分的阈值,低于此阈值就不再进行拆分
private int[] arrayToSort; //要排序的数组

public MergeSortRecursiveTask(final int[] arrayToSort, final int threshold) {
this.arrayToSort = arrayToSort;
this.threshold = threshold;
}

@Override
protected int[] compute() {
//拆分后的数组长度小于阈值,直接进行排序
if (arrayToSort.length <= threshold) {
// 调用jdk提供的排序方法
Arrays.sort(arrayToSort);
return arrayToSort;
}

// 对数组进行拆分
int midpoint = arrayToSort.length / 2;
int[] leftArray = Arrays.copyOfRange(arrayToSort, 0, midpoint);
int[] rightArray = Arrays.copyOfRange(arrayToSort, midpoint, arrayToSort.length);

MergeSortRecursiveTask leftTask = new MergeSortRecursiveTask(leftArray, threshold);
MergeSortRecursiveTask rightTask = new MergeSortRecursiveTask(rightArray, threshold);

// 调用任务,阻塞当前线程,直到所有子任务执行完成
// 使用invokeAll()方法同时提交多个任务,以提高任务的并行度,相当于同时fork并join
invokeAll(leftTask, rightTask);

//提交任务,多个任务同时提交时,推荐使用invokeAll()
// leftTask.fork();
// rightTask.fork();
//合并结果
// leftTask.join();
// rightTask.join();

// 合并排序结果
arrayToSort = merge(leftTask.getSortedArray(), rightTask.getSortedArray());
return arrayToSort;
}

private int[] getSortedArray() {
return arrayToSort;
}

/**
* 合并两个有序数组,并返回合并后的有序数组
*/
private int[] merge(final int[] leftArray, final int[] rightArray) {
// 定义用于合并结果的数组
int[] mergedArray = new int[leftArray.length + rightArray.length];
int mergedArrayPos = 0;
// 利用双指针进行两个数的比较
int leftArrayPos = 0;
int rightArrayPos = 0;
while (leftArrayPos < leftArray.length && rightArrayPos < rightArray.length) {
// 比较左右数组中的元素大小,并将较小的元素放入合并结果数组中
if (leftArray[leftArrayPos] <= rightArray[rightArrayPos]) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
} else {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
}
mergedArrayPos++;
}

// 将剩余的左数组元素放入合并结果数组中
while (leftArrayPos < leftArray.length) {
mergedArray[mergedArrayPos] = leftArray[leftArrayPos];
leftArrayPos++;
mergedArrayPos++;
}

// 将剩余的右数组元素放入合并结果数组中
while (rightArrayPos < rightArray.length) {
mergedArray[mergedArrayPos] = rightArray[rightArrayPos];
rightArrayPos++;
mergedArrayPos++;
}

// 返回合并后的有序数组
return mergedArray;
}


/**
* 随机生成数组
* @param size 数组的大小
*/
private static int[] buildRandomIntArray(final int size) {
int[] arrayToCalculateSumOf = new int[size];
Random generator = new Random();
for (int i = 0; i < arrayToCalculateSumOf.length; i++) {
arrayToCalculateSumOf[i] = generator.nextInt(10000);
}
return arrayToCalculateSumOf;
}

public static void main(String[] args) {
//拆分的阈值
int threshold = 20;
int[] arrayToSortByMergeSort = buildRandomIntArray(2000);
System.out.print("排序前: ");
for (int element : arrayToSortByMergeSort) {
System.out.print(element + " ");
}
System.out.println();
//利用forkjoin排序
MergeSortRecursiveTask mergeSortRecursiveTask = new MergeSortRecursiveTask(arrayToSortByMergeSort, threshold);
//构建forkjoin线程池
ForkJoinPool forkJoinPool = new ForkJoinPool();
long startTime = System.nanoTime();
//执行排序任务
final int[] mergeSortArray = forkJoinPool.invoke(mergeSortRecursiveTask);
System.out.print("排序后: ");
for (int element : mergeSortArray) {
System.out.print(element + " ");
}
System.out.println();

long duration = System.nanoTime() - startTime;
System.out.println("forkjoin排序时间: " + (duration / (1000f * 1000f)) + "毫秒");
}
}

ForkJoinPool工作原理

  • ForkJoinPool 内部有多个任务队列,当我们通过 ForkJoinPoolinvoke() 或者 submit() 方法提交任务时,ForkJoinPool 根据一定的路由规则把任务提交到一个任务队列中,如果任务在执行过程中会创建出子任务,那么子任务会提交到工作线程对应的任务队列中。

  • 如果工作线程对应的任务队列空了,是不是就没活儿干了呢?不是的,ForkJoinPool 支持一种叫做任务窃取的机制,如果工作线程空闲了,那它可以窃取其他工作任务队列里的任务。如此一来,所有的工作线程都不会闲下来了。

小贴士

工作窃取

  • 工作窃取,就是允许空闲线程从繁忙线程的双端队列中窃取任务。
  • 默认情况下,工作线程从它自己的双端队列的头部获取任务。但是,当自己的任务为空时,线程会从其他繁忙线程双端队列的尾部中获取任务。这种方法,最大限度地减少了线程竞争任务的可能性。
  • ForkJoinPool执行流程

总结

  • Fork/Join是一种基于分治思想的模型,在并发处理计算型任务时有着显著的优势。其效率的提升主要得益于两个方面:

    • 任务切分:将大的任务分割成更小粒度的小任务,让更多的线程参与执行;
    • 任务窃取:通过任务窃取,充分地利用空闲线程,并减少竞争。
  • 在使用ForkJoinPool时,需要特别注意任务的类型是否为纯函数计算类型,也就是这些任务不应该关心状态或者外界的变化,这样才是最安全的做法。如果是阻塞类型任务,那么你需要谨慎评估技术方案。虽然ForkJoinPool也能处理阻塞类型任务,但可能会带来复杂的管理成本。