分治算法

分治算法详解及例题

在这之前写过的 全排列 最大子序列和 棋盘覆盖 归并排序 都属于经典的分治策略。这篇博客主要是重新复习一下分治算法,写上另外几道经典的分治算法题

分治算法主要有两部分组成的:

  • 分:递归解决较小的问题
  • 治:从子问题的解构建原问题的解

一般在正文中至少含有两个递归调用的例程叫分治算法,而正文中只有一个递归调用的例程不是分治算法。一般坚持子问题是不相交的。
分析时间复杂度:

一个规模为n的实例可以划分为b个规模为n/b的实例,其中a个实例是需要求解的,为了简化分析,我们假设n是b的幂(每次都可以整的划分),对算法的运行时间,下面的递推关系式是显然的:

其中,a,b的含义已经说过了,f(n)表示将求解得到的a个子问题的解合并起来所需要的时间复杂度。

如何根据a,b以及f的阶来确定这个算法的时间复杂度呢?有下列主定理:(证明参见算法导论)

看起来复杂 其实大部分我们用不到,现在我们来分析一下:
当a < b时,这代表问题分成b个规模,但是却只需要处理a个问题就好,这属于减治策略,经典的算法是二分查找,此时前半部分的时间复杂度大多为log(N) 此时算法的时间复杂度决定在f(n)
当a = b时,这代表问题分成b个规模,而这b个规模都需要处理 当f(n)为O(1)时,算法大部分为O(N),而当f(n)为O(N)时 则是NlogN的算法 如归并排序
以上只是经常见到的分治问题的时间复杂度,具体的看公式。

接下来介绍几个经典的分治算法:

分治最经典的算法 当属归并排序了,重新敲了一个,理解下思想

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/**
* 文件名:mulply.java
* 时间:2014年11月19日下午10:03:59
* 作者:修维康
*/
package divideconquer;
import java.util.Arrays;
public class MergeSort {
public static void MergeSort(Comparable[] a, int low, int high) {
if (low < high) {
int mid = (low + high) / 2;
MergeSort(a, low, mid);
MergeSort(a, mid + 1, high);
Merge(a, low, mid, high);
}
}
private static void Merge(Comparable[] a, int low, int mid, int high) {
Comparable[] temp = new Comparable[high - low + 1];
int index = 0;
int begin1 = low;
int begin2 = mid + 1;
while (begin1 <= mid && begin2 <= high) {
if (a[begin1].compareTo(a[begin2]) < 0)
temp[index++] = a[begin1++];
else
temp[index++] = a[begin2++];
}
while (begin1 <= mid)
temp[index++] = a[begin1++];
while (begin2 <= high)
temp[index++] = a[begin2++];
for (int i = low, j = 0; i <= high; i++, j++)
a[i] = temp[j];
}
public static void main(String[] args) {
Integer[] a = new Integer[] { 5, 4, 3, 2, 1, 9, 7, 7 };
MergeSort(a, 0, a.length - 1);
System.out.println(Arrays.toString(a));
}
}

接下来是整数的乘法运算
思想就是将 整数中从中间一分为二 例如X = 61438521 Y = 94736407 分开XL = 6143 XR = 8521 YL=9473 YR = 6407
X = XL10^4 + XR Y = YL10^4 +YR

