題意 "題目鏈接" Sol NOI的題都這麼良心麽。。 先交個$n^4$暴力 = 75 hash優化一下 = 90 然後$90$到$100$分之間至少差了$10$難度臺階= = $90$分的暴力hash就比較trival了。 考慮怎麼優化。 顯然我們只要找出所有形如$AA$的字元串就行了,設$pre ...
題意
Sol
NOI的題都這麼良心麽。。
先交個\(n^4\)暴力 => 75
hash優化一下 => 90
然後\(90\)到\(100\)分之間至少差了\(10\)難度臺階= =
\(90\)分的暴力hash就比較trival了。
考慮怎麼優化。 顯然我們只要找出所有形如\(AA\)的字元串就行了,設\(pre[i]\)表示以\(i\)為端點,向前的所有\(AA\)的數量,\(suf[i]\)表示以\(i\)為端點,向後的所有\(AA\)的數量
這樣最終答案就是\(\sum_{i = 1}^{N - 1} pre[i] * suf[i + 1]\)
那怎麼求\(pre\)呢?(\(suf\)同理,只要把原串翻轉一下就和pre一樣了)
首先枚舉一個長度\(len\),然後每隔\(len\)個點打一個標記。
比如\(abcabca\)在長度為\(len = 3\)的時候是這樣的\(abc|abc|a\)
對於相鄰的兩個標記,我們二分找出他們的\(LCS\)和\(LIS\),然後考慮在第一個標記左端點的所有點的貢獻,一個顯然的結論是:(其實也不是很顯然,自己舉幾個例子試試吧)
在\([(i - pre + 1, min(i, i - pre + 1 + (pre + suf - len )) + 1])\)內的點會產生貢獻(可能不是很嚴格,但是可以A。。)
然後暴力加就可以A了(因為數據水。。)
實際上可以直接差分一波
時間複雜度:\(O(nlog^2n)\)
mmp居然卡unsigned 自然溢出
#include<bits/stdc++.h>
#define ull unsigned long long
#define LL long long
using namespace std;
const int MAXN = 1e5 + 10, mod = 998244353;
inline LL read() {
char c = getchar(); int x = 0, f = 1;
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
int N, sum1[MAXN], sum2[MAXN];
ull ha[MAXN], base = 123, po[MAXN];
char s[MAXN];
ull gethash(int l, int r) {
if(l <= 0) return ha[r];
return (ha[r] - ha[l - 1] * po[r - l + 1] % mod + mod) % mod;
}
void Get(char *a, int *sum) {
po[0] = 1; ha[0] = 0; memset(ha, 0, sizeof(ha));
for(int i = 1; i <= N; i++) ha[i] = (ha[i - 1] * base % mod + a[i]) % mod, po[i] = base * po[i - 1] % mod;
for(int len = 1; len <= N / 2 + 1; len++) {
for(int i = len; i < N; i += len) {
int ll = i, rr = i + len, l = 0, r = 0, pre = 0, suf = 0, ans = 0;
if(rr > N) continue;
l = 1, r = len; ans = 0;
while(l <= r) {
int mid = l + r >> 1;
if(gethash(ll - mid + 1, ll) == gethash(rr - mid + 1, rr)) l = mid + 1, ans = mid;
else r = mid - 1;
}
pre = ans;
l = 1, r = len; ans = 0;
while(l <= r) {
int mid = l + r >> 1;
if(gethash(ll + 1, ll + mid) == gethash(rr + 1, rr + mid)) l = mid + 1, ans = mid;
else r = mid - 1;
}
suf = ans;
if(pre + suf < len) continue;
//for(int j = i - pre + 1; j <= min(i, i - pre + 1 + (pre + suf - len )); j++) sum[j]++;
sum[i - pre + 1]++;
sum[min(i, i - pre + 1 + (pre + suf - len )) + 1]--;
}
}
for(int i = 1; i <= N; i++) sum[i] += sum[i - 1];
}
void solve() {
memset(sum1, 0, sizeof(sum1));
memset(sum2, 0, sizeof(sum2));
scanf("%s", s + 1); N = strlen(s + 1);
Get(s, sum1);
reverse(s + 1, s + N + 1);
Get(s, sum2);
reverse(sum2 + 1, sum2 + N + 1);
LL ans = 0;
for(int i = 1; i <= N - 1; i++) ans += sum1[i + 1] * sum2[i];
cout << ans << endl;
}
signed main() {
//freopen("a.in", "r", stdin);
for(int T = read(); T; T--, solve());
return 0;
}