归并排序(Merge Sort)

目标:将一个数组按照由低到高(或者由高到低)的顺序排序。

归并排序算法由 冯诺依曼 1945年发明。它是一种高效的排序算法,其最好、平均、最差时间复杂度都是 O(n log n)

归并排序算法使用了 分治法(divide and conquer) ,即将一个大的问题分成更小的问题并解决它们。我认为归并算法就是 拆分 合并

假设你需要将一个大小为 n 的数组排序。归并排序算法的排序步骤是:

  • 将所有的数字放入一个无序的堆。
  • 将堆分成两部分,现在你有 两个无序的堆
  • 持续将无序的堆拆分,直到无法再拆分为止,你将得到 n 个堆,每一个堆中有一个数字。
  • 现在开始将这些堆按照一定顺序按对 合并 。每一次合并过程中,将堆中的数字放入有序的队列。这一点很容易实现,因为每一个独立的堆中的内容都是有序的。
  • 假设你有一个数组 [2, 1, 5, 4, 9] 。这是一个无序的堆。拆分的目的就是一直拆分,直到不可再拆分为止。

    首先,将数组分成两堆: [2, 1] [5, 4, 9] 。还能继续拆吗?可以!

    先看左边的堆 [2, 1] 分成 [2] [1] ,还能继续拆吗?不能,现在来看另一个堆。

    [5, 4, 9] 分成 [5] [4, 9] , 毫无疑问 [5] 已经不能再分,但是 [4, 9] 可以分成 [4] [9]

    拆分的过程结束时,有这样一些堆: [2] [1] [5] [4] [9] 。注意每一个堆中都只有一个数字。

    现在有了一堆拆开的数组,你需要在 合并 它们的同时 对它们进行排序 。记住,合并是逐渐地将很多小数组合并,而不是一次合并出一个大的数组。在每一次合并迭代中,你需要关注的是将一个堆与另一个堆合并。

    现在这些堆就是 [2] [1] [5] [4] [9] ,第一次合并的结果是 [1, 2] [4, 5] [9] 。由于 [9] 落单了,在这一轮合并中它无法与其他堆合并。

    下一轮是合并 [1, 2] [4, 5] ,结果是 [1, 2, 4, 5] [9] 继续落单。

    现在只有两个堆 [1, 2, 4, 5] [9] ,现在是时候将它们合并了,结果就是一个有序的数组 [1, 2, 4, 5, 9]

    递归法(Top-down)

    我们先来看一个 Kotlin 实现的归并排序:

    fun mergeSort(array: IntArray): IntArray {
            if (array.size < 2) {
                return array                                                            //1
            val middleIndex = array.size / 2                                            //2
            val leftArray = mergeSort(array.sliceArray(0 until middleIndex))    //3
            val rightArray = mergeSort(array.sliceArray(middleIndex until array.size))//4
            return merge(leftArray, rightArray)                                               //5
    

    我们逐行解读这个函数:

  • 如果数组是空的或者只有一个元素,就没有必要继续拆分,直接返回即可。
  • 找到数组的中间位置。
  • 根据上一步找到的中间位置,递归拆分数组的左半部分。
  • 同样递归拆分数组的右半部分。
  • 最后,将所有的值合并到一起,保证合并后的结果是有序的。
  • 再来看一下合并算法:

    private fun merge(leftPile: IntArray, rightPile: IntArray): IntArray {
            var leftIndex = 0
            var rightIndex = 0
            var orderedPile = intArrayOf()
            while (leftIndex < leftPile.size && rightIndex < rightPile.size) {
                when {
                    leftPile[leftIndex] < rightPile[rightIndex] -> {
                        orderedPile += leftPile[leftIndex]
                        leftIndex++
                    leftPile[leftIndex] > rightPile[rightIndex] -> {
                        orderedPile += rightPile[rightIndex]
                        rightIndex++
                    else -> {
                        orderedPile += leftPile[leftIndex]
                        leftIndex++
                        orderedPile += rightPile[rightIndex]
                        rightIndex++
            while (leftIndex < leftPile.size) {
                orderedPile += leftPile[leftIndex]
                leftIndex++
            while (rightIndex < rightPile.size) {
                orderedPile += rightPile[rightIndex]
                rightIndex++
            return orderedPile
    

    这个函数看起来可能很可怕,但是它其实很简单:

  • 在合并的过程中需要两个游标用于跟踪两个数组的合并过程。
  • 这个是存放合并结果的数组。一开始它是一个空的数组,但是你在随后的步骤中会将其他数组中的元素添加进去。
  • 这个 while 循环就从左往右逐一比较两个数组中的元素并将它们添加到 orderedPile,这样就保证了结果是有序的。
  • 当前面的 while 结束的时候,意味着 leftPile 或者 rightPile 已经完全合并到了 orderedPile。这个时候,就不再需要比较,直接将另一个数组中剩余的部分直接添加到 orderedPile
  • 为了说明 merge() 函数的运行过程,我们假设现在有两个堆:leftPile = [1, 7, 8] 以及 rightPile = [3, 6, 9]。注意每一个堆中的元素都已经是有序的 -- 这一点在归并排序中是肯定成立的。下面是将两个堆合并的过程:

    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ ]
      l              r
    

    leftIndex(这我们用 l代表)指向 leftPile 的第一个元素 1。rightIndex(我们用 r 代表)指向 3。所以,添加到 orderedPile中的第一个元素时 1,同时将 left index l 向右移动一个位置。

    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ 1 ]
      -->l           r
    

    现在 l 指向 7, 但是 r 还指向 3,我们将最小的那一个元素加入有序堆,所以应该是 3。现在的情况是:

    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ 1, 3 ]
         l           -->r
    

    重复以上步骤。每一步我们都从 leftPile 或者 rightPile中取一个最小值放入 orderedPile:

    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ 1, 3, 6 ]
         l              -->r
    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ 1, 3, 6, 7 ]
         -->l              r
    leftPile       rightPile       orderedPile
    [ 1, 7, 8 ]    [ 3, 6, 9 ]     [ 1, 3, 6, 7, 8 ]
            -->l           r
    

    现在左侧的堆中已经没有数据。我们只需要将右侧堆中剩余的项目添加到 orderedPile。最终的结果是: [ 1, 3, 6, 7, 8, 9 ]

    注意这个算法非常简单:它从左向右遍历两个堆,每一步都取一个最小的数字。最终的结果能够有序是因为我们保证了合并的每一个堆都已经是有序的。

    迭代法(Bottom-up)

    上面我们所实现的归并排序算法称为递归法,因为他首先将数组拆分成更小的堆然后再合并。在排序数组的时候,实际上你可以跳过拆分的步骤立即执行数组元素的合并。这就是所谓的迭代法。

    是时候加强一点难度了。先来看一个完整的迭代法实现:

    fun <T> mergeSortBottomUp(array: Array<T>, isOrderedBofore: (T, T) -> Boolean): Array<T> {
            val n = array.size
            val z = arrayOf(array.clone(), array.clone())          //1
            var d = 0
            var width = 1
            while (width < n) {                                   //2
                var i = 0
                while (i < n) {                                   //3
                    var j = i
                    var l = i
                    var r = i + width
                    val lmax = minOf(l + width, n)
                    val rmax = minOf(r + width, n)
                    while (l < lmax && r < rmax) {               //4
                        if (isOrderedBofore(z[d][l], z[d][r])) {
                            z[1 - d][j] = z[d][l]
                        } else {
                            z[1 - d][j] = z[d][r]
                    while (l < lmax) {
                        z[1 - d][j] = z[d][l]
                    while (r < rmax) {
                        z[1 - d][j] = z[d][r]
                    i += width * 2
                width *= 2
                d = 1 - d                   //5
            return z[d]
    

    这个函数看起来比递归法版本要恐怖多了,但是注意函数体中包含了和 merge() 方法一样的 while 循环。

    我们先用一个实例来说明一下迭代排序的排序过程,假设有一个数组 [6, 2, 8, 1, 5, 4, 12, 3, 9] 需要排序。一开始我们就申请了一个二维数组 z ,里面存放了两份待排序的数组:

    [[6, 2, 8, 1, 5, 4, 12, 3, 9],
     [6, 2, 8, 1, 5, 4, 12, 3, 9]]
    

    第一步从二维数组的第一个元素z[0] 中取出数字,按照两两结对排序合并(也就是合并大小为1的堆,对应 width = 1 ),并将结果存入z[1]:

    [[6, 2, 8, 1, 5, 4, 12, 3, 9],
      ↓  ↓  ↓  ↓  ↓  ↓   ↓  ↓  ↓
       --    --    --     --   |
      ↓  ↓  ↓  ↓  ↓  ↓   ↓  ↓  ↓
     [2, 6, 1, 8, 4, 5, 3, 12, 9]]
    

    然后将合并的宽度加倍(width = 2),从 z[1]中取出数字合并排序,将结果存入 z[0]

    [[1, 2, 6, 8, 3, 4, 5, 12, 9],
      ↑        ↑  ↑         ↑  ↑
      ----------   ---------   |
      ↑        ↑  ↑         ↑  ↑
     [2, 6, 1, 8, 4, 5, 3, 12, 9]]
    

    如此循环,每次都需要将宽度加倍,并且切换待排序数据的来源以及排序结果的存入位置,对应代码中的 d = 1- d,从 z[d]中读数据,排序结果存入z[1-d]

    [[1, 2, 6, 8, 3, 4, 5, 12, 9], ↓ ↓ ↓ ----------------------- | ↓ ↓ ↓ [1, 2, 3, 4, 5, 6, 8, 12, 9]] [[1, 2, 3, 4, 5, 6, 8, 9, 12], ↑ ↑ -------------------------- ↑ ↑ [1, 2, 3, 4, 5, 6, 8, 12, 9]]

    最终得到排序好的数组z[d] : [1, 2, 3, 4, 5, 6, 8, 9, 12]

    再来看一下代码中的关键逻辑:

  • 归并排序算法需要一个临时的数组作为工作区,因为你不能在合并左右堆的同时覆盖它们的内容。因为每次申请一个新的数组对资源是极大的浪费,所以为我们使用了两个数组,然后通过 d 的值在两个数组之间切换,d 的值只能是 0 或者 1.数组 z[d] 用于读取数据,z[1-d] 用于写入数据。这就是所谓的双缓冲区
  • 从概念上讲,迭代法版本和递归法版本的工作原理是一样的。首先它将只有一个元素的堆合并,然后合并有两个元素的堆,然后是有4个元素的堆,等等。堆的大小由 width 决定。一开始,width1,但是在每一次循环迭代结束时,我们将它的值乘以2,所以外层的循环决定了每次合并的堆的大小,并且每次循环之后待合并的子数组都会增大。
  • 这个内部循环逐一检查堆中的每一个元素并将每一对堆合并成一个更大的堆。合并后的结果存入数组 z[1-d]
  • 这里的逻辑和递归版本是一样的。主要的区别是使用了双缓冲区,所以数据从 z[d] 中读出来然后存入 z[1-d]。同时使用了 isOrderedBofore来比较元素,而不是单纯的比较数字大小。所以这个归并排序算法是一个通用算法,你可以用它来排序任何类型的对象。
  • 这个时候,从z[d] 中读取出来的大小为 width 的堆已经被合并成一个大小为 width*2 的大堆并存放在 z[1-d]。在这里我们需要交换两个数组,所以下一步我们就从刚刚创建的新堆中读取数据。
  • 这个函数是一个泛型函数,所以你可以用它来排序任何你需要的类型,只有你提供合适的 isOrderedBofore函数来比较元素。

    使用范例:

            val numList: Array<Int> = arrayOf(21, 3, 12, 45, 6, 9, 56, 67, 1, 43, 0)
            val sortedNum = mergeSortBottomUp(numList) { x, y -> x < y }
            val sortedNumUp = mergeSortBottomUp(numList) { x, y -> x > y }
            val strList: Array<String> = arrayOf("e", "m", "ec", "q", "a", "dx", "adxz", "rf", "po")
            val sortedStr = mergeSortBottomUp(strList) { x: String, y: String -> x < y }
            val sortedStrUp = mergeSortBottomUp(strList) { x: String, y: String -> x > y }
    

    归并排序算法的运行速度取决于待排序数组的大小。数组越大,需要做的事情就越多。

    不管待排序的数组初始状态是否有序,都不会影响归并排序算法的运行速度,因为不管初始状态是否有序,拆分的步骤都不会变。

    所有,它的时间复杂度(最优、平均、最差)都是O(n log n)