3.5 RDD内部的计算机制

RDD的多个Partition分别由不同的Task处理。Task分为两类:shuffleMapTask、resultTask。本节基于源码对RDD的计算过程进行深度解析。

3.5.1 Task解析

Task是计算运行在集群上的基本计算单位。一个Task负责处理RDD的一个Partition,一个RDD的多个Partition会分别由不同的Task去处理,通过之前对RDD的窄依赖关系的讲解,我们可以发现在RDD的窄依赖中,子RDD中Partition的个数基本都大于等于父RDD中Partition的个数,所以Spark计算中对于每一个Stage分配的Task的数目是基于该Stage中最后一个RDD的Partition的个数来决定的。最后一个RDD如果有100个Partition,则Spark对这个Stage分配100个Task。

Task运行于Executor上,而Executor位于CoarseGrainedExecutorBackend(JVM进程)中。

Spark Job中,根据Task所处Stage的位置,我们将Task分为两类:第一类为shuffleMapTask,指Task所处的Stage不是最后一个Stage,也就是Stage的计算结果还没有输出,而是通过Shuffle交给下一个Stage使用;第二类为resultTask,指Task所处Stage是DAG中最后一个Stage,也就是Stage计算结果需要进行输出等操作,计算到此已经结束;简单地说,Spark Job中除了最后一个Stage的Task为resultTask,其他所有Task都为shuffleMapTask。

3.5.2 计算过程深度解析

Spark中的Job本身内部是由具体的Task构成的,基于Spark程序内部的调度模式,即根据宽依赖的关系,划分不同的Stage,最后一个Stage依赖倒数第二个Stage等,我们从最后一个Stage获取结果;在Stage内部,我们知道有一系列的任务,这些任务被提交到集群上的计算节点进行计算,计算节点执行计算逻辑时,复用位于Executor中线程池中的线程,线程中运行的任务调用具体Task的run方法进行计算,此时,如果调用具体Task的run方法,就需要考虑不同Stage内部具体Task的类型,Spark规定最后一个Stage中的Task的类型为resultTask,因为我们需要获取最后的结果,所以前面所有Stage的Task是shuffleMapTask。

RDD在进行计算前,Driver给其他Executor发送消息,让Executor启动Task,在Executor启动Task成功后,通过消息机制汇报启动成功信息给Driver。Task计算示意图如图3-6所示。

详细情况如下:Driver中的CoarseGrainedSchedulerBackend给CoarseGrainedExecutor-Backend发送LaunchTask消息。

(1)首先反序列化TaskDescription。

Spark 2.1.1版本的CoarseGrainedExecutorBackend.scala的receive的源码如下。

1.    override def receive: PartialFunction[Any, Unit] = {
2.  .......
3.  case LaunchTask(data) =>
4.        if (executor == null) {
5.          exitExecutor(1, "Received LaunchTask command but executor was null")
6.        } else {
7.          val taskDesc = ser.deserialize[TaskDescription](data.value)
8.          logInfo("Got assigned task " + taskDesc.taskId)
9.          executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber
            = taskDesc.attemptNumber, taskDesc.name, taskDesc.serializedTask)
10.       }

图3-6 Task计算示意图

Spark 2.2.0版本的CoarseGrainedExecutorBackend.scala的receive的源码与Spark 2.1.1版本相比具有如下特点。

 上段代码中第7行调整为调用TaskDescription的decode方法,解析读取dataIn、taskId、attemptNumber、executorId、name、index等信息,读取相应的JAR、文件、属性,返回TaskDescription值。

 上段代码中第9行executor.launchTask传入的第二个参数更新为封装的taskDesc值。

1.  .......
2.     val taskDesc = TaskDescription.decode(data.value)
3.   ......
4.       executor.launchTask(this, taskDesc)
5.

launchTask中调用了decode方法,TaskDescription.scala的decode的源码如下。

1.  def decode(byteBuffer: ByteBuffer): TaskDescription = {
2.  val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
3.  val taskId = dataIn.readLong()
4.  val attemptNumber = dataIn.readInt()
5.  val executorId = dataIn.readUTF()
6.  val name = dataIn.readUTF()
7.  val index = dataIn.readInt()
8.
9.  //读文件
10.     val taskFiles = deserializeStringLongMap(dataIn)
11.
12.     //读取jars包
13.     val taskJars = deserializeStringLongMap(dataIn)
14.
15.     //读取属性
16.     val properties = new Properties()
17.     val numProperties = dataIn.readInt()
18.     for (i <- 0 until numProperties) {
19.       val key = dataIn.readUTF()
20.       val valueLength = dataIn.readInt()
21.       val valueBytes = new Array[Byte](valueLength)
22.       dataIn.readFully(valueBytes)
23.       properties.setProperty(key, new String(valueBytes,
          StandardCharsets.UTF_8))
24.     }
25.
26.     //创建一个子缓冲用于序列化任务将其变成自己的缓冲区(被反序列化后)
27.     val serializedTask = byteBuffer.slice()
28.
29.     new TaskDescription(taskId, attemptNumber, executorId, name, index,
        taskFiles, taskJars,
30.       properties, serializedTask)
31.   }
32. }

