Programming/algorithm

[APIO 2010] 특공대 (Commando) 풀이

flcat 2020. 6. 23. 17:42

 

https://www.acmicpc.net/problem/4008

 

4008번: 특공대

입력은 세 줄로 구성된다. 첫 번째 줄에 전체 병사들 수인 양의 정수 n이 주어진다. 두 번째 줄에 특공대의 조정된 전투력 계산 등식의 계수인 세 정수 a, b, c가 주어진다. 마지막 줄에 병사들 1, 2,

www.acmicpc.net

Convex Hull Trick (CHT) 를 이용하는 대표 문제로써 2010년 출제 당시 학생들에게 충공깽을 선사했지만 지금은 국민문제가 되어버린 비운의 문제다.

 

우선 간단한 O(n3), O(n2) 해법을 간단히 소개하고 그 뒤 CHT 를 이용한 O(N) 해법을 소개하도록 하겠다.

 

1. O(N3) 해법

 

간단한 다이나믹 프로그래밍으로 해결 가능하다.

 

 D[n] 을 1부터 n 까지의 병사들을 적당히 나눴을 때 얻을 수 있는 최대 전투력의 합으로 정의하자.

 

그렇다면, D[0]=0 일것이고 (아무 병사도 선택하지 않은 경우 전투력은 0 이다),

S=∑k=j+1iXi(Xi는 i 번째 병사의 전투력)일 때

D[i]=maxj(D[j]+aS2+bS+c)가 될 것이다. 

 

즉 이를 알기 쉬운 말로 풀어쓰면

 

1번부터 i 번째 병사들을 적당히 나눴을 때 얻을 수 있는 최대 전투력은

 

i 보다 작은 j 에 대하여 1번째부터 j 번째 병사들을 적당히 나눴을 때 얻을 수 있는 최대 전투력과 (j + 1) 번째 병사부터 i 번째 병사까지 하나의 그룹으로 묶었을 때 전투력의 합의 최대치이다.

 

2. O(N2) 해법

 

위의 식을 잘 살펴보면 S=∑k=j+1iXi 이라는 부분이 있는데 이 S O(N) 이 아닌 O(1)만에 구할 수 있다.

 

그 방법은 바로 부분 합 (partial sum)을 이용하는 것이다.

 

즉, 처음부터 부분합 S[n]=X[0]+⋯+X[n] O(N) 시간안에 구해놓으면

 

S=∑k=j+1iXi=S[n]−S[j] 이다.

 

따라서 시간 복잡도를 O(N2)에서 O(N)으로 단축할 수 있다.

 

코드는 아래와 같다.

 

#include <iostream>
#define MAXN 1000000
#define SQ(x) ((x) * (x))
typedef long long ll;
using namespace std;
ll N, A, B, C, S[MAXN + 1], D[MAXN + 1];
// j + 1번째 병사부터 n 번째 병사까지를 하나의 그룹으로 묶었을때의 비용함수를 리턴한다.
inline ll evaluate(int n, int j) { 
	return A * (SQ(S[n] - S[j])) + B * (S[n] - S[j]) + C;
}
ll solve() {
	for (int i = 1; i <= N; i++) { 
		// 최댓값이 음수가 될 수 있으므로 j = 0 일 때의 함수값을 D[i] 의 기본 값으로 설정하고 시작한다.
		// 즉 D[i] 의 초기값은 evaluate(i, j=0) + D[j=0] 인데 D[0] = 0 이므로 생략한다.
		D[i] = evaluate(i, 0);
		for (int j = 1; j < i; j++) {
			D[i] = max(D[i], D[j] + evaluate(i, j));
        }
    }
    return D[N];
}
int main() {
	ios::sync_with_stdio(0);
    cin.tie(0); cin>> N >> A >> B >> C;
    // 입력 받음과 동시에 누적합을 구한다.
    for (int i = 1; i <= N; i++) cin>>S[i], S[i] += S[i-1];
    cout << solve();
}

 

3. O(N) 해법

 

오늘의 최종보스 CHT 이다.

 

D[n] 을 구하고 싶다고 가정해보자.

 

O(N2) 해법의 점화식에 의해서 D[n]은 아래와 같이 정의될 수 있다.

 

D[n]=maxi(Di+aS2+bS+c)

=maxi(Di+a(Sn−Si)2+b(Sn−Si)+c)

 

이를 전개해서 풀어보면

=maxi(a(Sn2−2SnSi+Si2)+b(Sn−Si)+c+Di)

이고 i에 종속적이지 않은 항들은 괄호 밖으로 뺄 수 있다.

 

