BOJ 5529:: 저택

http://acmicpc.net/problem/5529

일본 여행 기념 JOI 문제 풀이입니다. (???)

규칙을 알기 쉽게 정리하자면 다음과 같습니다.

  1. 처음에는 $(1, 1)$에서 위를 바라본 상태로 출발한다.
  2. 이동은 앞으로만 할 수 있다.
  3. 스위치를 만나면 스위치를 누르고 왼쪽 또는 오른쪽으로 꺾을 수 있다.

정리한 규칙을 바탕으로 문제에서 주어진 $8\times 9$ 입력의 정답을 쉽게 알 수 있습니다.

$(1, 1)$ 제외 방 19개를 지나고 스위치를 6번 누르므로 답은 25가 됩니다. 여기서 조금만 생각해보면 버튼이 없는 다른 방은 전혀 신경 쓸 필요가 없고, 따라서 버튼이 있는 방들을 꼭짓점으로 하는 그래프 문제라는 걸 깨닫게 됩니다. 대충 아래 그림과 같은 그래프가 되겠죠.

저는 처음에 이렇게 생각을 했는데, 이 그래프는 스위치 누르기를 전혀 고려하지 않고 있습니다. 스위치를 누르지 않고도 스위치가 있는 방에 도착만 하면 이동 방향을 꺾을 수 있는 형태의 그래프입니다. 스위치를 짝수 번 누르면 상하로만 움직일 수 있고, 홀수 번 누르면 좌우로만 움직일 수 있기 때문에 스위치를 짝수 번 아니면 홀수 번 눌렀는지에 따라 그래프가 달라져야 합니다.

왼쪽이 스위치를 짝수 번 눌렀을 때 그래프, 오른쪽이 홀수 번 눌렀을 때 그래프입니다. 처음에는 왼쪽에서 시작하고, 원하는 경우 스위치를 눌러 반대쪽 그래프의 같은 꼭짓점으로 이동하며 이때 거리는 스위치를 누르는 데 걸리는 시간, 즉 1입니다. 설명이 조금 복잡한데, 두 그래프를 합쳐서 하나로 나타내면 이해하기 쉽습니다.

빨간색 변은 스위치를 짝수 번 눌렀을 때 이동 방법, 초록색 변은 스위치를 홀수 번 눌렀을 때 이동 방법, 파란색 변은 스위치 누르기를 의미합니다. 빨간색과 초록색 변의 가중치는 단순히 두 스위치 사이의 기하학적 거리($x$좌표 또는 $y$좌표의 차)이고 파란색 변의 가중치는 모두 1입니다. 그러면 결국 문제는 $(1, 1)$의 빨간색 꼭짓점에서 시작하여 $(8, 9)$의 빨간색 또는 초록색 꼭짓점까지(어느 쪽이든 상관 없습니다) 최단 거리를 구하는 문제가 됩니다. 그래프에 최단 거리면 다익스트라죠.

약간 디테일한 부분에 신경 써 봅시다. 먼저, $(1, 1)$과 $(m, n)$에 스위치가 없는 경우를 고려합시다. 위 예제는 다행히 두 위치 모두 스위치가 있지만 예제 입력에선 둘 다 없습니다. $(m, n)$에 스위치가 없는 경우는 쉽습니다. 그냥 추가해 줍시다. 어차피 $(m, n)$에 도착하는 순간 더 움직이지 않을 테니 추가한다고 답이 달라지지는 않습니다. 반면에 $(1, 1)$에 스위치가 없다고 추가해버리면 시작하자마자 스위치를 눌러서 오른쪽으로 가버리는, 원래 문제에서 의도하지 않은 행동을 할 수 있기 때문입니다. 따라서 $(1, 1)$에 스위치가 없다면, $(1, 1)$에서 위로 걸어가다가 처음 만나는 스위치를 시작점으로 하면 됩니다. 물론 그래프 상에선 그 스위치에 해당하는 빨간색 꼭짓점이 됩니다. 결론적으론 $x$좌표가 1이고 $y$좌표가 최소인 스위치를 찾으면 된다고 할 수 있겠습니다. (여기서, 시작 꼭짓점의 distance가 0이 아닐 수도 있다는 데 주의합시다!) 만약 $x$ 좌표가 1인 스위치가 없다면? 그땐 절대 제일 왼쪽 세로줄($x$좌표 1)에서 못 벗어나니 그냥 -1을 출력하도록 예외 처리를 해야 합니다. ($m=1$이면 스위치 안 누르고 $(1, 1)$에서 $(1, n)$까지 갈 수 있지만 $m$이 2 이상이랬으니 고려하지 않아도 됩니다.)