(2)Executor会通过launchTask执行Task。

(3)Executor的launchTask方法创建一个TaskRunner实例在threadPool来运行具体的Task。

Spark 2.1.1版本的Executor.scala的launchTask的源码如下。

1.    def launchTask(
2.        context: ExecutorBackend,
3.        taskId: Long,
4.        attemptNumber: Int,
5.        taskName: String,
6.        serializedTask: ByteBuffer): Unit = {
7.  //调用TaskRunner句柄创建TaskRunner对象
8.      val tr = new TaskRunner(context, taskId = taskId, attemptNumber =
        attemptNumber, taskName, serializedTask)
9.  //将创建的TaskRunner对象放入即将进行的堆栈中
10.     runningTasks.put(taskId, tr)
11. //从线程池中分配一条线程给TaskRunner
12.     threadPool.execute(tr)
13.   }

Spark 2.2.0版本的Executor.scala的launchTask的源码与Spark 2.1.1版本相比具有如下特点。

 上段代码中第3~6行调整launchTask方法的第二个参数:传入封装的taskDescription任务描述信息。

 上段代码中第8行构建TaskRunner实例传入的也是taskDescription参数。

1.  def launchTask(context: ExecutorBackend, taskDescription:
    TaskDescription): Unit = {
2.     val tr = new TaskRunner(context, taskDescription)
3.  ......

在TaskRunner的run方法首先会通过statusUpdate给Driver发信息汇报自己的状态,说明自己处于running状态。同时,TaskRunner内部会做一些准备工作,如反序列化Task的依赖,通过网络获取需要的文件、Jar等;然后反序列化Task本身。

Spark 2.1.1版本的Executor.scala的run方法的源码如下。

1.      override def run(): Unit = {
2.        threadId = Thread.currentThread.getId
3.        Thread.currentThread.setName(threadName)
4.        val threadMXBean = ManagementFactory.getThreadMXBean
5.        val taskMemoryManager = new TaskMemoryManager(env.memoryManager,
          taskId)
6.        val deserializeStartTime = System.currentTimeMillis()
7.        val deserializeStartCpuTime = if (threadMXBean.
          isCurrentThreadCpuTimeSupported) {
8.          threadMXBean.getCurrentThreadCpuTime
9.        } else 0L
10.       Thread.currentThread.setContextClassLoader(replClassLoader)
11.       val ser = env.closureSerializer.newInstance()
12.       logInfo(s"Running $taskName (TID $taskId)")
13. //通过statusUpdate给Driver发信息汇报自己状态说明自己是running状态
14.       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
15.       var taskStart: Long = 0
16.       var taskStartCpu: Long = 0
17.       startGCTime = computeTotalGcTime()
18.       try {
19. //反序列化Task的依赖
20.         val (taskFiles, taskJars, taskProps, taskBytes) =
21.           Task.deserializeWithDependencies(serializedTask)
22.
23.         Executor.taskDeserializationProps.set(taskProps)
24.         updateDependencies(taskFiles, taskJars)
25. //反序列化Task本身
26.        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.
           getContextClassLoader)
27.         task.localProperties = taskProps
28.         task.setTaskMemoryManager(taskMemoryManager)
29. ......

Spark 2.2.0版本的Executor.scala的run方法的源码与Spark 2.1.1版本相比具有如下特点。

 上段代码中第20~21行删掉,即删掉val (taskFiles, taskJars, taskProps, taskBytes) =Task.deserializeWithDependencies(serializedTask)。

 第23行代码Executor.taskDeserializationProps.set方法的参数将taskProps调整为taskDescription.properties。

 第24行代码updateDependencies方法的参数将taskFiles、taskJars分别调整为taskDescription.addedFiles、taskDescription.addedJars。

 第26行代码ser.deserialize方法的第一个参数将taskBytes调整为taskDescription.serializedTask。

 第27行代码将taskProps调整为taskDescription.properties。

1.  ......
2.     Executor.taskDeserializationProps.set(taskDescription.properties)
3.    updateDependencies(taskDescription.addedFiles, taskDescription.
      addedJars)
4.    task = ser.deserialize[Task[Any]](
5.      taskDescription.serializedTask, Thread.currentThread.
        getContextClassLoader)
6.    task.localProperties = taskDescription.properties
7.  ......

(4)调用反序列化后的Task.run方法来执行任务,并获得执行结果。

Spark 2.1.1版本的Executor.scala的run方法的源码如下。

1.     override def run(): Unit = {
2.  ......
3.   //Task计算开始时间
4.    taskStart = System.currentTimeMillis()
5.          taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
6.            threadMXBean.getCurrentThreadCpuTime
7.          } else 0L
8.          var threwException = true
9.          val value = try {
10.  //运行Task的run方法
11.           val res: Any = task.run(
12.             taskAttemptId = taskId,
13.             attemptNumber = attemptNumber,
14.             metricsSystem = env.metricsSystem)
15.           threwException = false
16.           res
17.         } finally {
18.           val   releasedLocks     =   env.blockManager.releaseAllLocksForTask
              (taskId)
19.           val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
20.
21.           if (freedMemory > 0 && !threwException) {
22.             val errMsg = s"Managed memory leak detected; size = $freedMemory
                bytes, TID = $taskId"
23.             if(conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak",false)){
24.               throw new SparkException(errMsg)
25.             } else {
26.               logWarning(errMsg)
27.             }
28.           }
29.
30.           if (releasedLocks.nonEmpty && !threwException) {
31.             val errMsg =
32.               s"${releasedLocks.size} block locks were not released by TID
                  = $taskId:\n" +
33.                 releasedLocks.mkString("[", ", ", "]")
34.             if (conf.getBoolean("spark.storage.exceptionOnPinLeak",
                false)) {
35.               throw new SparkException(errMsg)
36.             } else {
37.               logWarning(errMsg)
38.             }
39.           }
40.         }
41. //计算完成时间
42.         val taskFinish = System.currentTimeMillis()
43. ......

Spark 2.2.0版本的Executor.scala的run方法的源码与Spark 2.1.1版本相比具有如下特点。

 上段代码中第13行attemptNumber调整为taskDescription.attemptNumber。

 上段代码中第40行之后新增一段代码:在任务完成前,循环遍历任务,抓取失败的情况,打印日志提醒用户业务代码可能导致任务失败。

1.  ......
2.    attemptNumber = taskDescription.attemptNumber,
3.  ......
4.          task.context.fetchFailed.foreach { fetchFailure =>
5.  //用户代码在不抛出任何错误的情况下捕获了故障。可能是用户打算这样做的(虽然不太可能),
    //因此我们将记录一个错误并继续下去
6.            logError(s"TID ${taskId} completed successfully though
              internally it encountered " + s"unrecoverable fetch failures!
              Most likely this means user code is incorrectly " + s"swallowing
              Spark's internal ${classOf[FetchFailedException]}",
              fetchFailure) }
7.        ......

task.run方法调用了runTask的方法,而runTask方法是一个抽象方法,runTask方法内部会调用RDD的iterator()方法,该方法就是针对当前Task对应的Partition进行计算的关键所在,在处理的方法内部会迭代Partition的元素,并交给我们自定义的function进行处理。

Task.scala的run方法的源码如下。

1.    final def run(
2.        taskAttemptId: Long,
3.        attemptNumber: Int,
4.        metricsSystem: MetricsSystem): T = {
5.      ......
6.      try {
7.        runTask(context)
8.      } catch
9.  ......

task有两个子类,分别是ShuffleMapTask和ResultTask,下面分别对两者进行讲解。

1.ShuffleMapTask

ShuffleMapTask.scala的源码如下。

1.   override def runTask(context: TaskContext): MapStatus = {
2.      //使用广播变量反序列化RDD
3.      val threadMXBean = ManagementFactory.getThreadMXBean
4.      val deserializeStartTime = System.currentTimeMillis()
5.      val deserializeStartCpuTime = if (threadMXBean.
        isCurrentThreadCpuTimeSupported) {
6.        threadMXBean.getCurrentThreadCpuTime
7.      } else 0L
8.  //创建序列化器
9.      val ser = SparkEnv.get.closureSerializer.newInstance()
10. //反序列化出RDD和依赖关系
11. val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
12.       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.
          getContextClassLoader)
13. //RDD反序列化的时间
14.  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
15.     _executorDeserializeCpuTime = if (threadMXBean.
        isCurrentThreadCpuTimeSupported) {
16.       threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
17.     } else 0L
18.  //创建Shuffle的writer对象,用来将计算结果写入Shuffle管理器
19.     var writer: ShuffleWriter[Any, Any] = null
20.     try {
21. //实例化shuffleManager
22.       val manager = SparkEnv.get.shuffleManager
23.   //对writer对象赋值
24.      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId,
         context)
25. //将计算结果通过writer对象的write方法写入shuffle过程
26.  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_
     <: Product2[Any, Any]]])
