LeetCode_4_两个排序数组的中位数

LeetCode_4_两个排序数组的中位数

题目描述

两个排序数组的中位数 Median of Two Sorted Arrays

给定两个大小为 m 和 n 的有序数组 nums1nums2
请找出这两个有序数组的中位数。要求算法的时间复杂度为 $O(\log (m+n))$ 。

示例 1:

nums1 = [1, 3]
nums2 = [2]
中位数是 2.0

示例 2:

nums1 = [1, 2]
nums2 = [3, 4]
中位数是 (2 + 3)/2 = 2.5

Tags: 数组,二分查找,分治算法

解决思路

题目给定两个有序的数组,找出中位数并不困难,最简单的一个循环就行,时间复杂度 $O(m+n)$ 。但是题目要求时间复杂度为 $O(\log (m+n))$ 就是一个大问题,具体思路采用二分查找,如下:

其中k 表示要求的中位数的位置,初始化 $k = (m + n + 1) / 2$ 和 $k = (m + n + 2) / 2$,即 $m + n$ 是偶数时为中间两个, $m + n$ 是奇数时为中位数的后面一个。

首先将 A 划分为左右两部分,其中 $i = k / 2$ ,$length(left_A) = i,length(right_A) = k - i$

left_A             |        right_A
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]

然后将 B 划分为左右两部分,其中 $j = k / 2$ ,$length(left_A) = j,length(right_A) = k - j$

left_B             |        right_B
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

则,A、B 数组可分为左右两部分,如下:

left          |          right
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

划重点来了

如果,$A[i-1] > B[j-1]$,则说明,中位数一定在 $left_A、right_A、right_B$中,这样就排除了$left_B$,想想这是为什么?
因为有$A[i-1] > B[j-1]$ 不妨设 $A[0] > B[p],p=0,1,2,\ldots,j-1$ 则中位数一定不在$left_B$中,所以缩小查找范围。

left          |          right
A[0], A[1], ..., A[i-1]  |  A[i], A[i+1], ..., A[m-1]
                         |  B[j], B[j+1], ..., B[n-1]

如果,$A[i-1] < B[j-1]$,则说明,中位数一定在 $right_A、left_B、right_B$中,这样就排除了$left_A$,想想这是为什么?理由同上。

left          |          right
                         |  A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1]  |  B[j], B[j+1], ..., B[n-1]

接下来,将缩小范围的新数组 A 或者 B 按同样的划分,同样的缩小范围,就可得到最后结果。

由上可得,方法为二分查找,定义递归函数,确定递归函数的出口为:

  • 当某个数组的数都被取完了,那么直接返回另一个数组剩下的的第 k 个元素即可。
  • 当 k = 1 时,也就是只需再找一个数即可,也就是取两者当前较小的那个即可。

如何确定奇数个数取中位数,偶数个数取中间两个数的平均数,我想,利用整数除法的规则,很容易就得到 中位数为 (m + n + 1) >> 1 这个数与 (m + n + 2) >> 1 取平均值即可,当个数为奇数时,两个数就是中间一个数。

Code

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int m = nums1.size();
        int n = nums2.size();
        int l = (m + n + 1) >> 1;
        int r = (m + n + 2) >> 1;
        int* ln = m > 0 ? &nums1[0] : nullptr;
        int* rn = n > 0 ? &nums2[0] : nullptr;
        return ( getKth(ln, m, rn, n, l) + getKth(ln, m, rn, n, r) ) / 2.0;
    }
    int getKth(int* ln, int m, int* rn, int n, int k) {
        if (m > n) {
            return getKth(rn, n, ln, m, k);
        }
        if (m == 0) {
            return rn[k - 1];
        }
        if (k == 1) {
            return min(ln[0], rn[0]);
        }
        int i = min(m, k / 2), j = min(n, k / 2);
        if (ln[i - 1] > rn [j - 1]) {
            return getKth(ln, m, rn + j, n - j, k - j);
        } else {
            return getKth(ln + i, m - i, rn, n, k - i);
        }
    }
};

相关推荐