要么改变世界,要么适应世界

利用KM算法求解最大权匹配

2020-11-27 13:28:00
88
目录

前言

KM算法是一种计算机算法,功能是求完备匹配下的最大权匹配。在一个二分图内,左顶点为X,右顶点为Y,现对于每组左右连接Xi->Yj有权wij,求一种匹配使得所有wij的和最大

参考博客:KM算法入门

算法实现

题目背景

奔小康挣大钱

题解

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;
const int MAXN = 305;
const int INF = 0x3f3f3f3f;

int love[MAXN][MAXN];
int ex_user[MAXN];
int ex_house[MAXN];
bool vis_user[MAXN];
bool vis_house[MAXN];
int match[MAXN];
int slack[MAXN];

int N;

bool dfs(int user) {
    vis_user[user] = true;
    for (int house = 0; house < N; ++house) {
        if (vis_house[house]) continue;
        int gap = ex_user[user] + ex_house[house] - love[user][house];
        if (gap == 0) {
            vis_house[house] = true;
            if (match[house] == -1 || dfs(match[house])) {
                match[house] = user;
                return true;
            }
        } else {
            slack[house] = min(slack[house], gap);
        }
    }
    return false;
}

int KM() {
    memset(match, -1, sizeof match);
    memset(ex_house, 0, sizeof ex_house);
    for (int i = 0; i < N; ++i) {
        ex_user[i] = love[i][0];
        for (int j = 1; j < N; ++j) {
            ex_user[i] = max(ex_user[i], love[i][j]);
        }
    }
    for (int i = 0; i < N; ++i) {
        fill(slack, slack + N, INF);
       //如果为用户 i 匹配房子失败 则通过降低期望值进行尝试匹配
        while (1) {
            memset(vis_user, false, sizeof vis_user);
            memset(vis_house, false, sizeof vis_house);
            // 如果为用户 i 匹配房子成功
            if (dfs(i)) break;
            int d = INF;
            for (int j = 0; j < N; ++j)
                if (!vis_house[j]) d = min(d, slack[j]);
            for (int j = 0; j < N; ++j) {
                if (vis_user[j]) ex_user[j] -= d;
                if (vis_house[j])
                    ex_house[j] += d;
                else
                    slack[j] -= d;
            }
        }
    }
    int res = 0;
    for (int i = 0; i < N; ++i) res += love[match[i]][i];
    return res;
}

int main() {
    while (~scanf("%d", &N)) {
        for (int i = 0; i < N; ++i)
            for (int j = 0; j < N; ++j) scanf("%d", &love[i][j]);
        printf("%d\n", KM());
    }
    return 0;
}
历史评论
开始评论