0
0
0
1. 云栖社区>
2. 博客>
3. 正文

## Prim算法

1. 输入：一个加权连通图，其中顶点集合为V，边集合为E；

2. 初始化：Vn = {x}，其中x为集合V中的任一节点（起始点），Enew = {}；

3. 重复下列操作，直到Vn = V：（在集合E中选取权值最小的边（u, v），其中u为集合Vn中的元素，而v则是V中没有加入Vn的顶点（如果存在有多条满足前述条件即具有相同权值的边，则可任意选取其中之一）；
将v加入集合Vn中，将（u, v）加入集合En中；）

4. 输出：使用集合Vn和En来描述所得到的最小生成树。

## Prim算法实现

```class MST(object):
def __init__(self, graph):
self.graph = graph
self.N = len(self.graph)
pass
def prim(self, start):
index = start
cost, path = [0] * self.N, [0] * self.N
# 初始化起点
known = [x for x in map(lambda x: True if x == start else False, [x for x in range(self.N)])]
path[start] = -1
for i in range(self.N):
cost[i] = self.graph[start][i]
# 遍历其余各个结点
for i in range(1, self.N):
mi = 1e9
# 找出相对最小权重的结点
for j in range(self.N):
if not known[j] and mi > cost[j]:
mi, index = cost[j], j
# 计算路径值
for j in range(self.N):
if self.graph[j][index] == mi:
path[index] = j
known[index] = True
# 更新index连通其它结点的权重
for j in range(self.N):
if not known[j] and cost[j] > self.graph[index][j]:
cost[j] = self.graph[index][j]
print(path)
# 图用临接矩阵表示
MST([
[1e9, 6, 8, 1e9, 7, 1e9, 1e9, 1e9],
[6, 1e9, 7, 1e9, 1e9, 3, 4, 1e9],
[8, 7, 1e9, 1e9, 1e9, 1e9, 6, 1e9],
[1e9, 1e9, 1e9, 1e9, 1e9, 1e9, 1e9, 2],
[7, 1e9, 1e9, 1e9, 1e9, 1e9, 1e9, 1e9],
[1e9, 3, 1e9, 1e9, 1e9, 1e9, 1e9, 9],
[1e9, 4, 6, 1e9, 1e9, 1e9, 1e9, 7],
[1e9, 1e9, 1e9, 2, 1e9, 9, 7, 1e9],
]).prim(0)```
path结果为：[-1, 0, 6, 7, 0, 1, 1, 6]

## Kruskal算法实现

```class Edge(object):
def __init__(self, start, end, weight):
self.start = start
self.end = end
self.weight = weight
def getEdges(self):
edges = []
for i in range(self.vertex):
for j in range(i+1, self.vertex):
if self.graph[i][j] != 1e9:
edge = Edge(i, j, self.graph[i][j])
edges.append(edge)
return edges```

```def kruskal(self):
union = dict.fromkeys([i for i in range(self.vertex)], -1)  # 辅助数组，判断两个结点是否连通
self.edges = self.getEdges()
self.edges.sort(key=lambda x: x.weight)
res = []
def getend(start):
while union[start] >= 0:
start = union[start]
return start
for edge in self.edges:
# 找到连通线路的最后一个结点
n1 = getend(edge.start)
n2 = getend(edge.end)
# 如果为共同的终点则不处理
if n1 != n2:
print('{}----->{}'.format(n1, n2))
(n1, n2) = (n2, n1) if union[n1] < union[n2] else (n1, n2)
union[n2] += union[n1]
union[n1] = n2
res.append(edge)
print(union.values())```

+ 关注