즉,

=maxi{−2aSnSi+aSi2−bSi+Di}+aSn2+bSn+c

이다.

 

이는 i에 대한 최적화 문제이며 i에 종속적이지 않은 항들 (=aSn2+bSn+c)은 maxi{−2aSnSi+aSi2−bSi+Di}를 구한뒤 그냥 마지막에 더해주면 된다. 따라서 우선, −2aSnSi+aSi2−bSi+Di를 고려해보자.

 

자 이제 Sn=x로 치환해보자. 그러면,

−2aSnSi+aSi2−bSi+Di=(−2aSi)x+(aSi2−bSi+Di)

가 되면 이는 x=Sn에 대한 일차함수 (선분) 이다!

 

n 보다 작은 i들에 대해 −2aSnSi+aSi2−bSi+Di  Sn에 대한 일차함수의 집합으로 나타낼 수 있음을 발견하였다.

 

그럼 이 선분들로 도대체 뭘 어떻게 해야하는가?

 

우선 위 일차함수들을 그래프로 나타내보자.

(출처: dyngina 님 블로그)

 

여기서 1,2,3,4 는 n보다 작은 i들에 대한 일차함수이다.

 

이 중, 현재 Sn의 값에서 따라 (x 축) 일차함수의 값이 최대가 되는 선분을 고르면 그것이 바로 −2aSnSi+aSi2−bSi+Di 의 최댓값이다.

 

그럼 과연 최적의 선분은 어떻게 구할것인가? 모든 선분을 다 살펴보면 간단하겠지만, 그럼 O(N2)으로 위에서 식을 힘들게 조작한 이유가 사라진다.

 

그 방법은 바로 Li-Chao Tree 라는 세그먼트 트리의 변형을 사용하여 log⁡(N)번만에 최대값을 반환하는 선분을 구하는 것인데, 사실 이 방법은 너무 복잡하다.

 

사실, 이 문제의 여러가지 특수 조건들을 사용하면 Li-Chao Tree 라는것을 사용할 필요 없이 그것도 Amortized O(1)번만에 최대값을 반환하는 선분을 구할 수 있다.

 

여기서 아래와 같은 점에 주목해보자.

 

1) a<0 이다.

2) xi>0 이다.

 

위 두가지 관찰로부터, −2aSi 는 항상 양수라는걸 알 수 있다. 따라서, 선분의 기울기는 항상 양수이다.

 

또한 두번째 관찰로부터 아래와 같은 사실을 발견할 수 있다.

 

3) j<i이면, Sj<Si이다. 

 

이는 당연한것이다. xi가 항상 양수이므로 뒤의 누적합이 앞의 누적합보다 무조건 크다.

 

그리고 세번째 관찰로부터 우리는 중요한 사실을 하나 얻을 수 있다.

 

** 선분의 기울기는 뒤로 갈수록 계속 증가한다. **

 

이 사실로부터 알 수 있는 것은 선분 후보 1번, 2번이 있을 때

 

후보2번이 후보1번의 값보다 크다면 후보1번은 더이상 앞으로 쓰일일이 죽어도 없다는것이다.

 

따라서, 선분 후보들을 deque (덱)으로 관리하고 후보 2번이 후보1번보다 좋다면 후보1번을 후보 리스트에서 제외시켜주면 된다.

 

그러면 후보 2번이 후보 1번보다 좋은지 어떻게 알까?

 

이는 간단하게 후보 1번과 후보 2번의 교점을 구함으로써 알 수 있다.

 

만약 후보 1번과 후보 2번의 교점이 Sn보다 작다면 후보 2번의 기울기가 후보1번의 기울기보다 크기 때문에 후보2번의 값이 무조건 더 크다.

 

따라서 우리는 후보 1번과 2번의 교점이 Sn보다 작다는것만 검사하면 된다.

 

이 과정을 조건을 만족할때까지 계속하면 Sn에 대한 선분의 최대값을 구할 수 있다.

 

이는 아래의 코드와 같이 구현될 수 있다.

 

// 현재 선분 후보에 있는 선분중 맨 처음 두 선분의 번호를 candid[0], candid[1] 이라고 하자.
// candid[1]번째 선분의 기울기가 candid[0]번째 선분의 기울기보다 크므로, 만약 두 선분의 교점이 S[n] 보다 작다면
// x = S[n] 일 때 candid[1]번째 선분을 쓰는것이 candid[0]번째 선분을 쓰는것보다 더 좋다.
while(candid.size() > 1 && intersect(candid[0], candid[1]) < S[n]) candid.pop_front();

