https://www.acmicpc.net/problem/14401
MST(최소 스패닝 트리)+ CCW(선분 교차 판정)
아무 생각 없이 MST문제인 줄 알고 풀었는데 틀리고 뒤늦게 기하적인 요소가 있다는 걸 깨달았다.
문제의 기본 틀은 MST 변형과 같다. 노드들의 좌표를 받아 모든 간선을 다 계산하고(N^2) 기존에 연결된 간선들에 맞게 유니온 한 뒤 disjoint set의 개수(부모가 자신인 노도의 개수)를 세고 크루스칼 알고리즘을 이용해 그 개수보다 하나 더 적게 간선을 선택해주면 된다. 단, 문제의 요구는 최소 비용이 아니라 최대 비용이므로 정렬을 역순으로 해줘야 된다. 참고로 비용은 거리의 제곱이므로 변수들은 전부 정수만 사용해도 충분하다.
하지만 문제를 잘 읽어보면 간선이 노드를 지나가면 안된다는 조건이 있다. 따라서 모든 간선을 계산할 때 간선이 노드를 하나라도 지나간다면 그 간선은 제외시켜야 된다. 이 작업에서 간선 하나당 연산 횟수가 N이 곱해지는데(모든 노드에 대해서 CCW) 노드의 개수가 500이하이므로 시간은 충분하다.
기하쪽의 문제를 거의 풀어보지 않아서 CCW 알고리즘의 정확한 구현을 모르는 관계로 내 마음대로 선분 위에 점이 있는지 확인하는 함수를 만들었다.
점 a, b로 이어진 선분이 있고 점 c가 있을때 직선ab의 기울기와 직선 ac의 기울기가 같고 점 c의 x, y좌표 각각이 a, b의 좌표 사이에 있으면 선분 ab가 점 c를 지나간다고 할 수 있다.
CCW를 좀 찾아보니 함수 자체가 훨씬 간결한 것 같아서 다음에 CCW 알고리즘과 함께 기하쪽을 공부를 해봐야겠다.
특정 알고리즘의 원리와 구현 방법에 대한 글도 블로그에 써볼 생각이었는데 아마 첫글은 CCW가 될 거 같다.
AC code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll N,M,par[501],t1,t2,cnt=-1,sol;
vector<pair<ll,ll>> v1(1);
vector<vector<ll>> v2;
ll find(ll a) {
if(par[a]==a) return a;
else return par[a]=find(par[a]);
}
ll uni(ll a,ll b) {
ll pa=find(a),pb=find(b);
if(pa>pb) swap(pa,pb);
if(pa==pb) return -1;
return par[pb]=pa;
}
ll dist(pair<ll,ll> &a,pair<ll,ll> &b) {
return (a.first-b.first)*(a.first-b.first)+(a.second-b.second)*(a.second-b.second);
}
bool check(pair<ll,ll> &a,pair<ll,ll> &b,pair<ll,ll> &c) {
if((a.second-b.second)*(a.first-c.first)!=(a.second-c.second)*(a.first-b.first)) return false;
ll mnX=min(a.first,b.first),mnY=min(a.second,b.second),mxX=max(a.first,b.first),mxY=max(a.second,b.second);
if(mnX>c.first||mxX<c.first||mnY>c.second||mxY<c.second) return false;
return true;
}
bool cmp(vector<ll> &a,vector<ll> &b) {
return a[0]>b[0];
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cin>>N>>M;
for(ll i=1;i<=N;i++) {
cin>>t1>>t2;
v1.push_back({t1,t2});
par[i]=i;
}
for(ll i=0;i<M;i++) {
cin>>t1>>t2;
uni(t1,t2);
}
for(ll i=1;i<=N;i++) if(par[i]==i) cnt++;
for(ll i=1;i<N;i++) for(ll k=i+1;k<=N;k++) {
ll tmp=0;
for(ll j=1;j<=N;j++) {
if(j==i||j==k) continue;
if(tmp=check(v1[i],v1[k],v1[j])) break;
}
if(!tmp) v2.push_back({dist(v1[i],v1[k]),i,k});
}
sort(v2.begin(),v2.end(),cmp);
for(ll i=0;cnt;i++) if(uni(v2[i][1],v2[i][2])>0) {
sol+=v2[i][0];
cnt--;
}
cout<<sol;
}
댓글남기기