27.       writer.stop(success = true).get
28.     } catch {
29.       case e: Exception =>
30.         try {
31.           if (writer != null) {
32.             writer.stop(success = false)
33.           }
34.         } catch {
35.           case e: Exception =>
36.             log.debug("Could not stop writer", e)
37.         }
38.         throw e
39.     }
40.   }

首先,ShuffleMapTask会反序列化RDD及其依赖关系,然后通过调用RDD的iterator方法进行计算,而iterator方法中进行的最终运算的方法是compute()。

RDD.scala的iterator方法的源码如下。

1.   final def iterator(split: Partition, context: TaskContext): Iterator[T]
     = {//判断此RDD的持久化等级是否为NONE(不进行持久化)
2.    if (storageLevel != StorageLevel.NONE) {
3.      getOrCompute(split, context)
4.    } else {
5.      computeOrReadCheckpoint(split, context)
6.    }
7.  }

其中,RDD.scala的computeOrReadCheckpoint的源码如下。

1.   private[spark] def computeOrReadCheckpoint(split: Partition, context:
     TaskContext): Iterator[T] =
2.  {
3.    if (isCheckpointedAndMaterialized) {
4.      firstParent[T].iterator(split, context)
5.    } else {
6.      compute(split, context)
7.    }
8.  }