XY = XLYL10^8 + (XLYR + XRYL)10^4 +XRYR
在这里T(N) = 4T(N/2) +O(N) 时间复杂度还是为N^2 我们并没有改进算法
但是经观察可以得到
XLYR +XRYL = (XL - XR)(YR-YL) + XLYL + XRYR
而XLYL和XRYR在上面求过了 因此此时算法是T(N) = 3T(N/2) +O(N),运用公式 时间复杂度是O(N^log 3)
下面的算法没有用到这个,因为输入的位数可能不相等,想要实现上面的算法 只需要略微一改就好,懒得写了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/**
* 文件名:Mutiply.java
* 时间:2014年11月21日上午8:39:48
* 作者:修维康
*/
package divideconquer;
/**
* 类名:Mutiply 说明:分治求大整数的乘法
* 如果位数相同 则分治时可减少一次乘法。
*
*/
public class Mutiply {
public static long mutiply(long x, long y) {
if (getBit(x) > 2 && getBit(y) > 2) {
long xBit = getBit(x) - getBit(x) / 2;
long yBit = getBit(y) - getBit(y) / 2;
long xL = (long)(x / Math.pow(10, xBit));
long xR = (long)(x % Math.pow(10, xBit));
long yL = (long)(y / Math.pow(10, yBit));
long yR = (long)(y % Math.pow(10, yBit));
//System.out.println(xBit +" "+ yBit+" "+xL+" "+xR +" " + yL +" "+yR);
return (long) ((mutiply(xL,yL) * (Math.pow(10, xBit + yBit))) + mutiply(xL,yR)
* (Math.pow(10, xBit)) + mutiply(xR,yL) * (Math.pow(10, yBit)) + mutiply(xR,yR));
}
return x * y;
}
public static int getBit(long a) {
int cout = 0;
while (a > 0) {
cout++;
a /= 10;
}
return cout;
}
/**
* 方法名:main 说明:测试
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
System.out.println(mutiply(61438521,94736407));
}
}

接下来是Strassen矩阵的乘法(这里是拷贝别人的)
理论上的分析就不多写了,涉及很多数学上的东西,计算麻烦,直接上书上的图:

分治策略将2个二阶矩阵采用下列方式来计算:

数一下,这样来计算2个二阶矩阵的乘法用了7次乘法,18次加法。而蛮力法用了8次乘法和4次加法。当然,这还不能体现出它的优越性,它的优越性表现在当矩阵的阶趋于无穷大时的渐进效率。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package divideconquer;
public class Strassen {
// 该程序可以对两个同阶的2^n阶的矩阵采用Strassen算法做矩阵乘法
/**
*
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
int[][] a = { { 1, 0, 2, 1 }, { 4, 1, 1, 0 }, { 0, 1, 3, 0 },
{ 5, 0, 2, 1 } };
int[][] b = { { 0, 1, 0, 1 }, { 2, 1, 0, 4 }, { 2, 0, 1, 1 },
{ 1, 3, 5, 0 } };
int[][] result = StrassenMul(a, b);
System.out.println("输出矩阵:");
for (int i = 0; i < result.length; i++) {
for (int j = 0; j < result.length; j++)
System.out.print(result[i][j] + " ");
System.out.println();
}
}
public static int[][] StrassenMul(int[][] a, int[][] b) { // a,b均是2的乘方的方阵
int[][] result = new int[a.length][a.length];
if (a.length == 2) // 如果a,b均是2阶的,递归结束条件
result = StrassMul(a, b);
else // 否则(即a,b都是4,8,16...阶的)
{
// a的四个子矩阵
int[][] A00 = copyArrays(a, 1);
int[][] A01 = copyArrays(a, 2);
int[][] A10 = copyArrays(a, 3);
int[][] A11 = copyArrays(a, 4);
// b的四个子矩阵
int[][] B00 = copyArrays(b, 1);
int[][] B01 = copyArrays(b, 2);
int[][] B10 = copyArrays(b, 3);
int[][] B11 = copyArrays(b, 4);
// 递归调用
int[][] m1 = StrassenMul(addArrays(A00, A11), addArrays(B00, B11));
int[][] m2 = StrassenMul(addArrays(A10, A11), B00);
int[][] m3 = StrassenMul(A00, subArrays(B01, B11));
int[][] m4 = StrassenMul(A11, subArrays(B10, B00));
int[][] m5 = StrassenMul(addArrays(A00, A01), B11);
int[][] m6 = StrassenMul(subArrays(A10, A00), addArrays(B00, B01));
int[][] m7 = StrassenMul(subArrays(A01, A11), addArrays(B10, B11));
// 得到result的四个子矩阵
int[][] C00 = addArrays(m7, subArrays(addArrays(m1, m4), m5));// m1+m4-m5+m7
int[][] C01 = addArrays(m3, m5); // m3+m5
int[][] C10 = addArrays(m2, m4); // m2+m4
int[][] C11 = addArrays(m6, subArrays(addArrays(m1, m3), m2));// m1+m3-m2+m6
// 也可以按照下列方法来求C
// C00 = addArrays(StrassenMul(A00,B00),StrassenMul(A01,B10));
// C01 = addArrays(StrassenMul(A00,B01),StrassenMul(A01,B11));
// C10 = addArrays(StrassenMul(A10,B00),StrassenMul(A11,B10));
// C11 = addArrays(StrassenMul(A10,B01),StrassenMul(A11,B11));
// 将四个子矩阵合并成result
Merge(result, C00, 1);
Merge(result, C01, 2);
Merge(result, C10, 3);
Merge(result, C11, 4);
}
return result;
}
private static void Merge(int[][] result, int[][] C, int flag) {
// 将C复制到result的相应位置
switch (flag) {
case 1:
for (int i = 0; i < result.length / 2; i++)
for (int j = 0; j < result.length / 2; j++)
result[i][j] = C[i][j];
break;
case 2:
for (int i = 0; i < result.length / 2; i++)
for (int j = result.length / 2; j < result.length; j++)
result[i][j] = C[i][j - result.length / 2];
break;
case 3:
for (int i = result.length / 2; i < result.length; i++)
for (int j = 0; j < result.length / 2; j++)
result[i][j] = C[i - result.length / 2][j];
break;
case 4:
for (int i = result.length / 2; i < result.length; i++)
for (int j = result.length / 2; j < result.length; j++)
result[i][j] = C[i - result.length / 2][j - result.length
/ 2];
break;
}
}
private static int[][] copyArrays(int[][] a, int flag) {
// 得到分割矩阵的子矩阵
int[][] result = new int[a.length / 2][a.length / 2];
switch (flag) {
case 1:
for (int i = 0; i < a.length / 2; i++)
for (int j = 0; j < a.length / 2; j++)
result[i][j] = a[i][j];
break;
case 2:
for (int i = 0; i < a.length / 2; i++)
for (int j = a.length / 2; j < a.length; j++)
result[i][j - a.length / 2] = a[i][j];
break;
case 3:
for (int i = a.length / 2; i < a.length; i++)
for (int j = 0; j < a.length / 2; j++)
result[i - a.length / 2][j] = a[i][j];
break;
case 4:
for (int i = a.length / 2; i < a.length; i++)
for (int j = a.length / 2; j < a.length; j++)
result[i - a.length / 2][j - a.length / 2] = a[i][j];
break;
}
return result;
}
private static int[][] StrassMul(int[][] a, int[][] b) {
// 计算2个二阶的矩阵乘法
// Strassen方法使用了7次乘法,18次加法(传统方法是8次乘法4次加法)
int[][] result = new int[2][2];
int m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
int m2 = (a[1][0] + a[1][1]) * b[0][0];
int m3 = a[0][0] * (b[0][1] - b[1][1]);
int m4 = a[1][1] * (b[1][0] - b[0][0]);
int m5 = (a[0][0] + a[0][1]) * b[1][1];
int m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
int m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
result[0][0] = m1 + m4 - m5 + m7;
result[0][1] = m3 + m5;
result[1][0] = m2 + m4;
result[1][1] = m1 + m3 - m2 + m6;
return result;
}
private static int[][] addArrays(int[][] a, int[][] b) {
// 求2个同阶矩阵的和
int[][] result = new int[a.length][a.length];
// System.out.println(result.length);
for (int i = 0; i < result.length; i++)
for (int j = 0; j < result.length; j++)
// for(int j = 0;i < result.length;j++)
result[i][j] = a[i][j] + b[i][j];
return result;
}
private static int[][] subArrays(int[][] a, int[][] b) {
// 矩阵减法
int[][] result = new int[a.length][a.length];
for (int i = 0; i < result.length; i++)
for (int j = 0; j < result.length; j++)
// for(int j = 0;i < result.length;j++)
result[i][j] = a[i][j] - b[i][j];
return result;
}
}

最接近点对点问题

和最大子序列和的递归算法很相似。将点按照x坐标非降序排序,找到中间的值 将其分成两部分,s1 s2 点要么在s1中 要么在s2中 在要么一个在s1一个在s2中,递归算,注意基准情况,当只有2个点或者3个点的时候就可以不用再分了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package divideconquer;
import java.util.Arrays;
class Point implements Comparable<Point> {
int x;
int y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
@Override
public int compareTo(Point o) {
// TODO Auto-generated method stub
if (this.x > o.x)
return 1;
else if (this.x < o.x)
return -1;
return 0;
}
}
public class ClosedPoint {
private static double getDistance(Point a, Point b) {
return Math.pow(Math.pow(b.x - a.x, 2) + Math.pow(b.y - a.y, 2), 0.5);
}
public static Point[] getPoints(Point[] points) {
Arrays.sort(points);
if (points.length == 2 || points.length == 3)
return getPoint(points);
else {
// 分成两部分 s1 s2;
Point[] result = new Point[2];
int mid = points.length / 2;
// copyOfRange为左闭右开
Point[] s1 = Arrays.copyOfRange(points, 0, mid);
Point[] s2 = Arrays.copyOfRange(points, mid, points.length);
// 得到s1 中最短距离d1 s2中最短距离d2
Point[] result1 = getPoints(s1);
Point[] result2 = getPoints(s2);
// System.out.println(result1.length);
// System.out.println(result2.length);
double d1 = getDistance(result1[0], result1[1]);
double d2 = getDistance(result2[0], result2[1]);
// 得到 横跨s2 s1的最短距离d3
double minDistance = d1 < d2 ? d1 : d2;
result = minDistance == d1 ? result1 : result2;
for (int i = 0; i < s2.length; i++) {
//优化,如果s2到中间的距离大于minDistance则也不需要继续往后走了
if (s2[i].x - s1[s1.length - 1].x > minDistance)
break;
for (int j = s1.length - 1; j >= 0; j--)
if (s2[i].x - s1[j].x > minDistance)
break;
else {
double d3 = getDistance(s2[i], s1[j]);
if (d3 < minDistance) {
minDistance = d3;
result[0] = s2[i];
result[1] = s1[j];
}
}
}
return result;
}
}
private static Point[] getPoint(Point[] points) {
Point[] result = new Point[2];
if (points.length == 2) {
result[0] = points[0];
result[1] = points[1];
} else {
double d1 = getDistance(points[0], points[1]);
double d2 = getDistance(points[0], points[2]);
double d3 = getDistance(points[1], points[2]);
if (d1 < d2 && d1 < d3) {
result[0] = points[0];
result[1] = points[1];
}
if (d2 < d1 && d2 < d3) {
result[0] = points[0];
result[1] = points[2];
}
if (d3 < d1 && d3 < d2) {
result[0] = points[1];
result[1] = points[2];
}
}
return result;
}
public static void main(String[] args) {
Point[] points = new Point[] { new Point(1, 1), new Point(1, 6),
new Point(1, 4), new Point(1, 7), new Point(1, 12),
new Point(2, 18), new Point(3, 20), new Point(5, 30) };
Point[] result = getPoints(points);
System.out.print("(" + result[0].x + "," + result[0].y + ")," + "("
+ result[1].x + "," + result[1].y + ")");
}
}

之前写的棋盘铺盖,快排,和找第K大的值,最大子序列和问题 都属于分治策略,至此分治算法完。