为了账号安全,请及时绑定邮箱和手机立即绑定
首页 手记 Java Fork/Join并行框架

Java Fork/Join并行框架

2018.06.10 22:45 1277浏览

作者: 一字马胡  
转载标志 【2017-11-03】

更新日志

日期更新内容备注
2017-11-03添加转载标志持续更新

初步了解Fork/Join框架

Fork/Join 框架是java7中加入的一个并行任务框架,可以将任务分割成足够小的小任务,然后让不同的线程来做这些分割出来的小事情,然后完成之后再进行join,将小任务的结果组装成大任务的结果。下面的图片展示了这种框架的工作模型:

Fork/Join工作模型

使用Fork/Join并行框架的前提是我们的任务可以拆分成足够小的任务,而且可以根据小任务的结果来组装出大任务的结果,一个最简单的例子是使用Fork/Join框架来求一个数组中的最大/最小值,这个任务就可以拆成很多小任务,大任务就是寻找一个大数组中的最大/最小值,我们可以将一个大数组拆成很多小数组,然后分别求解每个小数组中的最大/最小值,然后根据这些任务的结果组装出最后的最大最小值,下面的代码展示了如何通过Fork/Join求解数组的最大值:

import java.util.concurrent.ExecutionException;import java.util.concurrent.ForkJoinPool;import java.util.concurrent.Future;import java.util.concurrent.RecursiveTask;import java.util.concurrent.TimeUnit;import java.util.concurrent.TimeoutException;/**
* Created by hujian06 on 2017/9/28.
*
* fork/join demo
*/public class ForkJoinDemo {  /**
   * how to find the max number in array by Fork/Join
   */
  private static class MaxNumber extends RecursiveTask<Integer> {      private int threshold = 2;      private int[] array; // the data array

      private int index0 = 0;      private int index1 = 0;      public MaxNumber(int[] array, int index0, int index1) {          this.array = array;          this.index0 = index0;          this.index1 = index1;
      }      @Override
      protected Integer compute() {          int max = Integer.MIN_VALUE;          if ((index1 - index0) <= threshold) {              for (int i = index0;i <= index1; i ++) {
                  max = Math.max(max, array[i]);
              }

          } else {              //fork/join
              int mid = index0 + (index1 - index0) / 2;
              MaxNumber lMax = new MaxNumber(array, index0, mid);
              MaxNumber rMax = new MaxNumber(array, mid + 1, index1);

              lMax.fork();
              rMax.fork();              int lm = lMax.join();              int rm = rMax.join();

              max = Math.max(lm, rm);

          }          return max;
      }
  }  public static void main(String ... args) throws ExecutionException, InterruptedException, TimeoutException {

      ForkJoinPool pool = new ForkJoinPool();      int[] array = {100,400,200,90,80,300,600,10,20,-10,30,2000,1000};

      MaxNumber task = new MaxNumber(array, 0, array.length - 1);

      Future<Integer> future = pool.submit(task);

      System.out.println("Result:" + future.get(1, TimeUnit.SECONDS));

  }

}

可以通过设置不同的阈值来拆分成小任务,阈值越小代表拆出来的小任务越多。

工作窃取算法

Fork/Join在实现上,大任务拆分出来的小任务会被分发到不同的队列里面,每一个队列都会用一个线程来消费,这是为了获取任务时的多线程竞争,但是某些线程会提前消费完自己的队列。而有些线程没有及时消费完队列,这个时候,完成了任务的线程就会去窃取那些没有消费完成的线程的任务队列,为了减少线程竞争,Fork/Join使用双端队列来存取小任务,分配给这个队列的线程会一直从头取得一个任务然后执行,而窃取线程总是从队列的尾端拉取task。

Frok/Join框架的实现细节

在上面的示例代码中,我们发现Fork/Join的任务是通过ForkJoinPool来执行的,所以框架的一个核心是任务的fork和join,然后就是这个ForkJoinPool。关于任务的fork和join,我们可以想象,而且也是由我们的代码自己控制的,所以要分析Fork/Join,那么ForkJoinPool最值得研究。

ForkJoinPool的类关系图

上面的图片展示了ForkJoinPool的类关系图,可以看到本质上它就是一个Executor。在ForkJoinPool里面,有两个特别重要的成员如下:

    volatile WorkQueue[] workQueues;  
    final ForkJoinWorkerThreadFactory factory;

workQueues 用于保存向ForkJoinPool提交的任务,而具体的执行有ForkJoinWorkerThread执行,而ForkJoinWorkerThreadFactory可以用于生产出ForkJoinWorkerThread。可以看一些ForkJoinWorkerThread,可以发现每一个ForkJoinWorkerThread会有一个pool和一个workQueue,和我们上面描述的是一致的,每个线程都被分配了一个任务队列,而执行这个任务队列的线程由pool提供。

下面我们看一下当我们fork的时候发生了什么:

    public final ForkJoinTask<V> fork() {
        Thread t;        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);        else
            ForkJoinPool.common.externalPush(this);        return this;
    }

看上面的fork代码,可以看到首先取到了当前线程,然后判断是否是我们的ForkJoinPool专用线程,如果是,则强制类型转换(向下转换)成ForkJoinWorkerThread,然后将任务push到这个线程负责的队列里面去。如果当前线程不是ForkJoinWorkerThread类型的线程,那么就会走else之后的逻辑,大概的意思是首先尝试将任务提交给当前线程,如果不成功,则使用例外的处理方法,关于底层实现较为复杂,和我们使用Fork/Join关系也不太大,如果希望搞明白具体原理,可以看源码。