자 그럼 D[n]을 이 최고의 선분을 이용하여 구할 수 있다. 아까 상수는 선분에 포함시키지 않았으므로 이 선분을 더하는 것을 잊으면 안된다. D[n]은 아래와 같이 구할 수 있다.

D[n] = -2*A*S[i]*S[n] + A * SQ(S[i]) - B * S[i] + D[i] + evaluate(n);

그렇다면 S_n 에 대한 선분을 선분 후보리스트에 넣어야하는데 그냥 넣으면 될까?

 

아쉽게도 그냥 넣으면 안된다.

 

위에서 Sn에 대한 최고의 선분을 찾을때 우리는 다음과 같은 가정을 했다.

 

Sn−1 에 대해 "2번 후보는 1번 후보 다음으로 좋은 후보이다.

 

하지만 Sn에 대한 선분을 집어넣으므로써 이 가정은 깨질 수 있다.

 

위 그림을 다시 살펴보도록 하자.

 

(출처: dyngina 님 블로그)

 

선분이 1번 2번만 존재할때는 1번 다음의 최고의 후보가 2번이다.

 

하지만 3번을 집어넣게 되면 1번 다음 최고의 후보는 2번이 아닌 3번이다.

 

우리의 Sn 이 1,3 번 선분의 교점과 1, 2번 선분의 교점의 사이에 있다고 가정해보자. 만약 3번을 집어넣었는데 2번을 빼지 않는다면, 우리의 알고리즘은 1,2번의 교점이 Sn 보다 크므로 1번을 최고의 선분이라고 생각하고 종료할 것이다. 하지만 최고의 선분은 3번 (혹은 4번)이다.

 

그리고 2번의 기울기는 3번의 기울기보다 작기 때문에 1,3번의 교점이 1,2번의 교점보다 전에 있다면, 2번이 최고 선분이 되는일은 죽어도 네버 없다.

 

따라서 저러한 2번 선분을 제거해줘야 한다.

 

그러면 "저러한 2번 선분"의 기준은 무엇인가?

 

우리의 덱 (deque) 의 마지막에 있는 선분의 번호를 각각 j  k (j<k) 라고 하자. 또한 Sn에 대한 선분의 번호를 n이라 하자.

 

만약 n번 선분과 j번 선분의 교점이 j번 선분과 k번 선분의 교점보다 전에 있다면 k번째 선분은 아무런 쓸모가 없는 녀석이므로 제거해줘야 한다.

 

이는 아래와 같이 구현할 수 있다.

 

// 현재 선분 후보에 들어 있는 선분중 맨 끝 두 선분의 번호를 END(1), END(0) 라고 하고, 새로 추가된 선분의 번호를 n 이라 할 때
// n 번 선분과 END(1) 번 선분의 교점이 END(0) 선분과 END(1) 번 선분의 교점보다 앞에 있다면 n 번째 선분이 END(0) 번째 선분보다 항상 기울기가
// 크기 때문에 END(0) 번째 선분을 이용할 일이 절대 없다. 따라서 그냥 없앤다.
while (candid.size() > 1 && intersect(n, END(1)) < intersect(END(0), END(1)))
	candid.pop_back();
candid.push_back(n);

n번째 선분을 집어넣으므로써 필요없어지는 선분들을 제거 한 뒤, n번째 선분을 목록에 추가한다.

 

각 선분은 n번의 루프동안 한번 삭제되면 다시 방문되지 않는다. 따라서 전체 시간복잡도는 O(N)이다.

 

이 알고리즘을 해결하는 전체 코드는 아래와 같다.

 

