https://www.acmicpc.net/problem/5254
문제를 간단하게 요약하자면 배열 a1, a2, a3, … an에서
① 1 <= i < j <= n인 임의의 i, j에 대해 ai ~ aj의 값을 전부 aj로 바꿀때 배열 전체 합의 최대값을 구하고
② 또 임의의 i, j에 대해 ai ~ aj의 값을 전부 ai로 바꿀때의 최대값도 구하는 것이다.
일단 문제를 관찰해보면 ①을 해결하면 ②는 배열을 뒤집어서 똑같은 방법을 사용하면 해결이 가능하다는 것을 알 수 있다.
따라서 ①을 해결할 방법을 생각해보자.
D[x]를 i=x일때 배열 전체 합의 최대값이라고 하고 S[x]를 a1 ~ ax의 합이라고 하자. 즉, 누적합이다. (S[0]=0)
D[x]는 다음과 같이 나타낼 수 있다.
D[x]=max(S[y-1] + (x - y + 1) * ax + S[n] - S[x]) (1 <= y < x)
해당 식은 다음과 같이 정리가 가능하다.
D[x]=max(S[y-1] - y * ax) + (x + 1) * ax + S[n] - S[x] (1 <= y < x)
모든 y에 대해 나이브하게 탐색을 진행하면 시간복잡도가 O(N^2)이므로 tle가 나오지만 max안의 식이 일차함수 형태이므로 CHT를 적용할 수 있다.
이 경우 시간복잡도를 O(N)까지 줄이는 것이 가능하다.
아래의 코드는 CHT를 리차오 트리로 구현하여 시간복잡도가 O(NlogN)이지만 3 <= n <= 300000이므로 충분하다.
AC code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf = 2e18;
struct Line {
ll a,b;
ll get(ll x) {
return a*x+b;
}
};
struct Node {
int l,r;
ll s,e;
Line line;
};
struct Li_Chao {
vector<Node> tree;
void init(ll s,ll e) {
tree.resize(0);
tree.push_back({-1,-1,s,e,{0,-inf}});
}
void update(int node,Line v) {
ll s=tree[node].s,e=tree[node].e;
ll m=s+e>>1;
Line low=tree[node].line, high=v;
if(low.get(s)>high.get(s)) swap(low,high);
if(low.get(e)<=high.get(e)) {
tree[node].line=high;
return;
}
if(low.get(m)<high.get(m)) {
tree[node].line=high;
if(tree[node].r==-1) {
tree[node].r=tree.size();
tree.push_back({-1,-1,m+1,e,{0,-inf}});
}
update(tree[node].r,low);
} else {
tree[node].line=low;
if(tree[node].l==-1) {
tree[node].l=tree.size();
tree.push_back({-1,-1,s,m,{0,-inf}});
}
update(tree[node].l,high);
}
}
ll query(int node,ll x) {
if(node==-1) return -inf;
ll s=tree[node].s,e=tree[node].e;
ll m=s+e>>1;
if(x<=m) return max(tree[node].line.get(x),query(tree[node].l,x));
else return max(tree[node].line.get(x),query(tree[node].r,x));
}
} seg;
bool cmp(pair<ll,ll> a,pair<ll,ll> b) {
if(a.first!=b.first) return a.first<b.first;
else return a.second<b.second;
}
ll N,c[300001],s[300001],sol;
void func() {
for(int i=1;i<=N;i++) s[i]=c[i]+s[i-1];
seg.init(-1e9,1e9);
sol=-inf;
for(int i=1;i<=N;i++) {
sol=max(sol,seg.query(0,c[i])+(i+1)*c[i]+s[N]-s[i]);
seg.update(0,{-i,s[i-1]});
}
cout<<sol<<"\n";
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin>>N;
for(int i=1;i<=N;i++) cin>>c[i];
func();
reverse(c+1,c+1+N);
func();
}
댓글남기기