RDD的compute方法是一个抽象方法,每个RDD都需要重写的方法。

此时,选择查看MapPartitionsRDD已经实现的compute方法,可以发现compute方法的实现是通过f方法实现的,而f方法就是我们创建MapPartitionsRDD时输入的操作函数。

1.      private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
2.      var prev: RDD[T],
3.      f: (TaskContext, Int, Iterator[T]) => Iterator[U],  //(TaskContext,
        partition index, iterator)
4.      preservesPartitioning: Boolean = false)
5.    extends RDD[U](prev) {
6.
7.    override val partitioner = if (preservesPartitioning) firstParent[T].
      partitioner else None
8.
9.    override def getPartitions: Array[Partition] = firstParent[T].
      partitions
10.
11.   override def compute(split: Partition, context: TaskContext): Iterator
      [U] =
12.     f(context, split.index, firstParent[T].iterator(split, context))
13.
14.   override def clearDependencies() {
15.     super.clearDependencies()
16.     prev = null
17.   }
18. }

MapPartitionsRDD.scala的源码如下。

注意:通过迭代器的不断叠加,将每个RDD的小函数合并成一个大的函数流

然后在计算具体的Partition之后,通过shuffleManager获得的shuffleWriter把当前Task计算的结果根据具体的shuffleManager实现写入到具体的文件中,操作完成后会把MapStatus发送给Driver端的DAGScheduler的MapOutputTracker。

2.ResultTask

Driver端的DAGScheduler的MapOutputTracker把shuffleMapTask执行的结果交给ResultTask,ResultTask根据前面Stage的执行结果进行shuffle后产生整个job最后的结果。

ResultTask.scala的runTask的源码如下。

1.   override def runTask(context: TaskContext): U = {
2.     //使用广播变量反序列化RDD及函数
3.     val threadMXBean = ManagementFactory.getThreadMXBean
4.     val deserializeStartTime = System.currentTimeMillis()
5.     val deserializeStartCpuTime = if (threadMXBean.
       isCurrentThreadCpuTimeSupported) {
6.       threadMXBean.getCurrentThreadCpuTime
7.     } else 0L
8.     //创建序列化器
9.     val ser = SparkEnv.get.closureSerializer.newInstance()
10.    //反序列RDD和func处理函数
11.    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T])
       => U)](
12.      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.
         getContextClassLoader)
13.    _executorDeserializeTime = System.currentTimeMillis() -
       deserializeStartTime
14.    _executorDeserializeCpuTime = if (threadMXBean.
       isCurrentThreadCpuTimeSupported) {
15.      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
16.    } else 0L
17.
18.    func(context, rdd.iterator(partition, context))
19.  }

而ResultTask的runTask方法中反序列化生成func函数,最后通过func函数计算出最终的结果。