HS_development_log

세그먼트 트리(Segment Tree) / Java 본문

Algorithm-이론

세그먼트 트리(Segment Tree) / Java

DevHyeonseong 2020. 1. 13. 19:28
반응형

자바를 이용한 세그먼트 트리의 기본구조입니다.

이진트리 구조로 구현되었습니다.

구간합을 예시로 들어서 구현했습니다.

 

세그먼트 트리 초기화

 

예를 들어 [ 1 , 2 , 3 , 4 , 5 , 6 ] 의 배열을 원소로갖는 세그먼트 트리의 모습은 아래 그림처럼 나와야합니다.

완성됐을때의 세그먼트 트리

아래 코드로 이러한 모습이 구현이 가능합니다.

	public int init(int start, int end, int node) {
		if(start == end) { /* 리프노드이거나 자식노드들이 구간합이 모두구해졌을 경우 */
			return tree[node] = arr[start]; /* 구간합 트리에 넣어준다 */
		}
		/* 반씩 나눠서  재귀적으로 자식노드들의 구간합을 구해준다 */
		int mid = (start+end)/2;
		return tree[node] = init(start, mid, node*2) + init(mid+1, end, node*2+1);
	}

 

코드 동작 순서

각 구간합 옆에있는 숫자가 코드가 실행되는 순서입니다. 세모안에들어가있는 것은 구간합을 구하지못하여 계속 재귀적 탐색을 하는경우고 동그라미 안에 들어가있는 것은 구간합을 구해서 세그먼트 트리에 값을 채우는 경우입니다.

 

세그먼트 트리 구간합

	public int sum(int start, int end, int node, int left, int right) {
		if(left>end || right < start) {
			return 0;
		}
		if(left <=start && end <=right) {
			return tree[node];
		}
		/* 필요한 구간마다 밑에서부터 구간합을 가지고 올라온다 */
		int mid = (start+end)/2;
		return sum(start, mid, node*2, left, right) + sum(mid+1, end, node*2+1, left, right);
	}

코드 동작 순서

<인덱스번호> 입니다.

예를들어 2~4번까지의 구간합을 구하고 싶다면

<0~5> 까지의 인덱스를 갖는 트리의 자식노드로 들어가야합니다.

<0~2> 까지의 구간합에서 2번의 합을 구해야하므로 자식노드로 들어갑니다.

<0~1>은 필요없으므로 0을리턴, <2>는 필요한값이므로 구간합인 3을리턴해서 위로올립니다.

그러면 <0~2> 까지의 구간합에서 <2>에 대한정보인 3을 구했습니다.

같은 방식으로 <3~5> 구간에서 <3~4>까지의 구간에 대한정보를 구하면 9

따라서 두개를 더해주면 <2~4>까지의 구간합은 12라는것을 알 수 있습니다.

 

세그먼트 트리 값 수정

	public void update(int start, int end, int node, int index, int dif) {
		if(index < start || index > end) {
			return;
		}
		tree[node] += dif; /* 변경된 값만큼 더해주고 */
		if(start == end) {
			return;
		}
		/* 변경된 값이 속해있는 구간의 구간합을 모두 바꿔준다 */
		int mid = (start + end)/2;
		update(start, mid, node*2, index, dif);
		update(mid+1, end, node*2+1, index, dif);
	}

값을 수정하는것은 구간합을 구해주는 방식과 비슷합니다.

수정하고싶은 인덱스가 포함되어있는 모든구간을 바뀐만큼 변경해주면 됩니다.

 

 

소스코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import java.util.*;
public class SegmentTree{
    public static void main(String[] args) {
        Tree st = new Tree();
        st.init(0, st.arr.length-1,1);
        
        Scanner scan = new Scanner(System.in);
        
        System.out.print("n번 부터 m번 까지의 구간합 : n,m 을 입력하시오 : ");
        System.out.println(st.sum(0, st.arr.length-11, scan.nextInt(), scan.nextInt()));
        
        System.out.print("n번째 인덱스의 값을 m 으로 변경 : n,m 을 입력하시오 : ");
        int n = scan.nextInt();
        int m = scan.nextInt();
        st.update(0, st.arr.length-110, m-st.arr[n]);
        System.out.println("변경된 세그먼트트리의 전체합 : " + st.sum(0, st.arr.length-1105));
    }
}
class Tree {
    int arr[]; // 구간합을 만들 요소들
    int tree[]; // 구간합 트리
    public Tree() {
        Scanner scan = new Scanner(System.in);
        System.out.print("배열 크기 입력 : ");
        arr = new int[scan.nextInt()];
        tree = new int[arr.length*4];
        System.out.print("배열 요소 입력 : ");
        for(int i=0;i<arr.length;i++) {
            arr[i] = scan.nextInt();
        }
    }
    public int init(int start, int end, int node) {
        if(start == end) { /* 리프노드이거나 자식노드들이 구간합이 모두구해졌을 경우 */
            return tree[node] = arr[start]; /* 구간합 트리에 넣어준다 */
        }
        /* 반씩 나눠서  재귀적으로 자식노드들의 구간합을 구해준다 */
        int mid = (start+end)/2;
        return tree[node] = init(start, mid, node*2+ init(mid+1, end, node*2+1);
    }
    public int sum(int start, int end, int node, int left, int right) {
        if(left>end || right < start) {
            return 0;
        }
        if(left <=start && end <=right) {
            return tree[node];
        }
        /* 필요한 구간마다 밑에서부터 구간합을 가지고 올라온다 */
        int mid = (start+end)/2;
        return sum(start, mid, node*2, left, right) + sum(mid+1, end, node*2+1, left, right);
    }
    public void update(int start, int end, int node, int index, int dif) {
        if(index < start || index > end) {
            return;
        }
        tree[node] += dif; /* 변경된 값만큼 더해주고 */
        if(start == end) {
            return;
        }
        /* 변경된 값이 속해있는 구간의 구간합을 모두 바꿔준다 */
        int mid = (start + end)/2;
        update(start, mid, node*2, index, dif);
        update(mid+1, end, node*2+1, index, dif);
    }
    public void print() {
        for(int i=1;i<this.tree.length;i++) {
            System.out.print(i + "번째 인덱스 : " + this.tree[i]);
            System.out.println();
        }
    }
}
 
cs
반응형