下面看一下join的流程:

    public final V join() {        int s;        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);        return getRawResult();
    }    
        private int doJoin() {        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }    
        final int doExec() {        int s; boolean completed;        if ((s = status) >= 0) {            try {
                completed = exec();
            } catch (Throwable rex) {                return setExceptionalCompletion(rex);
            }            if (completed)
                s = setCompletion(NORMAL);
        }        return s;
    }    
     /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();        return true;
    }

上面展示了主要的调用链路,我们发现最后落到了我们在代码里编写的compute方法,也就是执行它,所以,我们需要知道的一点是,fork仅仅是分割任务,只有当我们执行join的时候,我们的额任务才会被执行。

如何使用Fork/Join并行框架

前文首先展示了一个求数组中最大值得例子,然后介绍了“工作窃取算法”,然后分析了Fork/Join框架的一些细节,下面才是我们最关心的,怎么使用Fork/Join框架呢?

为了使用Fork/Join框架,我们只需要继承类RecursiveTask或者RecursiveAction。前者适用于有返回值的场景,而后者适合于没有返回值的场景。

基于fork/Join框架实现归并排序

直接放可执行的代码:

import java.util.Random;import java.util.concurrent.ForkJoinPool;import java.util.concurrent.RecursiveAction;/**
 * Created by hujian06 on 2017/10/23.
 *
 * merge sort by fork/join
 */public class ForkJoinMergeSortDemo {    public static void main(String ... args) {        new Worker().runWork();
    }

}class Worker {    private static final boolean isDebug = false;    public void runWork() {        int[] array = mockArray(200000000, 1000000); // mock the data

        forkJoinCase(array);
        normalCase(array);

    }    private void printArray(int[] arr) {        if (isDebug == false) {            return;
        }        for (int i = 0; i < arr.length; i ++) {
            System.out.print(arr[i] + " ");
        }

        System.out.println();
    }    private void forkJoinCase(int[] array) {
        ForkJoinPool pool = new ForkJoinPool();

        MergeSortTask mergeSortTask = new MergeSortTask(array, 0, array.length - 1);        long start = System.currentTimeMillis();

        pool.invoke(mergeSortTask);        long end = System.currentTimeMillis();

        printArray(array);

        System.out.println("[for/join mode]Total cost: " + (end - start) / 1000.0 + " s, for " +
                array.length + " items' sort work.");
    }    private void normalCase(int[] array) {        long start = System.currentTimeMillis();        new MergeSortWorker().sort(array, 0, array.length - 1);        long end = System.currentTimeMillis();

        printArray(array);

        System.out.println("[normal mode]Total cost: " + (end - start) / 1000.0 + " s, for " +
                array.length + " items' sort work.");
    }    private static final  int[] mockArray(int length, int up) {        if (length <= 0) {            return null;
        }        int[] array = new int[length];

        Random random = new Random(47);        for (int i = 0; i < length; i ++) {
            array[i] = random.nextInt(up);
        }        return array;
    }
}class MergeSortTask extends RecursiveAction {    private static final int threshold = 100000;    private final MergeSortWorker mergeSortWorker = new MergeSortWorker();    private int[] data;    private int left;    private int right;    public MergeSortTask(int[] array, int l, int r) {        this.data = array;        this.left = l;        this.right = r;
    }    @Override
    protected void compute() {        if (right - left < threshold) {
            mergeSortWorker.sort(data, left, right);
        } else {            int mid = left + (right - left) / 2;
            MergeSortTask l = new MergeSortTask(data, left, mid);
            MergeSortTask r = new MergeSortTask(data, mid + 1, right);

            invokeAll(l, r);

            mergeSortWorker.merge(data, left, mid, right);
        }
    }
}class MergeSortWorker {    // Merges two subarrays of arr[].
    // First subarray is arr[l..m]
    // Second subarray is arr[m+1..r]
    void merge(int arr[], int l, int m, int r) {        // Find sizes of two subarrays to be merged
        int n1 = m - l + 1;        int n2 = r - m;        /* Create temp arrays */
        int L[] = new int[n1];        int R[] = new int[n2];        /*Copy data to temp arrays*/
        for (int i = 0; i < n1; ++i)
            L[i] = arr[l + i];        for (int j = 0; j < n2; ++j)
            R[j] = arr[m + 1 + j];        /* Merge the temp arrays */

        // Initial indexes of first and second subarrays
        int i = 0, j = 0;        // Initial index of merged subarry array
        int k = l;        while (i < n1 && j < n2) {            if (L[i] <= R[j]) {
                arr[k ++] = L[i ++];
            } else {
                arr[k ++] = R[j ++];
            }
        }        /* Copy remaining elements of L[] if any */
        while (i < n1) {
            arr[k ++] = L[i ++];
        }        /* Copy remaining elements of R[] if any */
        while (j < n2) {
            arr[k ++] = R[j ++];
        }
    }    // Main function that sorts arr[l..r] using
    // merge()
    void sort(int arr[], int l, int r) {        if (l < r) {            // Find the middle point
            int m = l + (r - l) / 2;            // Sort first and second halves
            sort(arr, l, m);
            sort(arr, m + 1, r);            // Merge the sorted halves
            merge(arr, l, m, r);
        }
    }
}





点击查看更多内容
0人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
手记
粉丝
55
获赞与收藏
259

关注TA,一起探索更多经验知识

同主题相似文章浏览排行榜

风间影月说签约讲师

51篇手记,涉及Java、MySQL、Redis、Spring等方向

进入讨论

Tony Bai 说签约讲师

152篇手记,涉及Go、C、Java、Python等方向

进入讨论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消