dc3.cpp
Bài toán
In ra suffix array (mảng hậu tố) của một xâu kí tự dùng phương pháp DC3.
Độ phức tạp
thời gian: O(n)
bộ nhớ: O(n)
Code này của Nguyễn Tiến Trung Kiên
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
#define long long long
#define SetLength(a, n, t) a = (t *)calloc(n, sizeof(t))
bool leq(int a1, int a2, int b1, int b2) {
if (a1 != b1) return a1 <= b1;
return a2 <= b2;
}
bool leq(int a1, int a2, int a3, int b1, int b2, int b3) {
if (a1 != b1) return a1 <= b1;
return leq(a2, a3, b2, b3);
}
int intncmp(int a[], int b[], int n) {
for (int i = 0; i < n; ++i)
if (a[i] != b[i]) return a[i] - b[i];
return 0;
}
void radix_sort(int a[], int b[], int n, int r[], int k) {
vector<int> Count(k + 1), Start(k + 1);
for (int i = 0; i < n; ++i)
Count[r[a[i]]]++;
for (int i = 1; i <= k; ++i)
Start[i] = Start[i - 1] + Count[i - 1];
for (int i = 0; i < n; ++i)
b[Start[r[a[i]]]++] = a[i];
}
void sexiffus(int a[], int sa[], int n, int k) {
int n0 = (n + 2) / 3, n1 = (n + 1) / 3, n2 = n / 3, n02 = n0 + n2, cnt;
int *r, *sa12, *r0, *sa0, *Rank;
SetLength(r, n02 + 3, int);
SetLength(sa12, n02 + 3, int);
SetLength(r0, n0, int);
SetLength(sa0, n0, int);
SetLength(Rank, n + 3, int);
cnt = 0;
for (int i = 0; i < n + n0 - n1; ++i)
if (i % 3) r[cnt++] = i;
radix_sort(r, sa12, n02, a + 2, k);
radix_sort(sa12, r, n02, a + 1, k);
radix_sort(r, sa12, n02, a, k);
int Name = 0;
for (int i = 0; i < n02; ++i) {
if (i == 0 || intncmp(a + sa12[i - 1], a + sa12[i], 3) != 0) Name++;
Rank[sa12[i]] = Name;
}
cnt = 0;
for (int i = 0; i < n + n0 - n1; ++i)
if (i % 3 == 1) r[cnt++] = i;
for (int i = 0; i < n + n0 - n1; ++i)
if (i % 3 == 2) r[cnt++] = i; // r = {1,4,7,...,2,5,8,...}
for (int i = 0; i < n02; ++i)
r[i] = Rank[r[i]];
if (Name == n02)
for (int i = 0; i < n02; ++i)
sa12[r[i] - 1] = i;
else {
sexiffus(r, sa12, n02, Name);
for (int i = 0; i < n02; ++i)
r[sa12[i]] = i + 1;
}
for (int i = 0; i < n02; ++i)
if (sa12[i] < n0)
sa12[i] = sa12[i] * 3 + 1;
else
sa12[i] = (sa12[i] - n0) * 3 + 2;
for (int i = 0; i < n02; ++i)
Rank[sa12[i]] = i + 1;
cnt = 0;
for (int i = 0; i < n + n0 - n1; ++i)
if (i % 3 == 0) sa0[cnt++] = i; // r0 = {0,3,6,...}
radix_sort(sa0, r0, n0, Rank + 1, n02);
radix_sort(r0, sa0, n0, a, k);
int i = 0, j = n0 - n1, l = 0;
while (i < n0 && j < n02) {
int x = sa0[i], y = sa12[j];
if (sa12[j] % 3 == 1
? leq(a[x], Rank[x + 1], a[y], Rank[y + 1])
: leq(a[x], a[x + 1], Rank[x + 2], a[y], a[y + 1], Rank[y + 2]))
sa[l++] = sa0[i++];
else
sa[l++] = sa12[j++];
}
while (i < n0) sa[l++] = sa0[i++];
while (j < n02) sa[l++] = sa12[j++];
free(r), free(sa12), free(r0), free(sa0), free(Rank);
}
#define N 200005
int n, a[N], sa[N];
char s[N];
int main() {
gets(s);
n = strlen(s);
for (int i = 0; i < n; ++i)
a[i] = s[i];
sexiffus(a, sa, n, 128);
for (int i = 0; i < n; ++i)
printf("%d\n", sa[i]);
}
Nhận xét
Code này đã được dùng để nộp cho một bài trên SPOJ.
Tham khảo
http://www.spoj.com/problems/SARRAY/