1. 问题定义

本文目标是解决一个经典算法问题:在两个已排序的数组中,高效地找出它们合并后第 k 小的元素

我们不会直接合并数组,而是设计一个高效的算法。整个过程分为几步:先明确问题,然后分析两个简单但效率不高的解法,接着重点介绍一个基于二分查找的高效解法,最后通过测试验证其正确性。

代码示例仅使用 int 类型,但该算法适用于所有可比较的数据类型,也可用泛型实现。

2. 什么是两个有序数组并集中的第K小元素?

2.1. 第K小元素的概念

通常,找一个数组中的第K小元素会用到选择算法。但本文场景特殊:输入是两个已排序的数组,目标是在它们的并集中找第K小的数。

我们来看一个例子:

nth element problem 1

给定两个有序数组 ab(长度可以不等),我们希望得到它们合并排序后的结果:

nth element problem 2

如上图(c)所示,合并后的有序数组中,第1小的元素是 3,第4小的元素是 20

2.2. 重复元素的处理

重复元素的计数方式需要明确。一个元素可能在数组 a 中出现多次(如 3),也可能同时出现在 b 中。

nth element distinct 1

  • (c) 图表示去重计数。
  • (d) 图表示重复元素也算作不同位置的元素。

本文采用 (d) 的方式处理,即所有出现都算作独立元素,这也是更常见的需求。

nth element distinct 2

3. 两种简单但效率较低的解法

3.1. 合并后排序

最简单粗暴的方法:合并两个数组,排序,然后取第K个元素。

int getKthElementSorted(int[] list1, int[] list2, int k) {

    int length1 = list1.length, length2 = list2.length;
    int[] combinedArray = new int[length1 + length2];
    System.arraycopy(list1, 0, combinedArray, 0, list1.length);
    System.arraycopy(list2, 0, combinedArray, list1.length, list2.length);
    Arrays.sort(combinedArray);

    return combinedArray[k-1];
}

设数组长度分别为 nm,总长 c = n + m。排序时间复杂度为 O(c log c),即 **O((n+m) log(n+m))**。

⚠️ 缺点

  • 需要 O(n+m) 的额外空间。
  • 时间复杂度较高,不够高效。

3.2. 归并过程找第K个

借鉴归并排序的合并过程,我们不需要完全合并,只需归并到第K个元素即可。

基本思路:

  • 用两个指针分别指向两个数组的开头。
  • 比较指针处的元素,较小者计入结果,对应指针前移。
  • 当累计移动了 k 次,最后被选中的元素就是答案。

nth element merge

优化:我们不需要真正构造结果数组,只需模拟指针移动。

public static int getKthElementMerge(int[] list1, int[] list2, int k) {

    int i1 = 0, i2 = 0;

    while(i1 < list1.length && i2 < list2.length && (i1 + i2) < k) {
        if(list1[i1] < list2[i2]) {
            i1++;
        } else {
            i2++;
        }
    }

    if((i1 + i2) < k) {
        return i1 < list1.length ? list1[k - i2 - 1] : list2[k - i1 - 1]; 
    } else if(i1 > 0 && i2 > 0) {
        return Math.max(list1[i1-1], list2[i2-1]);
    } else {
        return i1 == 0 ? list2[i2-1] : list1[i1-1];
    }
}
  • **时间复杂度:O(k)**,非常直观。
  • 优点:空间复杂度 O(1),且易于修改为去重模式。

4. 基于双数组的二分查找(高效解法)

能否突破 O(k) 的限制?**答案是可以,通过二分查找将复杂度降至 O(log(min(n, m)))**。

核心思想:在较短的数组上进行二分查找,决定从每个数组中各取多少个元素

我们实现的骨架方法如下:

int findKthElement(int k, int[] list1, int[] list2)
    throws NoSuchElementException, IllegalArgumentException {

    // 参数校验
    // 特殊情况处理
    // 二分查找主逻辑
}

下面我们按逆序讲解:先看二分逻辑,再看边界处理和参数校验。

4.1. 二分查找逻辑

4.1.1. 确定每个数组取多少元素

