題目的意思是:
給一個有n個點,m條邊的無向圖
兩點之間可以存在多條邊
現在每次隨機增加一條邊
問使得全部點都連通需要增加多少次(期望值)
首先,求出所有連通分量。用并查集。
每次隨機增加一條邊的時候一共有兩種情況:
1)這條邊連接了兩個不同的連通分量,它的概率是p
2)這條邊在一個連通分量里,它的概率是q = 1 - p
前者可以改變連通分量的數量,后者不能
如果把當前圖的狀態視為一個子問題
那么就可以用動態規劃解決問題了
圖的狀態可以表示為:有多少個連通分量,每個連通分量包含多少個點
比如說圖的狀態 (2, 3, 3) 表示有三個連通分量,每個連通分量包含的點的個數分別為 2, 3, 3
動態規劃的轉移方程為:
f = p*(1+r) + p*q*(2+r) + p*q^2*(3+r) ....
其中r為p發生后,新狀態的期望值
這個東西高中的時候學過,呵呵。
而1)中也包含多種情況,需要兩兩枚舉
最大的問題是,f的值是一個無限數列,它的極值很難求。但無論如何,有高手求出來了。。在這里:http://archive.cnblogs.com/a/1325929/
它的極值是 f = p * (1 / (1 - q) + r) / (1 - q)
我對照了一下標程,確實是這個。
后來我自己推導了一下,發現它可以化成多個等比數列相加的形式,求和后,發現當n趨向于無窮大的時候,它的極限就是上面這個公式。
(注意:i*q^i, 當0<q<1,i趨向于無窮大的時候等于0)
這樣程序就可以寫了。動態規劃保存每個圖的狀態。
如果用python寫,只要建立一個tuple到float的映射就可以了。非常方便。
java中也有List<int>到Double的映射。
c里面估計就得用hash了。
py代碼,參照標程寫的。
fi = open('in')
fo = open('out')
dp = {():0}
ti = 0
def get(s):
if s in dp:
return dp[s]
q = sum([i*(i-1) for i in s])*1.0/2/nn
res = 0
for i in range(len(s)):
for j in range(len(s)):
if i < j:
l = list(s)
del l[max(i,j)]
del l[min(i,j)]
l.append(s[i]+s[j])
l.sort()
r = get(tuple(l))
p = s[i]*s[j]*1.0/nn
res += p*(1+r-r*q)/pow(1-q,2)
dp[s] = res
return res
while 1:
a = fi.readline().split()
if a == None or len(a) != 2:
break
N, M = int(a[0]), int(a[1])
nn = N*(N-1)/2
s = [ i for i in range(N) ]
for i in range(M):
u, v = [ int(i) for i in fi.readline().split() ]
u -= 1
v -= 1
k = s[u]
for j in range(N):
if s[j] == k:
s[j] = s[v]
ss = [ s.count(i) for i in set(s) ]
ss.sort()
print '----', ti
mine = get(tuple(ss))
ans = float(fo.readline().strip())
print 'mine', mine, 'ans', ans
print len(dp)
ti += 1
標程
用很簡潔的代碼寫了并查集,值得借鑒!
import java.util.*;
import java.io.File;
import java.io.PrintWriter;
import java.io.FileNotFoundException;
public class interconnect_pm {
private static int nn;
public static void main(String[] args) throws FileNotFoundException {
Scanner in = new Scanner(new File("in"));
PrintWriter out = new PrintWriter("ans.out");
int n = in.nextInt();
nn = (n * (n - 1)) / 2;
int m = in.nextInt();
int[] p = new int[n];
for (int i = 0; i < n; i++) p[i] = i;
for (int i = 0; i < m; i++) {
int u = in.nextInt();
int v = in.nextInt();
u--;
v--;
int k = p[u];
for (int j = 0; j < n; j++) {
if (p[j] == k) {
p[j] = p[v];
}
}
}
List<Integer> st = new ArrayList<Integer>();
for (int i = 0; i < n; i++) {
int s = 0;
for (int j = 0; j < n; j++) {
if (p[j] == i) s++;
}
if (s > 0) {
st.add(s);
}
}
Collections.sort(st);
List<Integer> fn = new ArrayList<Integer>();
fn.add(n);
mem.put(fn, 0.0);
out.println(get(st));
System.out.println(mem.size());
out.close();
}
static Map<List<Integer>, Double> mem = new HashMap<List<Integer>, Double>();
private static double get(List<Integer> st) {
Double ret = mem.get(st);
if (ret != null) return ret;
int m = st.size();
int[][] a = new int[m][m];
for (int i = 0; i < m; i++) {
for (int j = i + 1; j < m; j++) {
a[i][j] = st.get(i) * st.get(j);
}
}
int s = 0;
for (int i = 0; i < m; i++) {
s += st.get(i) * (st.get(i) - 1) / 2;
}
double res = 0;
for (int i = 0; i < m; i++) {
for (int j = i + 1; j < m; j++) {
List<Integer> ss = new ArrayList<Integer>(st.size() - 1);
boolean q = true;
int z = st.get(i) + st.get(j);
for (int k = 0; k < m; k++) {
if (k != i && k != j) {
int zz = st.get(k);
if (q && zz >= z) {
q = false;
ss.add(z);
}
ss.add(zz);
}
}
if (q)
ss.add(z);
double p = a[i][j] * 1.0 / (nn - s);
double e = a[i][j] * 1.0 / ((1 - s * 1.0 / nn) * (1 - s * 1.0 / nn) * nn);
e = e + get(ss) * p;
res += e;
}
}
System.out.println(st);
mem.put(st, res);
return res;
}
}