1. 简介

查找一个序列中的最大/最小值时间复杂度均为 O(N)O(N),而查询一个序列中第 K(K=0N1)K(K = 0 \cdots N-1) 大的数时间复杂度最坏情况下即为排序的最好时间复杂度 O(NlogN)O(N \log N)(只考虑比较排序),但利用快排的 PartitionPartition 思想也可以达到期望 O(N)O(N) 的时间复杂度,最坏情况下 O(N2)O(N^2) 的时间复杂度。

2. 思想

  • 沿用快排中的 PartitionPartition 思想,选择一个枢轴,然后将小于枢轴的数都交换到枢轴左边,大于枢轴的数都交换到枢轴右边。

  • 然后判断:

  1. 如果枢轴左边小于等于枢轴的序列大小等于 KK,则说明第 KK 小的数即为枢轴。
  2. 如果枢轴左边小于等于枢轴的序列大小大于 KK,则说明第 KK 小的数一定在枢轴左边的序列。
  3. 如果枢轴左边小于等于枢轴的序列大小小于 KK,则说明第 KK 小的数一定在枢轴右边的序列。

【注】同样,在快排中采用的使划分尽量均衡的方法也可以用到此处,从而尽可能避免出现最坏情况。

3. 代码

3.1 基础版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// PARTITION
template <typename T>
T* partition(T *s, T *t, bool (*cmp)(const T &, const T &)) {
T x = *s;
while(s < t) {
while(s < t && !cmp(*t,x)) --t;
swap(*s,*t);
while(s < t && !cmp(x,*s)) ++s;
swap(*s,*t);
}
return s;
}
// 查找第 k 大的数
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
T *mid = partition(s,t-1,cmp);
if(mid-s == k) {
return mid;
} else if(mid-s > k) {
return findKth(s,mid,k,cmp);
} else {
return findKth(mid+1,t,k+(s-mid)-1,cmp);
}
}
#endif

3.2 随机化版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// PARTITION
template <typename T>
T* partition(T *s, T *t, bool (*cmp)(const T &, const T &)) {
T *p = s + (rand()%(t-s+1));
swap(*p,*s);
T x = *s;
while(s < t) {
while(s < t && !cmp(*t,x)) --t;
swap(*s,*t);
while(s < t && !cmp(x,*s)) ++s;
swap(*s,*t);
}
return s;
}
// 查找第 k 大的数
template <typename T>
T* FindKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &)) {
T *mid = partition(s,t-1,cmp);
if(mid-s == k) {
return mid;
} else if(mid-s > k) {
return FindKth(s,mid,k,cmp);
} else {
return FindKth(mid+1,t,k+(s-mid)-1,cmp);
}
}
// 查找第 k 大的数(随机化版本)
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
srand(time(NULL));
return FindKth(s,t,k,cmp);
}
#endif

3.3 三数取中版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// PARTITION
template <typename T>
T* partition(T *s, T *t, bool (*cmp)(const T &, const T &)) {
if(s < t) {
T *mid = s + (t-s)/2;
if(cmp(*t,*s)) {
swap(*t,*s);
}
if(cmp(*mid,*s)) {
swap(*mid,*s);
}
if(cmp(*t,*mid)) {
swap(*mid,*t);
}
swap(*mid,*s);
}
T x = *s;
while(s < t) {
while(s < t && !cmp(*t,x)) --t;
swap(*s,*t);
while(s < t && !cmp(x,*s)) ++s;
swap(*s,*t);
}
return s;
}
// 查找第 k 大的数
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
T *mid = partition(s,t-1,cmp);
if(mid-s == k) {
return mid;
} else if(mid-s > k) {
return findKth(s,mid,k,cmp);
} else {
return findKth(mid+1,t,k+(s-mid)-1,cmp);
}
}
#endif

3.4 三路划分版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// 查找第 k 大的数
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
T *p = s + 1;
T *lt = s, *gt = t-1;
T x = *s;
while(p <= gt) {
if(cmp(*p,x)) {
swap(*(lt++),*(p++));
} else if(cmp(x,*p)) {
swap(*(gt--),*p);
} else {
p++;
}
}
if(lt-s <= k && k < gt+1-s) {
return lt;
} else if(lt-s > k) {
return findKth(s,lt,k,cmp);
} else {
return findKth(gt+1,t,k-(gt-s+1),cmp);
}
}
#endif

3.5 随机化 + 三路划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// 查找第 k 大的数
template <typename T>
T* FindKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
// 随机化
swap(*(s+(rand()%(t-s))),*s);
// 三路划分
T *p = s + 1;
T *lt = s, *gt = t-1;
T x = *s;
while(p <= gt) {
if(cmp(*p,x)) {
swap(*(lt++),*(p++));
} else if(cmp(x,*p)) {
swap(*(gt--),*p);
} else {
p++;
}
}
if(lt-s <= k && k < gt+1-s) {
return lt;
} else if(lt-s > k) {
return FindKth(s,lt,k,cmp);
} else {
return FindKth(gt+1,t,k-(gt-s+1),cmp);
}
}
// 查找第 k 大的数(随机化版本)
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
srand(time(NULL));
return FindKth(s,t,k,cmp);
}
#endif

3.6 随机化 + 三数取中 + 三路划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <bits/stdc++.h>
using namespace std;

#ifndef _KTH_
#define _KTH_
#define ll int

// 比较函数
template <typename T>
bool compare(const T & t1, const T & t2) {
return t1 < t2;
}
// 查找第 k 大的数
template <typename T>
T* FindKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
// 随机化 + 三数取中
T *x = s + (rand()%(t-s));
T *y = s + (rand()%(t-s));
T *z = s + (rand()%(t-s));
if(cmp(*z,*x)) {
swap(*x,*z);
}
if(cmp(*y,*x)) {
swap(*x,*y);
}
if(cmp(*z,*y)) {
swap(*y,*z);
}
swap(*y,*s);
// 三路划分
T *p = s + 1;
T *lt = s, *gt = t-1;
T pivot = *s;
while(p <= gt) {
if(cmp(*p,pivot)) {
swap(*(lt++),*(p++));
} else if(cmp(pivot,*p)) {
swap(*(gt--),*p);
} else {
p++;
}
}
if(lt-s <= k && k < gt+1-s) {
return lt;
} else if(lt-s > k) {
return FindKth(s,lt,k,cmp);
} else {
return FindKth(gt+1,t,k-(gt-s+1),cmp);
}
}
// 查找第 k 大的数(随机化版本)
template <typename T>
T* findKth(T *s, T *t, ll k, bool (*cmp)(const T &, const T &) = compare) {
srand(time(NULL));
return FindKth(s,t,k,cmp);
}
#endif