834. Sum of Distances in Tree
Hard
There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.
You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.
Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.
Example 1:
Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.
Example 2:
Input: n = 1, edges = []
Output: [0]
Example 3:
Input: n = 2, edges = [[1,0]]
Output: [1,1]
Constraints:
- 1 <= n <= 3 * 104
- edges.length == n - 1
- edges[i].length == 2
- 0 <= ai, bi < n
- ai != bi
- The given input represents a valid tree.
문제 풀이
- 주어진 트리에서 각 노드에 도착하는 거리들의 총합을 구해야한다.
- 모든 노드를 시작점으로 하고 각각의 노드에 도착하는 거리를 더한 값을 정답배열에 넣고 리턴해야한다.
- 모든 노드를 탐색해야하기 때문에 완전탐색 접근을 생각할 수 있으나, 문제 제한사항에서는 최악의 경우 3 * 10^4 개의 노드가 존재한다.
- 그렇기 때문에 완전탐색으로 모든 노드를 탐색할때 O(N)이 걸리고, 각각의 노드를 모두 시작점으로 한번씩 탐색해야하기 때문에 O(N^2) 시간복잡도로 시간초과가 난다.
- 부분 트리로 거리의 총합을 계산하는 방법을 사용하여 O(N)으로 시간복잡도를 줄일 수 있다.
- 만약 a와 b트리가 한선으로 연결되어 있고 두개의 거리 총합을 구하는 식은 다음과 같다.
- a와 b가 연결된 거리 총합 = a의 거리 총합 + b의 거리 총합 + b의 총 노드 개수
- a와 b노드를 연결하는 간선이 한개 존재하기 때문에 a에서 b 각각의 노드에 도달하려면 1 씩 더해줘야한다.
- 예를 들어 a에는 0노드가 있고 b에는 1노드와 2노드, 3노드가 일렬로 연결되어 있다고 가정해보자.
- 0노드의 총거리는 0이며, b노드에는 1에서 2로 가는 거리 1과 1에서 3으로 가는 거리 2가 있어 3이된다.
- a의 노드 0에서 1까지 가는 길이는 개수 0개에서 1을 더해 1이고, 0에서 2까지 가려면 b노드에서 계산된 2에 간선 1을 추가하여 2가되고, 0에서 3까지 가는 거리는 1에서 3으로 가는 거리 2에 1을 더해 3이 있다.
- 결과적으로 0이 최상위 트리라고 했을때, b트리의 총 거리수 3 + b트리의 총 노드 개수 3을 계산하여 6이된다.
- 다시말해 a - b 로 연결된 트리의 정답은 a 정답 + (b의 1노드에서 2노드 + 1 + b의 1노드에서 3노드 + 1 + a에서 b 시작노드로 연결되는 간선 1) 이된다.
- 이를위해, dfs 후위 탐색을 사용하여 각각의 노드들이 가진 노드의 개수를 cnt 배열에 더해준다.
- 그리고 위식을 사용하여 정답에는 node위치와 child 위치의 정답을 구하기 위해 node의 정답 + child 정답 + child 노드의 총 개수를 더한다.
- 위식과 똑같이 a(node 트리)정답 + b(child 트리)정답 + b(child 트리) 노드 총 개수로 탐색을 완료한다.
- dfs를 끝내면 자식 노드들을 가지고 있는 노드들의 정답들과 자식들의 개수가 저장된다.
- 나머지 노드들의 정답을 채우기 위해서는 위의 식을 응용하여 dfs 선위 탐색을 진행한다.
- 1. a - b 의 정답 = a 정답 + b 정답 + b 노드의 총 개수
- 2. b - a 의 정답 = b 정답 + a 정답 + a 노드의 총 개수
- 이 두식을 합쳐 사용하는데, 현재 자식들이 존재하는 노드들의 정답은 저장되어 있다.
- 1번 식은 정답이 저장되어있는, 즉 자식들이 존재하는 노드이며 2번 식은 정답이 없는 노드들이다.
- 이 b노드의 값을 구하기 위해 2번 식에서 1번 식을 빼서 사용한다.
- 계산을 하게된 결과는 다음과 같다.
- b - a 의 정답 - a - b 의 정답 = a 노드의 총 개수 - b 노드의 총 개수
- b - a 의 정답 = a - b 의 정답 + a 노드의 총 개수 - b 노드의 총 개수
- a - b 의 정답과 b 노드의 총 개수는 미리 구해놓았기 때문에 바로 사용할 수 있다.
- a 노드의 총 개수는 전체 노드 개수에서 b 노드의 총 개수를 빼면 쉽게 구할 수 있다.
- 즉, b - a 의 정답 = a - b 의 정답 - b 노드의 총 개수 + a - b의 총 노드 개수 - b 노드의 총 개수
- 이식을 이용하여 dfs를 한번더 진행하면서 위식을 사용하여 정답들을 저장한다.
- 선위 탐색으로 dfs를 진행하는 이유는 순차적으로 먼저 계산을 해야 다음 정답을 참고할때 사용할 수 있다.
- 자식들이 있는 노드들은 첫번째 dfs에서 정답을 구했고, 두번째 dfs 에서 위 공식을 사용하여 트리 모양을 변경시켜 답을 구한다고 생각하면 이해하기 쉽다.
- 각각의 노드들을 2번 완전 탐색 하기 때문에 O(N + N)이며 최종적으로는 O(N) 의 시간이 걸린다.
소스 코드
class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
g = defaultdict(set)
for a, b in edges:
g[a].add(b)
g[b].add(a)
cnt = [1] * n
ans = [0] * n
def dfs(node, parent):
for child in g[node]:
if child != parent:
dfs(child, node)
cnt[node] += cnt[child]
ans[node] += ans[child] + cnt[child]
def dfs2(node, parent):
for child in g[node]:
if child != parent:
ans[child] = ans[node] - cnt[child] + n - cnt[child]
dfs2(child, node)
dfs(0, None)
dfs2(0, None)
return ans
'컴퓨터공학 > LeetCode 1000' 카테고리의 다른 글
[LeetCode] 6. Zigzag Conversion (1) | 2023.02.04 |
---|---|
[LeetCode] 1056. Confusing Number (0) | 2023.01.02 |
[LeetCode] 886. Possible Bipartition (0) | 2022.12.22 |
[LeetCode] 1971. Find if Path Exists in Graph (0) | 2022.12.21 |
[LeetCode] 1066. Campus Bikes II (0) | 2022.12.21 |