#include <deque> #include <iostream>
#define MAXN 1000000 #define SQ(x) ((x) * (x))
#define END(diff) (candid[candid.size() - 1 - diff])
typedef long long ll;
using namespace std; 
ll N, A, B, C, S[MAXN + 1], D[MAXN + 1];
// i 번째 선분과 j 번째 선분 (j < i)의 교점을 찾는다. 
inline double intersect(int i, int j) {
	return (A * (SQ(S[j]) - SQ(S[i])) - B * (S[j] - S[i]) + D[j] - D[i]) / (2.0*A*(S[j]- S[i])); 
} 
// 비용함수 A * (S_n)^2 + B(S_n) + C 를 리턴한다. 
inline ll evaluate(int n) {
	return A * SQ(S[n]) + B * S[n] + C;
}
ll solve() {
// 선분 후보들을 저장하는 덱(deque).
	deque<int> candid;
	candid.push_back(0);
	for (int n = 1; n <= N; n++) {
		// 현재 선분 후보에 있는 선분중 맨 처음 두 선분의 번호를 candid[0], candid[1] 이라고 하자. 
		// candid[1]번째 선분의 기울기가 candid[0]번째 선분의 기울기보다 크므로, 만약 두 선분의 교점이 S[n] 보다 작다면 
		// x = S[n] 일 때 candid[1]번째 선분을 쓰는것이 candid[0]번째 선분을 쓰는것보다 더 좋다. 
		while(candid.size() > 1 && intersect(candid[0], candid[1]) < S[n]) candid.pop_front(); 
		int i = candid.front(); 
		D[n] = -2*A*S[i]*S[n] + A * SQ(S[i]) - B * S[i] + D[i] + evaluate(n); 
		// 현재 선분 후보에 들어 있는 선분중 맨 끝 두 선분의 번호를 END(1), END(0) 라고 하고, 새로 추가된 선분의 번호를 n 이라 할 때 
		// n 번 선분과 END(1) 번 선분의 교점이 END(0) 선분과 END(1) 번 선분의 교점보다 앞에 있다면 n 번째 선분이 END(0) 번째 선분보다 항상 기울기가 
		// 크기 때문에 END(0) 번째 선분을 이용할 일이 절대 없다. 따라서 그냥 없앤다. 
		while (candid.size() > 1 && intersect(n, END(1)) < intersect(END(0), END(1)))
			candid.pop_back();
		candid.push_back(n);
	} 
	return D[N]; 
} 
int main() {
	ios::sync_with_stdio(0);
	cin.tie(0); cin>> N >> A >> B >> C;
	// 입력 받음과 동시에 누적합을 구한다.
	for (int i = 1; i <= N; i++) cin>>S[i], S[i] += S[i-1];
	cout << solve(); 
}


출처: https://conankuns.tistory.com/14 [난쿤이의 끄적끄적]

 

아래는 자바로 푼 코드이다.

import java.io.*;
import java.util.LinkedList;
import java.util.StringTokenizer;

public class Main {
    static int n;
    static long a,b,c;
    static long s[] = new long[1000001];
    static long d[] = new long[1000001];

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st;

        n = Integer.parseInt(br.readLine());

        st = new StringTokenizer(br.readLine());
        a = Integer.parseInt(st.nextToken());
        b = Integer.parseInt(st.nextToken());
        c = Integer.parseInt(st.nextToken());

        st = new StringTokenizer(br.readLine());
        for(int i = 1 ; i <= n ; i++) {
            s[i] = Integer.parseInt(st.nextToken());
            s[i] += s[i-1];
        }

        bw.write(solve()+"");

        br.close();
        bw.flush();
        bw.close();
    }
    private static double intersect(int i,int j) {
        return (a * ((s[j] * s[j]) - (s[i] * s[i])) - b * (s[j] - s[i]) + d[j] - d[i]) / (2.0 * a * (s[j] - s[i]));
    }
    private static long evaluate(int n) {
        return a * (s[n]*s[n]) + b * s[n] + c;
    }

    private static long solve () {

        LinkedList<Integer> candid = new LinkedList<>();
        candid.addLast(0);

        for(int k = 1 ; k <= n ; k++) {
            while(candid.size() > 1 && intersect(candid.getFirst(), candid.get(1)) < s[k]) {
                candid.removeFirst();
            }
            int i = candid.getFirst();
            d[k] = -2 * a * s[i] * s[k] + a * (s[i]*s[i]) - b * s[i] + d[i] + evaluate(k);

            while (candid.size() > 1 && intersect(k,candid.get(candid.size()-2)) < intersect(candid.getLast(), candid.get(candid.size()-2))) {
                candid.removeLast();
            }
            candid.addLast(k);
        }
        return d[n];
    }
}

 

 

---------

출처 : https://conankuns.tistory.com/14

 

[APIO 2010] 특공대 (Commando) 풀이

https://www.acmicpc.net/problem/4008 4008번: 특공대 입력은 세 줄로 구성된다. 첫 번째 줄에 전체 병사들 수인 양의 정수 n이 주어진다. 두 번째 줄에 특공대의 조정된 전투력 계산 등식의 계수인 세 정수 a

conankuns.tistory.com

코난쿤님이 재밌는 문제를 추천해주시고 문제풀이도 최대한 자세히 해주셨다.

 

혼자였으면 지금 실력으론 어림도 없는 문제 아닌가 싶다.

 

친절히 알려주시는 코난쿤님에게 이 글을 빌어 감사하다고 말씀드리고 싶다.