[PS] 백준 2213번: 트리의 독립집합

|

문제

내가 이해한 문제 내용

트리의 독립집합은 인접하지 않은 정점들의 부분집합으로 정의되는데, 이 때 각 정점에 가중치가 주어질 경우 가중치의 최댓값을 갖는 최대 독립집합을 구하시오.

접근 방식

사회망 서비스 문제와 상당히 유사하기 때문에 DFS + DP로 접근했다.

어려웠던 점 & 배운 점

  • DFS를 짤 때 상태값이 들어가서 많이 헤맸다.
  • 독립집합을 출력하려면 해당 원소들을 추적해야 하는데 이 부분에서 각 노드에 원소들을 계속 저장하는 방식으로 시도했기 때문에 비효율적으로 짰다.
  • 혼자 풀긴 풀었지만 조금 더 개선하고자 Crocus님의 풀이를 보게 되었는데 너무 깔끔했고 추적하는 방식에 있어서 재귀를 사용한 것 + DFS에서 visit 배열을 사용하지 않았던 것이 인상적이었다.

시간복잡도

최적의 풀이를 기준으로 잡는다면,

  • DFS + DP하는데 $O(V+E)$
  • 독립집합 추적하는데 $O(V)$
  • 독립집합 정렬하는데 $O(VlgV)$
  • 최종적으로 $O(VlgV)$

코드

주석부분이 내가 통과한 코드고 아닌부분이 최적의 풀이이다.

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

int n,v1,v2;
int w[10001],visit[10001],dp[10001][2];
vector<int> g[10001],set[10001][2];
vector<int> ans;

void dfs(int v)
{
    visit[v] = 1;
    dp[v][0] = w[v];
    //set[v][0].push_back(v);

    for(int i=0; i<g[v].size(); i++){
        int next = g[v][i];
        if(!visit[next]){
            dfs(next);
            dp[v][0] += dp[next][1];
            //for(int j=0; j<set[next][1].size(); j++) set[v][0].push_back(set[next][1][j]);

            if(dp[next][0]>dp[next][1]){
                dp[v][1] += dp[next][0];
                //for(int j=0; j<set[next][0].size(); j++) set[v][1].push_back(set[next][0][j]);
            } else{
                dp[v][1] += dp[next][1];
                //for(int j=0; j<set[next][1].size(); j++) set[v][1].push_back(set[next][1][j]);
            }
        }
    }
}

void addSet(int prev, int here, int state)
{
    if(!state){
        ans.push_back(here);
        for(auto child : g[here]){
            if(prev==child) continue;
            addSet(here,child,1);
        }
    }
    else{
        for(auto child : g[here]){
            if(prev==child) continue;
            printf("%d\n",child);
            if(dp[child][0]>dp[child][1])
                addSet(here,child,0);
            else
                addSet(here,child,1);
        }
    }
}


int main(void)
{
    scanf("%d",&n);
    for(int i=1; i<=n; i++)
        scanf("%d",&w[i]);
    while(scanf("%d%d",&v1,&v2)!=EOF){
        g[v1].push_back(v2);
        g[v2].push_back(v1);
    }
    dfs(1);
    dp[1][0]>dp[1][1] ? addSet(-1,1,0) : addSet(-1,1,1);
    //sort(set[1][0].begin(),set[1][0].end());
    //sort(set[1][1].begin(),set[1][1].end());

    sort(ans.begin(), ans.end());

    printf("%d\n",max(dp[1][0],dp[1][1]));
    // if(dp[1][0]>dp[1][1]){
    //     for(int i=0; i<set[1][0].size(); i++)
    //         printf("%d ",set[1][0][i]);
    // } else{
    //     for(int i=0; i<set[1][1].size(); i++)
    //         printf("%d ",set[1][1][i]);
    // }

    for(auto node : ans) printf("%d ",node);

    return 0;
}