我们假设从 list1nElementsList1 个元素,那么从 list2 就取 k - nElementsList1 个元素。

int nElementsList2 = k - nElementsList1;

例如,k=8,我们先尝试从 list1 取4个,list2 取4个。

nth element binary a

关键判断:

  • 如果 list1 的第4个元素 > list2 的第4个元素,说明 list1 取多了,需要减少 nElementsList1
  • 否则,说明 list1 取少了,需要增加 nElementsList1

nth element binary b

代码实现:

int right = k;
int left = 0;
do {
    nElementsList1 = ((left + right) / 2) + 1;
    nElementsList2 = k - nElementsList1;

    if(nElementsList2 > 0) {
        if (list1[nElementsList1 - 1] > list2[nElementsList2 - 1]) {
            right = nElementsList1 - 2;
        } else {
            left = nElementsList1;
        }
    }
} while(!kthSmallesElementFound(list1, list2, nElementsList1, nElementsList2));

4.1.2. 停止条件

循环在以下任一条件满足时停止:

  1. 元素相等:从两个数组取出的最后一个元素相等,直接返回该元素。 nth element binary c 2

  2. 交叉小于:这是核心逻辑。

    • list1 取出的最大元素 < list2 未取出的最小元素。
    • list2 取出的最大元素 < list1 未取出的最小元素。 nth element binary d

    ✅ 此时,取出的 k 个元素一定就是最小的 k 个,它们的最大值即为答案。

判断代码:

private static boolean foundCorrectNumberOfElementsInBothLists(int[] list1, int[] list2, int nElementsList1, int nElementsList2) {

    if(nElementsList2 < 1) {
        return true;
    }

    if(list1[nElementsList1-1] == list2[nElementsList2-1]) {
        return true;
    }

    if(nElementsList1 == list1.length) {
        return list1[nElementsList1-1] <= list2[nElementsList2];
    }

    if(nElementsList2 == list2.length) {
        return list2[nElementsList2-1] <= list1[nElementsList1];
    }

    return list1[nElementsList1-1] <= list2[nElementsList2] && list2[nElementsList2-1] <= list1[nElementsList1];
}

4.1.3. 返回结果

根据 nElementsList2 的值决定返回值:

  • nElementsList2 == 0:所有 k 个元素都来自 list1,返回 list1[nElementsList1-1]
  • 否则:返回两个取出部分最大值中的较大者 max(list1[nElementsList1-1], list2[nElementsList2-1])

nth element binary e

代码:

return nElementsList2 == 0 ? list1[nElementsList1-1] : max(list1[nElementsList1-1], list2[nElementsList2-1]);

4.2. 左右边界初始化

初始的 leftright 不能简单设为 0k,需要根据实际情况调整。

  • **左边界 left**:如果 k 大于 list2 的长度,则 list1 至少要取 k - list2.length 个元素。
  • **右边界 right**:不能超过 list1 的长度,也不能超过 k-1
// 如果 k 大于 list2 的长度,调整左边界
int left = k < list2.length ? 0 : k - list2.length - 1;

// 初始右边界不能超过 list1 的长度
int right = Math.min(k-1, list1.length - 1);

nth element left border

4.3. 特殊情况处理

在二分前处理几个边界情况,可以简化主逻辑:

// 找最小值,即第1小
if(k == 1) {
    return Math.min(list1[0], list2[0]);
}

// 找最大值,即第 (m+n) 小
if(list1.length + list2.length == k) {
    return Math.max(list1[list1.length-1], list2[list2.length-1]);
}

// 交换数组,确保 list1 不是更小的那个(简化逻辑)
if(k <= list2.length && list2[k-1] < list1[0]) {
    int[] temp = list1;
    list1 = list2;
    list2 = temp;
}

4.4. 参数校验

防止 null 指针和数组越界:

void checkInput(int k, int[] list1, int[] list2) throws NoSuchElementException, IllegalArgumentException {

    if(list1 == null || list2 == null || k < 1) { 
        throw new IllegalArgumentException(); 
    }
 
    if(list1.length == 0 || list2.length == 0) { 
        throw new IllegalArgumentException(); 
    } 

    if(k > list1.length + list2.length) {
        throw new NoSuchElementException();
    }
}

