shuffle讀過程源碼分析 上一篇中,我們分析了shuffle在map階段的寫過程。簡單回顧一下,主要是將ShuffleMapTask計算的結果數據在記憶體中按照分區和key進行排序,過程中由於記憶體限制會溢寫出多個磁碟文件,最後會對所有的文件和記憶體中剩餘的數據進行歸併排序並溢寫到一個文件中,同時會記 ...
shuffle讀過程源碼分析
上一篇中,我們分析了shuffle在map階段的寫過程。簡單回顧一下,主要是將ShuffleMapTask計算的結果數據在記憶體中按照分區和key進行排序,過程中由於記憶體限制會溢寫出多個磁碟文件,最後會對所有的文件和記憶體中剩餘的數據進行歸併排序並溢寫到一個文件中,同時會記錄每個分區(reduce端分區)的數據在文件中的偏移,並且把分區和偏移的映射關係寫到一個索引文件中。
好了,簡單回顧了寫過程後,我們不禁思考,reduce階段的數據讀取的具體過程是什麼樣的?數據讀取的發生的時機是什麼?
首先應該回答後一個問題:數據讀取發生的時機是什麼?我們知道,rdd的計算鏈根據shuffle被切分為不同的stage,一個stage的開始階段一般就是從讀取上一階段的數據開始,也就是說stage讀取數據的過程其實就是reduce過程,然後經過該stage的計算鏈後得到結果數據,再然後就會把這些數據寫入到磁碟供下一個stage讀取,這個寫入的過程實際上就是map輸出過程,而這個過程我們之前已經分析過了。本篇我們要分析的是reduce階段讀取數據的過程。
啰嗦了這麼一大段,其實就是為了引出數據讀取的入口,還是要回到ShuffleMapTask,這裡我只貼部分代碼:
// shuffle管理器
val manager = SparkEnv.get.shuffleManager
// 獲取一個shuffle寫入器
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
// 這裡可以看到rdd計算的核心方法就是iterator方法
// SortShuffleWriter的write方法可以分為幾個步驟:
// 將上游rdd計算出的數據(通過調用rdd.iterator方法)寫入記憶體緩衝區,
// 在寫的過程中如果超過 記憶體閾值就會溢寫磁碟文件,可能會寫多個文件
// 最後將溢寫的文件和記憶體中剩餘的數據一起進行歸併排序後寫入到磁碟中形成一個大的數據文件
// 這個排序是先按分區排序,在按key排序
// 在最後歸併排序後寫的過程中,沒寫一個分區就會手動刷寫一遍,並記錄下這個分區數據在文件中的位移
// 所以實際上最後寫完一個task的數據後,磁碟上會有兩個文件:數據文件和記錄每個reduce端partition數據位移的索引文件
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
// 主要是刪除中間過程的溢寫文件,向記憶體管理器釋放申請的記憶體
writer.stop(success = true).get
讀取數據的代碼其實就是rdd.iterator(partition, context),
iterator方法主要是處理rdd緩存的邏輯,如果有緩存就會從緩存中讀取(通過BlockManager),如果沒有緩存就會進行實際的計算,發現最終調用RDD.compute方法進行實際的計算,這個方法是一個抽象方法,是由子類實現的具體的計算邏輯,用戶代碼中對於RDD做的一些變換操作實際上最終都會體現在compute方法中。
另一方面,我們知道,map,filter這類運算元不是shuffle操作,不會導致stage的劃分,所以我們想看shuffle讀過程就要找一個Shuffle類型的操作,我們看一下RDD.groupBy,最終調用了groupByKey方法
RDD.groupByKey
def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
val createCombiner = (v: V) => CompactBuffer(v)
val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v
val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2
val bufs = combineByKeyWithClassTag[CompactBuffer[V]](
createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false)
bufs.asInstanceOf[RDD[(K, Iterable[V])]]
}
最終調用了combineByKeyWithClassTag
RDD.combineByKeyWithClassTag
做一些判斷,檢查一些非法情況,然後處理一下分區器,最後返回一個ShuffledRDD,所以接下來我們分析一下ShuffleRDD的compute方法
def combineByKeyWithClassTag[C](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
// 如果key是Array類型,是不支持在map端合併的
// 並且也不支持HashPartitioner
if (keyClass.isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
}
if (partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("HashPartitioner cannot partition array keys.")
}
}
// 聚合器,用於對數據進行聚合
val aggregator = new Aggregator[K, V, C](
self.context.clean(createCombiner),
self.context.clean(mergeValue),
self.context.clean(mergeCombiners))
// 如果分區器相同,就不需要shuffle了
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(iter => {
val context = TaskContext.get()
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else {
// 返回一個ShuffledRDD
new ShuffledRDD[K, V, C](self, partitioner)
.setSerializer(serializer)
.setAggregator(aggregator)
.setMapSideCombine(mapSideCombine)
}
}
ShuffleRDD.compute
通過shuffleManager獲取一個讀取器,數據讀取的邏輯在讀取器里。
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
// 通過shuffleManager獲取一個讀取器
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
SortShuffleManager.getReader
無需多說,直接看BlockStoreShuffleReader
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
BlockStoreShuffleReader.read
顯然,這個方法才是核心所在。總結一下主要步驟:
- 獲取一個包裝的迭代器ShuffleBlockFetcherIterator,它迭代的元素是blockId和這個block對應的讀取流,很顯然這個類就是實現reduce階段數據讀取的關鍵
- 將原始讀取流轉換成反序列化後的迭代器
- 將迭代器轉換成能夠統計度量值的迭代器,這一系列的轉換和java中對於流的各種裝飾器很類似
- 將迭代器包裝成能夠相應中斷的迭代器。每讀一條數據就會檢查一下任務有沒有被殺死,這種做法是為了儘量及時地響應殺死任務的請求,比如從driver端發來殺死任務的消息。
- 利用聚合器對結果進行聚合。這裡再次利用了AppendonlyMap這個數據結構,前面shuffle寫階段也用到這個數據結構,它的內部是一個以數組作為底層數據結構的,以線性探測法線性的hash表。
- 最後對結果進行排序。
所以很顯然,我們想知道的shuffle讀取數據的具體邏輯就藏在ShuffleBlockFetcherIterator中
private[spark] class BlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {
private val dep = handle.dependency
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
// 獲取一個包裝的迭代器,它迭代的元素是blockId和這個block對應的讀取流
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
// 將原始讀取流轉換成反序列化後的迭代器
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
// 轉換成能夠統計度量值的迭代器,這一系列的轉換和java中對於流的各種裝飾器很類似
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
// 每讀一條數據就會檢查一下任務有沒有被殺死,
// 這種做法是為了儘量及時地響應殺死任務的請求,比如從driver端發來殺死任務的消息
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
// 利用聚合器對結果進行聚合
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
// 最後對結果進行排序
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
}
ShuffleBlockFetcherIterator
這個類比較複雜,仔細看在類初始化的代碼中會調用initialize方法。
其次,我們應該註意它的構造器中的參數,
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
// 如果沒有啟用外部shuffle服務,就是BlockTransferService
blockManager.shuffleClient,
blockManager,
// 通過mapOutputTracker組件獲取每個分區對應的數據block的物理位置
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
// 獲取幾個配置參數
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
ShuffleBlockFetcherIterator.initialize
- 首先將本地的block和遠程的block分隔開
- 然後開始發送請求拉取遠程數據。這個過程中會有一些約束條件限制拉取數據請求的數量,主要是正在獲取的總數據量的限制,請求併發數限制;每個遠程地址同時拉取的塊數也會有限制,但是這個閾值預設是Integer.MAX_VALUE
- 獲取本地的block數據
其中,獲取本地數據較為簡單,主要就是通過本節點的BlockManager來獲取塊數據,並通過索引文件獲取數據指定分區的數據。
我們著重分析遠程拉取的部分
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 向TaskContext中添加一個回調,在任務完成時做一些清理工作
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
// 將本地的block和遠程的block分隔開
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
// 發送遠程拉取數據的請求
// 儘可能多地發送請求
// 但是會有一定的約束:
// 全局性的約束,全局拉取數據的rpc線程併發數,全局拉取數據的數據量限制
// 每個遠程地址的限制:每個遠程地址同時拉取的塊數不能超過一定閾值
fetchUpToMaxBytes()
// 記錄已經發送的請求個數,仍然會有一部分沒有發送請求
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
// 獲取本地的block數據
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
ShuffleBlockFetcherIterator.splitLocalRemoteBlocks
我們首先來看如何切分遠程和本地的數據塊,總結一下這個方法:
- 首先將同時拉取的數據量的大小除以5作為每次請求拉取的數據量的限制,這麼做的原因是為了允許同時從5個節點拉取數據,因為節點的網路環境可能並不穩定,同時從多個節點拉取數據有助於減少網路波動對性能帶來的影響,而對整體的同時拉取數據量的限制主要是為了限制本機網路流量的使用
- 迴圈遍歷每一個節點地址(這裡是BlockManagerId),
- 如果地址與本機地址相同,那麼對應的blocks就是本地block
對於遠程block,則要根據同時拉取數據量大小的限制將每個節點的所有block切分成多個請求(FetchRequest),確保這些請求單次的拉取數據量不會太大
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. // 之所以將請求大小減小到maxBytesInFlight / 5, // 是為了並行化地拉取數據,最毒允許同時從5個節點拉取數據 val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks) // 記錄總的block數量 var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size // 如果地址與本地的BlockManager相同,就是本地block if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks if (size > 0) { curBlocks += ((blockId, size)) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } // 如果超過每次請求的數據量限制,那麼創建一次請求 if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) logDebug(s"Creating fetch request of $curRequestSize at $address " + s"with ${curBlocks.size} blocks") curBlocks = new ArrayBuffer[(BlockId, Long)] curRequestSize = 0 } } // Add in the final request // 掃尾方法,最後剩餘的塊創建一次請求 if (curBlocks.nonEmpty) { remoteRequests += new FetchRequest(address, curBlocks) } } } logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") remoteRequests }
ShuffleBlockFetcherIterator.fetchUpToMaxBytes
回到initialize方法中,在完成本地與遠程block的切分後,我們得到了一批封裝好的數據拉取請求,將這些請求加到隊列中,接下來要做的是通過rpc客戶端發送這些請求,
這個方法邏輯還是相對簡單,主要邏輯就是兩個迴圈,先發送延緩隊列中的請求,然後發送正常的請求;之所以會有延緩隊列是因為這些請求在第一次待發送時因為數據量超過閾值或者請求數量超過閾值而不能發送,所以就被放到延緩隊列中,而這裡的處理也是優先發送延緩隊列中的請求。每個請求在發送前必須要滿足下麵幾個條件才會被髮送:
- 當前正在拉取的數據量不能超過閾值maxReqsInFlight(預設48m);這裡會有一個問題,如果某個block的數據量超過maxReqsInFlight值呢?這種情況下會等當前已經沒有進行中的數據拉取請求才會發送這個請求,因為在對當前請求數據量閾值進行判斷時會檢查bytesInFlight == 0,如果這個條件滿足就不會檢查本次請求的數據量是否會超過閾值。
- 當前正在拉取的請求數據量不能超過閾值(預設Int.MaxValue)
- 每個遠程地址的同時請求數量也會有限制(預設Int.MaxValue)
最後符合條件的請求就會被髮送,這裡要提出的一點是如果一次請求的數據量超過maxReqSizeShuffleToMem值,那麼就會寫入磁碟的一個臨時文件中,而這個閾值的預設值是Long.MaxValue,所以預設情況下是沒有限制的。
// 發送請求 // 儘可能多地發送請求 // 但是會有一定的約束: // 全局性的約束,全局拉取數據的rpc線程併發數,全局拉取數據的數據量限制 // 每個遠程地址的限制:每個遠程地址同時拉取的塊數不能超過一定閾值 private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host // immediately, defer the request until the next time it can be processed. // Process any outstanding deferred fetch requests if possible. if (deferredFetchRequests.nonEmpty) { for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { while (isRemoteBlockFetchable(defReqQueue) && !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { val request = defReqQueue.dequeue() logDebug(s"Processing deferred fetch request for $remoteAddress with " + s"${request.blocks.length} blocks") send(remoteAddress, request) if (defReqQueue.isEmpty) { deferredFetchRequests -= remoteAddress } } } } // Process any regular fetch requests if possible. while (isRemoteBlockFetchable(fetchRequests)) { val request = fetchRequests.dequeue() val remoteAddress = request.address // 如果超過了同時拉取的塊數的限制,那麼將這個請求放到延緩隊列中,留待下次請求 if (isRemoteAddressMaxedOut(remoteAddress, request)) { logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) defReqQueue.enqueue(request) deferredFetchRequests(remoteAddress) = defReqQueue } else { send(remoteAddress, request) } } // 發送一個請求,並且累加記錄請求的塊的數量, // 以用於在下次請求時檢查請求塊的數量是否超過閾值 def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { sendRequest(request) numBlocksInFlightPerAddress(remoteAddress) = numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } // 這個限制是對所有的請求而言,不分具體是哪個遠程節點 // 檢查當前的請求的數量是否還有餘量 // 當前請求的大小是否還有餘量 // 這主要是為了限制併發數和網路流量的使用 def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { fetchReqQueue.nonEmpty && (bytesInFlight == 0 || (reqsInFlight + 1 <= maxReqsInFlight && bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) } // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a // given remote address. // 檢測正在拉取的塊的數量是否超過閾值 // 每個地址都有一個同事拉取塊數的限制 def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > maxBlocksInFlightPerAddress } }
ShuffleBlockFetcherIterator.next
通過上一個方法的分析,我們能夠看出來,初始化時發起的拉取數據的請求並未將所有請求全部發送出去,並且還會有請求因為超過閾值而被放入延緩隊列中,那麼這些未發送的請求是什麼時候被再次發送的呢?答案就在next方法中。我們知道ShuffleBlockFetcherIterator是一個迭代器,所以外部調用者對元素的訪問是通過next方法,所以很容易想到next方法中肯定會有發送拉取數據請求的邏輯。
總結一下:
- 首先從結果隊列中獲取一個拉取成功的結果(結果隊列是一個阻塞隊列,如果沒有拉取成功的結果會阻塞調用者)
- 拿到一個結果後檢查這個結果是拉取成功還是拉取失敗,如果失敗則直接拋異常(重試的邏輯實在rpc客戶端實現的,不是在這裡實現)
- 如果是一個成功的結果,首先要更新一下一些任務度量值,更新一些內部的簿記量,如正在拉取的數據量
- 將拉取到的位元組緩衝包裝成一個位元組輸入流
- 通過外部傳進來的函數對流再包裝一次,通過外部傳進來的函數再包裝一次,一般是解壓縮和解密
- 而且流被壓縮或者加密過,如果塊的大小比較小,那麼要將這個流拷貝一份,這樣就會實際出發解壓縮和解密,以此來儘早暴露塊損壞的 問題
最後一句關鍵語句,再次發起一輪拉取數據請求的發 送,因為經過next處理之後,已經有拉取成功的數據了,正在拉取的數據量和請求數量可能減小了,這就為發送新的請求騰出空間
override def next(): (BlockId, InputStream) = { if (!hasNext) { throw new NoSuchElementException } numBlocksProcessed += 1 var result: FetchResult = null var input: InputStream = null // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. // For local shuffle block, throw FailureFetchResult for the first IOException. while (result == null) { val startFetchWait = System.currentTimeMillis() result = results.take() val stopFetchWait = System.currentTimeMillis() shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 // 主要是更新一些度量值 shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } shuffleMetrics.incRemoteBlocksFetched(1) } bytesInFlight -= size if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } // 將位元組緩衝包裝成一個位元組輸入流 val in = try { buf.createInputStream() } catch { // The exception could only be throwed by local shuffle block case e: IOException => assert(buf.isInstanceOf[FileSegmentManagedBuffer]) logError("Failed to create input stream from local block", e) buf.release() throwFetchFailedException(blockId, address, e) } // 通過外部傳進來的函數再包裝一次,一般是增加壓縮和加密的功能 input = streamWrapper(blockId, in) // Only copy the stream if it's wrapped by compression or encryption, also the size of // block is small (the decompressed block is smaller than maxBytesInFlight) // 如果塊的大小比較小,而且流被壓縮或者加密過,那麼需要將這個流拷貝一份 if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { val originalInput = input val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) try { // Decompress the whole block at once to detect any corruption, which could increase // the memory usage tne potential increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. Utils.copyStream(input, out) out.close() input = out.toChunkedByteBuffer.toInputStream(dispose = true) } catch { case e: IOException => buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] || corruptedBlocks.contains(blockId)) { throwFetchFailedException(blockId, address, e) } else { logWarning(s"got an corrupted block $blockId from $address, fetch again", e) corruptedBlocks += blockId fetchRequests += FetchRequest(address, Array((blockId, size))) result = null } } finally { // TODO: release the buf here to free memory earlier originalInput.close() in.close() } } // 拉取失敗,拋異常 // 這裡思考一下:拉取塊數據肯定是有重試機制的,但是這裡拉取失敗之後直接拋異常是為何?? // 答案是:重試機制並不是正在這裡實現 的,而是在rpc客戶端發送拉取請求時實現了重試機制 // 也就是說如果到這裡是失敗的話,說明已經經過重試後還是失敗的,所以這裡直接拋異常就行了 case FailureFetchResult(blockId, address, e) => throwFetchFailedException(blockId, address, e) } // Send fetch requests up to maxBytesInFlight // 這裡再次發送拉取請求,因為前面已經有成功拉取到的數據, // 所以正在拉取中的數據量就會減小,所以就能為新的請求騰出空間 fetchUpToMaxBytes() } currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) }
總結
到此,我們就把shuffle讀的過程大概分析完了。整體下來,感覺主幹邏輯不是很複雜,但是裡面有很多細碎邏輯,所以上面的分析還是比較碎,這裡把整個過程的主幹邏輯再提煉一下,以便能有個整體的認識:
- 首先,在一些shuffle類型的RDD中,它的計算方法compute會通過ShuffleManager獲取一個block數據讀取器BlockStoreShuffleReader
- 通過BlockStoreShuffleReader中的read方法進行數據的讀取,一個reduce端分區的數據一般會依賴於所有的map端輸出的分區數據,所以數據一般會在多個executor(註意是executor節點,通過BlockManagerId唯一標識,一個物理節點可能會運行多個executor節點)節點上,而且每個executor節點也可能會有多個block,在shuffle寫過程的分析中我們也提到,每個map最後時輸出一個數據文件和索引文件,也就是一個block,但是因為一個節點
- 這個方法通過ShuffleBlockFetcherIterator對象封裝了遠程拉取數據的複雜邏輯,並且最終將拉取到的數據封裝成流的迭代器的形式
- 對所有的block的流進行層層裝飾,包括反序列化,任務度量值(讀入數據條數)統計,每條數據可中斷,
- 對數據進行聚合
- 對聚合後的數據進行排序
所以,從這裡我們也能看出來,新版的shuffle機制中,也就是SortShuffleManager,用戶代碼對於shuffle之後的rdd拿到的是經過排序的數據(如果指定排序器的話)。