java多线程之ForkJoin框架

Fork/JOIN框架

简介

Fork/Join框架是Java7提供了的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。

运行流程图

工作窃取模式

工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行

假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。

工作窃取算法的优点是充分利用线程进行并行计算,并减少了线程间的竞争,其缺点是在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。

例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
public class ForkJoinCase extends RecursiveTask<Double> {
private static final long serialVersionUID = 1L;
public ForkJoinCase() {
}
// 进行forkJoin的界限,数量少于100直接进行计算
static final int THRESHOLD = 100;
double[] array;
int start;
int end;
ForkJoinCase(double[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Double compute() {
if (end - start <= THRESHOLD) {
// 如果任务足够小,直接计算:
double sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
// 任务太大,一分为二:
int middle = (end + start) / 2;
ForkJoinCase subtask1 = new ForkJoinCase(this.array, start, middle);
ForkJoinCase subtask2 = new ForkJoinCase(this.array, middle, end);
invokeAll(subtask1, subtask2);
Double subresult1 = subtask1.join();
Double subresult2 = subtask2.join();
Double result = subresult1 + subresult2;
return result;
}
public static void main(String[] args) {
double[] array = new double[400];
for(int i=0; i<array.length; i++) {
array[i] = Math.random() * 100;
}
// fork/join task:
ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4
ForkJoinTask<Double> task = new ForkJoinCase(array, 0, array.length);
long startTime = System.currentTimeMillis();
Double result = fjp.invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}
}

注意事项

  • invokeAll()方法:invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。

附invokeAll源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
public static void invokeAll(ForkJoinTask forkjointask, ForkJoinTask forkjointask1)
{
forkjointask1.fork();
int i;
if((i = forkjointask.doInvoke() & -268435456) != -268435456)
forkjointask.reportException(i);
int j;
if((j = forkjointask1.doJoin() & -268435456) != -268435456)
forkjointask1.reportException(j);
}
public static transient void invokeAll(ForkJoinTask aforkjointask[])
{
Object obj = null;
int i = aforkjointask.length - 1;
for(int j = i; j >= 0; j--)
{
ForkJoinTask forkjointask = aforkjointask[j];
if(forkjointask == null)
{
if(obj == null)
obj = new NullPointerException();
continue;
}
if(j != 0)
{
forkjointask.fork();
continue;
}
if(forkjointask.doInvoke() < -268435456 && obj == null)
obj = forkjointask.getException();
}
for(int k = 1; k <= i; k++)
{
ForkJoinTask forkjointask1 = aforkjointask[k];
if(forkjointask1 == null)
continue;
if(obj != null)
{
forkjointask1.cancel(false);
continue;
}
if(forkjointask1.doJoin() < -268435456)
obj = forkjointask1.getException();
}
if(obj != null)
rethrow(((Throwable) (obj)));
}

参考资料