POJ3417 Network – 树上差分 – lca
题意:给出一个无向图,分别给出n-1条树边(主要边)和m条非树边(附加边),这个无向图可以看做一棵树外加m条附加边,你可以切断一条主要边和一条附加边,求切割后,能够使这个无向图不再连通的切割方案数(即使只切断一条主要边就可以使图不连通,你也需要再切断一条附加边)
我们先考虑只有一条附加边(x,y)时,这时这张图就是一棵基环树
我们发现如果x,y之间有一条附加边,则这条边和x到y的路径组成了一个环,如果说我们要切割x到y的路径上的一条主要边,我们必须要再切断这条附加边,才能使图不再连通
那么如果x,y之间有两条或以上附加边,若我们已切割了x到y路径上的一条主要边,那么是无法通过仅再切割一条附加边来使图不再连通
因而我们每次读入一条附加边,就给x到y的路径上的所有主要边记录上“被覆盖一次”,这样再去遍历所有主要边
对于我们想要切割的一条主要边,有以下3种情况
- 若这条边被覆盖0次,则可以任意再切断一条附加边
- 若这条边被覆盖1次,那么只能再切断唯一的一条附加边
- 若这条边被覆盖2次及以上,没有可行的方案
现在的问题是如何快速求出每条边被覆盖了多少次,对于这类问题,可以类比序列差分,有树上差分算法
设差分数组dif初值为0,若x,y有一条附加边,则dif[x]++,dif[y]++,dif[lca(x,y)]-=2
设f(x)为以x为根的子树中所有节点dif之和,则f(x)就是x到其父节点的边被覆盖的次数
没错,求的是子树和,所以说求各种类似前缀和的东西真的很好用,一般可以用这些方式对区间O(1)求解
为了更快地解决问题,有必要总结各种思想与算法,并分析其适用于什么样的情况,要时常复习自己的博客啊…
比如说这题给我的启示是,树上问题可以类比序列问题,比如说求树上s到t点的路径和,就可以用树上前缀和
最后处理答案的时候注意,若f[x] == 0,则x可能为根节点,要特判,根节点没有父节点!!!
一般我们记录一个点“上面的一条边”
特别注意,LCA若是循环从20开始,则数组要开到21以上!不然会越界RE
#include
#include
#include
#include
#include
using namespace std;
#define debug(x) cerr << #x << "=" << x << endl;
const int MAXN = 100000 + 10;
const int INF = 1<<30;
int n,m,tot,last[MAXN],f[MAXN][21],d[MAXN],dif[MAXN],vis[MAXN],ans,sta[MAXN];
struct Edge {
int u,v,w,to;
Edge(){}
Edge(int u, int v, int to) : u(u), v(v), to(to) {}
}e[MAXN * 2];
inline void add(int u, int v) {
e[++tot] = Edge(u,v,last[u]);
last[u] = tot;
}
void lca_dfs(int now) {
for(int i=last[now]; i; i=e[i].to) {
int v = e[i].v;
if(d[v]) continue;
f[v][0] = now;
d[v] = d[now] + 1;
lca_dfs(v);
}
}
void ica_init() {
d[1] = 1;
lca_dfs(1);
for(int k=1; k<=20; k++) {//LCA若是循环从20开始,则数组要开到21以上!不然会越界RE
for(int i=1; i<=n; i++) {
f[i][k] = f[f[i][k-1]][k-1];
}
}
}
int lca(int x, int y) {
if(d[x] > d[y]) swap(x, y); //不妨设x浅于y
for(int i=20; i>=0; i--) {
if(d[f[y][i]] >= d[x]) y = f[y][i];
}
if(x == y) return x;
for(int i=20; i>=0; i--) {
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
}
return f[x][0];
}
int dfs(int now) {
vis[now] = 1;
sta[now] = dif[now];
for(int i=last[now]; i; i=e[i].to) {
int v = e[i].v;
if(vis[v]) continue;
sta[now] += dfs(v);
}
return sta[now];
}
int main() {
cin >> n >> m;
for(int i=1; iint u,v;
scanf("%d %d", &u, &v);
add(u, v);
add(v, u);
}
ica_init();
for(int i=1; i<=m; i++) {
int u, v;
scanf("%d %d", &u, &v);
dif[u]++;
dif[v]++;
dif[lca(u,v)] -= 2;
}
dfs(1);
for(int i=1; i<=n; i++) {
if(sta[i] == 0 && i != 1) {
ans += m;
}
if(sta[i] == 1) {
ans++;
}
}
printf("%d\n", ans);
return 0;
}