TOP-K问题

前言

TOP-K问题是面试中的常见题型,具体表现为:

  1. 海量数据
  2. 求最大(最小)的K个值.

比如:给定1亿个乱序的整数,求其中最大的1000个

今天一步一步来分析这种题目的解决办法,以及用到的算法思想.

依赖

此文中会依赖到快速排序,堆排序等排序算法,以及数据结构堆.

如果你对上述两者的原理有所了解,可以继续往下看.如果不了解,可以点击链接先看一下基础~.

好,进入正题.(以给定1亿个乱序的整数,求其中最大的1000个为例.)

暴躁方法

这种题目一看就是排序嘛(暴躁吧).大不了就是我把一亿个数字都排序了(不要告诉我内存不够,我不听),然后取前1000个数字返回即可.

那么选择哪种排序方式呢?插入,选择等时间复杂度为O(n2)的就不考虑了,考虑一下快排,归并,堆排序,他们的平均时间复杂度都为O(nlogn).

注:计数排序等就不考虑了,,一亿个数字,范围得多大啊…

放在此题目里面,就是O(1亿 * log 1亿).

优化一下

要想优化这个,我们首先要明白,上面排序中的O(nlogn)都是干啥的

这个n是指每个数字都要遍历一次,这个没问题,你至少每个数字得看到才能知道大小吧.

这个logn是指,每个元素被拿到之后,要去和logn个元素比较才能确定他的大小,以便找到他的具体位置.

发现问题了嘛?在这道题目里面,我们其实是不关心每个元素的具体位置,我们只想要前面的1000个的具体位置.说明这里可以优化一下.

可不可以我先拿到前1000元素,他们就是目前最大的,然后之后的每个元素和这1000个元素比较一下.如果能在大于任意一个,就替换掉他.如果1000个都比他打,则将他丢弃掉.

这样相当于我们遍历1亿个数字,对每一个数字比较1000次. 时间复杂度为:O(1 亿 * 1000).

再优化一下

为了确定每一个元素是否是所有的前1000,真的需要去和1000个依次比较吗?

不需要,这里可以借助这个数据结构.

首先,我们拿前1000个元素,构造一个小顶堆.花费时间为:O(n).

然后,我们拿剩下的所有元素,依次和堆顶元素进行比较,因为堆是小顶堆,小于堆顶就说明小于堆中的所有元素,直接丢弃,大于堆顶则替换堆顶元素,之后调整一下堆,使其继续符合堆的性质.

在这个过程中,我们花费的时间为:O(n) + O(n) * O(log1000).

时间复杂度为:O(nlongk),其中,n为总数据量,即1亿.k为所求数据量,即1000.

换种思路

求最大的1000个元素,真的要将前1000个元素排序吗?不一定吧,我们只需要知道他们都很大,是前1000个,具体他们之间的顺序不需要确定.

那么我们是否可以将数据分成两部分,左边是前1000大,右面是剩余元素.

这种思路是不是有点像快排?将大的放一边,小的放一边.

那么具体的思路就是:

首先对所有数据进行一次快排.分为了左右两部分.

如果左边的元素数量大于1000,那么说明前1000元素都在左边,对右边的直接丢弃.

对左边的元素再次进行快排. 分为左右两部分.如果左边的元素数量为900,小于100.则说明我们拿到了前900个数据.

然后对右边进行快排,找到他的前100个数据.

两部分相加,即是所求的前1000个数据.

在这个过程中,时间复杂度为:O(n)logn.

更正:
时间复杂度为:O(n).

总结

解决TOP-K问题常用的两种方式:

维护K个数据的堆,之后依次使用数据来与堆顶元素比较,要么丢弃,要么替换掉堆顶元素,之后调整堆. 时间复杂度为:O(nlogK).

实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* 使用最小堆求TOP-K问题
*/
private int findTopK2(int[] arr, int k) {
//最小堆
MinHeap heap = new MinHeap();
//建堆
heap.buildMaxHeap(arr, arr.length);
for (int i = k; i < arr.length; i++) {
//如果当前值大于堆顶元素,则替换掉堆顶元素并调整堆
if (arr[i] > arr[0]) {
arr[0] = arr[i];
heap.delete(arr, k);
}
}
//这里其实不需要返回值,只是因为在findTopK1中返回了index,这里也返回index,方便遍历
//其实在此方法执行后,arr数组的钱K位就是所求元素,直接遍历即可
return k - 1;
}

最小堆为额外实现,这里附上代码,看不懂的童鞋可以去依赖中的数据结构之堆之中查看.

最小堆实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package heap;

import java.util.Arrays;

