hungarian.cpp (3)
Bài toán
Thực hiện thuật toán Hungary trên đồ thị hai phía.
Độ phức tạp
tối đa lên đến : O(n^4)
Code này của Nguyễn Tiến Trung Kiên
#include <stdio.h>
#include <vector>
#include <queue>
using namespace std;
void minimize(int &a, int b){ if (a>b) a=b; }
typedef pair<int, int> ii;
#define X first
#define Y second
int n, m;
vector<ii> a[12309]; //
int fx[12309], fy[12309]; //
int matchX[12309], matchY[12309]; //
int d[12309]; //
int bfs(){
queue<int> qu;
int i, u, v;
for (i=1; i<=n; i++) if (matchX[i]==0) qu.push(i);
for (i=1; i<=n; i++) d[i]=0;
while (qu.size()){
u=qu.front(); qu.pop();
for (i=0; v=a[u][i].Y; i++)
if (d[v]==0 && a[u][i].X-fx[u]-fy[v]==0){
d[v]=u;
if (matchY[v]==0) return v;
else qu.push(matchY[v]);
}
}
return 0;
}
void enlarge(int u){
int x, y;
while (y=u){
x = d[y];
u = matchX[x];
matchX[x]=y;
matchY[y]=x;
}
}
void subX_addY(int start){
int delta = 1000111000;
int u, i, v;
static bool usedX[12309];
static bool usedY[12309];
for (i=1; i<=n; i++) usedX[i]=false;
for (i=1; i<=n; i++) usedY[i]=false;
usedX[start] = true;
for (i=1; i<=n; i++)
if (d[i]){
usedX[matchY[i]] = true;
usedY[i] = true;
}
// evaluate delta
for (u=1; u<=n; u++)
if (usedX[u])
for (i=0; v=a[u][i].Y; i++)
if (!usedY[v])
minimize(delta, a[u][i].X - fx[u] - fy[v]);
for (i=1; i<=n; i++) if (usedX[i]) fx[i] += delta;
for (i=1; i<=n; i++) if (usedY[i]) fy[i] -= delta;
}
int cost(int u, int w){
int i, v;
for (i=0; v=a[u][i].Y; i++)
if (v==w) return a[u][i].X;
}
int trace(int u){
int i, r=0;
i=u;
do {
i=matchX[i];
r += cost(i, matchX[i]);
d[i]=1;
} while (i!=u);
return r;
}
main(){
int i, p, q, w, u, T;
int sum=0 , count = 0; //
scanf("%d", &T);
while (T--){
scanf("%d%d", &n, &m);
for (i=1; i<=n; i++) {
a[i].clear();
fx[i]=fy[i]=d[i]=0;
matchX[i] = matchY[i] = 0;
}
sum=0 , count = 0;
for (i=1; i<=m; i++){
scanf("%d%d%d", &p, &q, &w);
if (p!=q)
a[p].push_back(ii(w, q));
}
for (i=1; i<=n; i++) a[i].push_back(ii());
for (i=1; i<=n; i++)
do {
u = bfs();
if (u==0) subX_addY(i);
else enlarge(u);
} while (u==0);
for (i=1; i<=n; i++) d[i]=0;
for (i=1; i<=n; i++) if (d[i]==0) { sum += trace(i); count ++; }
printf("%d\n", sum);
}
}
Nhận xét
Code trên đã được nộp thành công trên một bài tập của SPOJ.