코딩/알고리즘

[알고리즘 문제] 25402번 트리와 쿼리

Study_Cat 2024. 5. 29. 19:57

 

출처 : https://www.acmicpc.net/problem/25402

 

 

 

1. 풀이

해당 문제는 주어진 S에 대하여 집합 혹은 그룹을 만들고 각 그룹의 인원수를 Combination을 이용하여 푸는 것 임을 바로 알 수 있다. 이 문제에서 중요한 점은 주어진 집합 S에 대해 그룹의 인원수를 파악하는 것이다. 

 

각 질문마다 S집합에 속한 노드를 chk배열에 확인해두고 단순히 dfs를 돌면 시간초과가 날 것이다. 그 까닭은 dfs는 O(N)이 아니라 O(V+E) 이기 때문이다. 이러한 dfs는 해당 노드에 연결된 모든 간선을 탐색하므로 해당 문제에서 최악의 경우 각 질문마다 O(N) 이 발생하며 결과적로 N*O(N) = O(N^2) 이 되며 시간초과가 난다.

 

위 예시에서 빨간색은 S집합에 속한 노드를 뜻하는데, 인접 배열을 통해 dfs를 할 경우 모든 간선에 대해 탐색하여 필요없는 노드까지 탐색함을 알 수 있다. 

 

 

 

위 문제로 인해 단순한 dfs는 자신과 연결된 모든 간선에 대해 탐색하기에 안되는데, 이를 해결하는 방법은 "트리의 성질" 을 이용하는 것이다. 그래프와 달리 트리는 "부모"  라는 개념이 존재하며 1개의 루트만을 가진다. 즉 기존의 그래프는 무방향 연결로 퍼지는 방식으로 탐색했다면 트리는 단방향 간선으로 탐색함으로 써 필요없는 간선에 대해 탐색하지 않을 수 있다.

 

2. 코드

#include <bits/stdc++.h>
using namespace std;

#define S 250005
int N, M, path[S], cnt[S], p[S],chk[S], epoch;
vector<int> v[S], his;

void update(int n, int bef)
{
    path[n]=bef;
    for(auto&i:v[n]) if(i!=bef) update(i, n);
}

int get_p(int n)
{
    if(chk[path[p[n]]]!=epoch) return p[n];
    p[n]=get_p(path[p[n]]), cnt[p[n]]+=1;
    return p[n];
}

int main()
{
    int x, y;
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin>>N;
    for(int i=1;i<N;i++){
        cin>>x>>y;
        v[x].push_back(y); v[y].push_back(x);
    }

    update(1,0);

    cin>>M;
    for(epoch=1;epoch<=M;epoch++){

        int Q,n;
        cin>>Q;
        for(int j=0;j<Q;j++){
            cin>>n;
            his.push_back(n);
            cnt[n]=1, chk[n]=epoch, p[n]=n;
        }
        long long sum = 0;
        for(auto&j:his)
            if(p[j]==j) get_p(j);
        for(auto&j:his) sum+=1LL*cnt[j]*(cnt[j]-1)/2;
        cout<<sum<<"\n";
        his.clear();
    }

}

 

 

3. 소감

처음에 dfs를 돌면 시간초과가 난다는 사실은 인지하였지만, 이를 어떻게 해결할 지 생각하기 힘들었다. 시간초과가 나는 이유와 이를 어떻게 처리하면 좋을지 고민하면서 트리의 성질인 루트는 오직 1개라는 사실을 이용하고자 하였고 그룹의 병합을 위해 union_find를 사용했다. 이 문제를 풀면서 자료 구조에 대해 익숙해질 필요가 있다고 느꼈다.

 

트리는 인접 노드를 전부 탐색할 필요없이 조상에 대해서만, 즉 단방향 간선으로 탐색하면 된다!