또 신경 쓸 것 하나는 스위치 좌표에서 변을 계산해 내는 건데, 모든 스위치를 $x$좌표로 정렬($x$좌표가 같으면 $y$좌표로 정렬)했을 때 인접한 두 스위치의 $x$좌표가 같으면 그 사이에 빨간색 변이 있습니다. 마찬가지로 $y$좌표로 정렬하면 초록색 변을 계산할 수 있고요. 위 예제 가지고 직접 해보면 무슨 말인지 이해할 수 있으니 자세한 설명은 생략합니다.

아, 그리고 모든 방을 뺑뺑이 돌면 답이 대충 $m\times n + k$니까 $m$, $n$, $k$가 크면 답이 int 범위를 넘습니다. long long을 씁시다.

시간 복잡도는 변을 계산하는 데 $O(k\log{k})$가 걸리고, 꼭짓점의 개수가 $2k$, 변의 개수가 많아봤자 꼭짓점 당 3개씩 해서 $6k$이니까 다익스트라를 쓰는 데 역시 $O(k\log{k})$가 걸리므로 총 $O(k\log{k})$입니다.

#include<cstdio>
#include<algorithm>
#include<vector>
#include<queue>
#include<tuple>
 
#define INF (1LL << 42)
 
 
using namespace std;
 
typedef tuple<long long, int, int> T;
 
union Switch {
    int crd[2];
    struct {int x, y, i;};
};
 
int m, n, k, mn_idx = -1, start = -1;
Switch sw[200001];
 
vector<int> adj[200001][2];
long long dist[200001][2];
bool visited[200001][2];
priority_queue<T, vector<T>, greater<T>> pq;
 
 
bool cmp_x(Switch &a, Switch &b) {return a.x < b.x || (a.x == b.x && a.y < b.y);}
bool cmp_y(Switch &a, Switch &b) {return a.y < b.y || (a.y == b.y && a.x < b.x);}
bool cmp_i(Switch &a, Switch &b) {return a.i < b.i;}
 
 
int main(void) {
    scanf("%d%d%d", &m, &n, &k);
    for (int i = 0; i < k; i++) {
        scanf("%d%d", &sw[i].x, &sw[i].y);
        sw[i].i = i;
        if (sw[i].x == m && sw[i].y == n) mn_idx = i;
    }
    if (mn_idx == -1) {
        sw[k].x = m; sw[k].y = n; sw[k].i = k;
        mn_idx = k++;
    }
    
    for (int m : {0, 1}) {
        sort(sw, sw+k, m == 0? cmp_x : cmp_y);
        if (m == 0 && sw[0].x == 1) start = sw[0].i;
        for (int i = 0; i < k; i++) {
            if (i != 0 && sw[i-1].crd[m] == sw[i].crd[m])
                adj[sw[i].i][m].push_back(sw[i-1].i);
            if (i != k-1 && sw[i+1].crd[m] == sw[i].crd[m])
                adj[sw[i].i][m].push_back(sw[i+1].i);
        }
    }
    
    if (start == -1) {
        printf("-1");
        return 0;
    }
    
    sort(sw, sw+k, cmp_i);
    for (int i = 0; i < k; i++) dist[i][0] = dist[i][1] = INF;
 
    dist[start][0] = sw[start].y - 1;
    pq.push(T(dist[start][0], start, 0));
 
    int cur, mod, d;
    while (!pq.empty()) {
        do {
            cur = get<1>(pq.top());
            mod = get<2>(pq.top());
            pq.pop();
        } while (!pq.empty() && visited[cur][mod]);
 
        if (visited[cur][mod]) break;
 
        visited[cur][mod] = true;
        for (int nxt : adj[cur][mod]) {
            d = abs(sw[nxt].crd[1-mod] - sw[cur].crd[1-mod]);
            if (dist[nxt][mod] > dist[cur][mod] + d) {
                dist[nxt][mod] = dist[cur][mod] + d;
                pq.push(T(dist[nxt][mod], nxt, mod));
            }
        }
        if (dist[cur][1-mod] > dist[cur][mod] + 1) {
            dist[cur][1-mod] = dist[cur][mod] + 1;
            pq.push(T(dist[cur][1-mod], cur, 1-mod));
        }
    }
    
    long long mov_t = min(dist[mn_idx][0], dist[mn_idx][1]);
    printf("%lld", mov_t != INF? mov_t : -1LL);
 
    return 0;
}