/**
* created by huyanshi on 2019/1/16
*/
public class MinHeap {


public static void main(String[] args) {

MinHeap heap = new MinHeap();

int[] data = {15, 13, 1, 5, 20, 12, 8, 9, 11};
// 测试建堆
heap.buildMaxHeap(data, data.length);
System.out.println(Arrays.toString(data));

// 测试插入
int[] newArr = heap.insert(data, 14, data.length);
System.out.println(Arrays.toString(newArr));

// 测试删除
heap.delete(newArr, data.length + 1);
System.out.println(Arrays.toString(newArr));
}

/**
* 建立最小堆
*/
public void buildMaxHeap(int[] array, int len) {
//从最后一个非叶子节点开始,逐个节点向前进行下沉
for (int i = (len / 2 - 1); i >= 0; i--) {
down(array, len, i);
}
}

/**
* 插入元素
*/
public int[] insert(int[] array, int a, int len) {
//copy到新数组,长度加1,并将添加的值放入末尾
int[] newArr = Arrays.copyOf(array, len + 1);
newArr[len] = a;
//对新加入的值进行上浮
return up(newArr, len + 1, len);
}

/**
* 删除元素
*/
public void delete(int[] array, int len) {
//交换第一个和最后一个节点
exchange(array, 0, len - 1);
//从根节点进行下沉
down(array, len - 1, 0);
}

/**
* 元素下沉
*/
private void down(int[] array, int len, int i) {
int minIndex = i;
//如果有左子树,且左子树小于父节点,则将最小指针指向左子树
if (i * 2 + 1 < len && array[i * 2 + 1] < array[minIndex]) {
minIndex = i * 2 + 1;
}
//如果有右子树,且右子树小于父节点,则将最小指针指向右子树
if (i * 2 + 2 < len && array[i * 2 + 2] < array[minIndex]) {
minIndex = i * 2 + 2;
}
//如果父节点不是最小值,则将父节点与最小值交换,并且递归调整与父节点交换的位置。
if (minIndex != i) {
exchange(array, minIndex, i);
down(array, len, minIndex);
}
}


/**
* 元素上浮
*/
private int[] up(int[] array, int len, int i) {
//如果上浮到根节点了,则结束
if (i == 0) {
return array;
}
//如果当前节点小于其父节点,则交换值,并且对其父节点进行上浮.
if (array[i] < array[(i - 1) / 2]) {
exchange(array, i, (i - 1) / 2);
array = up(array, len, (i - 1) / 2);
}
return array;
}

/**
* 交换数组在两个下标的值
*/
private void exchange(int[] input, int i, int j) {
int tmp = input[i];
input[i] = input[j];
input[j] = tmp;
}

}

快速选择算法(类快速排序)

对数据进行分隔(使用快排思想):

  1. 若切分后的左子数组的长度 > k,则前k大元素必出现在左子数组中;
  2. 若切分后的左子数组的长度 = k,则前k大元素为左子数组.
  3. 若切分后的左子树组的长度 s < k, 则左子数组为前s大元素,在右子数组中寻找前k-s大元素.

时间复杂度为:O(nlogK).

更正:

经”灰灰是菇凉”提醒: 此处的时间复杂度为:O(n).

快速选择算法中:花费时间为:n + n/2 + n/4 + ... + n/n(极限情况,k=1) = n(1+ 1/2 + 1/4 + ...+ 1/n) < O(2n) = O(n).

实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
* 找到前k大的元素,返回的index为第k大元素的下标,遍历0-index即可
*/
private int findTopK1(int[] arr, int start, int end, int k) {
int index = -1;
if (start < end) {
//切分,找到标志位的下标
int mid = partition(arr, start, end);
//计算左边数组长度
int leftLen = mid - start + 1;
//如果左边长度=k,则直接返回
if (leftLen == k) {
index = mid;
//左边长度大于k,在左边数组寻找前k大元素
} else if (leftLen > k) {
index = findTopK1(arr, start, mid, k);
//左边长度小于k,在右边数组寻找前k-len个元素.
} else {
index = findTopK1(arr, mid + 1, end, k - leftLen);
}
//返回第k大元素的下标.
return index;
}
return index;
}


public int partition(int[] a, int start, int end) {
//以最左边的值为基准
int base = a[start];
//start一旦等于end,就说明左右两个指针合并到了同一位置,可以结束此轮循环。
while (start < end) {
while (start < end && a[end] >= base) {
//从右边开始遍历,如果比基准值大,就继续向左走
end--;
}
//上面的while循环结束时,就说明当前的a[end]的值比基准值小,应与基准值进行交换
if (start < end) {
//交换
exchange(a, start, end);
//交换后,此时的那个被调换的值也同时调到了正确的位置(基准值左边),因此左边也要同时向后移动一位
start++;
}
while (start < end && a[start] <= base) {
//从左边开始遍历,如果比基准值小,就继续向右走
start++;
}
//上面的while循环结束时,就说明当前的a[start]的值比基准值大,应与基准值进行交换
if (start < end) {
//交换
exchange(a, start, end);
//交换后,此时的那个被调换的值也同时调到了正确的位置(基准值右边),因此右边也要同时向前移动一位
end--;
}
}
//这里返回start或者end皆可,此时的start和end都为基准值所在的位置
return end;
}

public void exchange(int[] input, int i, int j) {
int tmp = input[i];
input[i] = input[j];
input[j] = tmp;
}

测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public static void main(String[] args) {
TopKProblem tk = new TopKProblem();
int k = 3;
int array[] = {20, 100, 4, 2, 87, 9, 8, 5, 46, 26};
int index = tk.findTopK1(array, 0, array.length - 1, k);
System.out.println("使用快速选择:");
for (int i = 0; i <= index; i++) {
System.out.print(array[i] + ",");
}
System.out.println();

int k1 = 3;
int array1[] = {20, 100, 4, 2, 87, 9, 8, 5, 46, 26};
int index1 = tk.findTopK2(array1, k1);
System.out.println("使用堆:");
for (int i = 0; i <= index1; i++) {
System.out.print(array[i] + ",");
}

}

输出结果如下:

1
2
3
4
使用快速选择:
2,4,5,
使用堆:
2,4,5,

注:经本人多次调节k测试,代码及结果正确,如发现错误可以随时评论及邮件,将尽快修正.

完。




ChangeLog

2019-01-16 完成

以上皆为个人所思所得,如有错误欢迎评论区指正。

欢迎转载,烦请署名并保留原文链接。

联系邮箱:huyanshi2580@gmail.com

更多学习笔记见个人博客——>呼延十