4.5. 完整代码

public static int findKthElement(int k, int[] list1, int[] list2) throws NoSuchElementException, IllegalArgumentException {

    checkInput(k, list1, list2);

    if(k == 1) {
        return Math.min(list1[0], list2[0]);
    }

    if(list1.length + list2.length == k) {
        return Math.max(list1[list1.length-1], list2[list2.length-1]);
    }

    if(k <= list2.length && list2[k-1] < list1[0]) {
        int[] temp = list1;
        list1 = list2;
        list2 = temp;
    }

    int left = k < list2.length ? 0 : k - list2.length - 1; 
    int right = Math.min(k-1, list1.length - 1); 

    int nElementsList1, nElementsList2; 

    do { 
        nElementsList1 = ((left + right) / 2) + 1; 
        nElementsList2 = k - nElementsList1; 

        if(nElementsList2 > 0) {
            if (list1[nElementsList1 - 1] > list2[nElementsList2 - 1]) {
                right = nElementsList1 - 2;
            } else {
                left = nElementsList1;
            }
        }
    } while(!foundCorrectNumberOfElementsInBothLists(list1, list2, nElementsList1, nElementsList2));

    return nElementsList2 == 0 ? list1[nElementsList1-1] : Math.max(list1[nElementsList1-1], list2[nElementsList2-1]);
}

private static boolean foundCorrectNumberOfElementsInBothLists(int[] list1, int[] list2, int nElementsList1, int nElementsList2) {
    if(nElementsList2 < 1) {
        return true;
    }
    if(list1[nElementsList1-1] == list2[nElementsList2-1]) {
        return true;
    }
    if(nElementsList1 == list1.length) {
        return list1[nElementsList1-1] <= list2[nElementsList2];
    }
    if(nElementsList2 == list2.length) {
        return list2[nElementsList2-1] <= list1[nElementsList1];
    }
    return list1[nElementsList1-1] <= list2[nElementsList2] && list2[nElementsList2-1] <= list1[nElementsList1];
}

5. 算法测试

测试用例应覆盖各种边界情况。一个有效的测试策略是:将高效算法的结果与简单算法(合并排序法)的结果进行对比

生成随机有序数组:

int[] sortedRandomIntArrayOfLength(int length) {
    int[] intArray = new Random().ints(length).toArray();
    Arrays.sort(intArray);
    return intArray;
}

单次测试:

private void randomTest() {
    Random random = new Random();
    int length1 = Math.abs(random.nextInt()) % 1000 + 1;
    int length2 = Math.abs(random.nextInt()) % 1000 + 1;

    int[] list1 = sortedRandomIntArrayOfLength(length1);
    int[] list2 = sortedRandomIntArrayOfLength(length2);

    int k = (Math.abs(random.nextInt()) + 1) % (length1 + length2);

    int result = findKthElement(k, list1, list2);
    int result2 = getKthElementSorted(list1, list2, k);
    int result3 = getKthElementMerge(list1, list2, k);

    assertEquals(result2, result);
    assertEquals(result2, result3);
}

运行大量测试:

@Test
void randomTests() {
    java.util.stream.IntStream.range(1, 100000).forEach(i -> randomTest());
}

6. 总结

本文探讨了在两个有序数组中找第K小元素的三种方法:

  • ✅ **O((n+m)log(n+m))**:合并排序,简单但慢。
  • ✅ **O(k)**:归并模拟,空间友好,适合 k 较小的情况。
  • ✅ **O(log(min(n, m)))**:双数组二分查找,理论最优,但代码复杂。

踩坑提醒:虽然二分法时间复杂度最低,但其代码逻辑复杂,边界条件多。在实际项目中,如果 k 不是特别大,推荐使用 O(k) 的归并法,它足够快且易于理解和维护。只有在性能是硬性要求且数据量巨大时,才考虑使用二分法。

所有代码已上传至 GitHub 仓库


原始标题:Find the Kth Smallest Element in Two Sorted Arrays in Java