Summarize
#2128. 「HAOI2015」数字串拆分
Solution
一道神仙题。对矩阵初学者极不友好。
$f$ 函数的计算非常简单,递推公式如下:
$$
f(i)=f(i-1)+f(i-2)+..+f(i-m)
$$
容易看出, $f(i)$ 可以使用矩阵加速递推。以 $m=3$ 为例,构造矩阵
$$
\begin{aligned}
A=\left[ \begin{matrix} 1 & 1 & 1 \
1 & 0 & 0 \
0 & 1 & 0 \end{matrix} \right]
\end{aligned}
$$
那么 $f(n)$ 即为 $(A^n)_{1,1}$ 。
但我们要求的是 $g(S)$ 。考虑将 $g(S)$ 表示为矩阵的形式,最终答案为矩阵的第 $[1,1]$ 项。
考虑 $g(123)$ 的推导。
$$
\begin{aligned}
g(123) &= f(1+2+3)+f(12+3)+f(1+23)+f(123) \
&= A^{1+2+3}+A^{12+3}+A^{1+23}+A^{123} \
&= A^1\times A^2\times A^3+A^{12}\times A^3+A^1\times A^{23}+A^{123}
\end{aligned}
$$
转化为以上形式后,记 $g’(n)$ 为将数字 $S$ 的前 $n$ 项分解后得到的矩阵;在本例中,$g’(3)=g(123)$ 。
可以得到 $g’(i)$ 的递推公式:
$$
g’(i)=g’(i-1)\times A^{\text{num[i,i]}}+g’(i-2)\times A^{\text{num[i-1,i]}}+..+g’(0)\times A^{\text{num[1,i]}}
$$
其中 $\text{num[l,r]}$ 表示字符串中 $[l,r]$ 区间所对应的数字。例如当 $S=”12345”$ ,$\text{num[2,4]}=234$ 。
考虑 $A^{\text{num[l,r]}}$ 的计算。可以通过预处理加速计算。记
$$
P[i][j]=A^{i*10^j}
$$
使用该数组,可以在 $O(n)$ 时间复杂度内计算 $A^{\text{num[l,r]}}$ 。
至此,算法骨架已经成型;时间复杂度为 $O(n\times m^3+n^2\times m^3)$ 。实现时需要注意细节优化(如计算 $g’$ 时使用恰当的循环顺序)。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
| #include<bits/stdc++.h> using namespace std; typedef long long ll;
const int mod = 998244353;
struct Matrix { int m, n; ll g[6][6]; };
Matrix mul(Matrix a, Matrix b) { Matrix ret; ret.m = a.m, ret.n = b.n; for (int i = 1; i <= ret.n; ++i) { for (int j = 1; j <= ret.m; ++j) { ret.g[i][j] = 0; for (int k = 1; k <= a.n; ++k) ret.g[i][j] = (ret.g[i][j] + a.g[i][k] * b.g[k][j]) % mod; } } return ret; }
Matrix add(Matrix a, Matrix b) { Matrix ret; ret.m = a.m, ret.n = a.n; for (int i = 1; i <= a.n; ++i) for (int j = 1; j <= a.m; ++j) ret.g[i][j] = (a.g[i][j] + b.g[i][j]) % mod; return ret; }
Matrix power10(Matrix a) { Matrix ret = a; ret = mul(ret, ret); ret = mul(ret, ret); ret = mul(ret, a); ret = mul(ret, ret); return ret; }
void Init_Unit_Matrix(Matrix& a, int n) { a.m = a.n = n; for (int i = 1; i <= n; ++i) for (int j = 1; j <= n; ++j) a.g[i][j] = (i == j) ? 1 : 0; }
void Init_Empty_Matrix(Matrix& a, int n) { a.m = a.n = n; for (int i = 1; i <= n; ++i) for (int j = 1; j <= n; ++j) a.g[i][j] = 0; }
char s[505];
int N, M; Matrix A; Matrix P[10][505]; Matrix Q[505]; Matrix f[505];
int main() { cin >> (s + 1); N = strlen(s + 1); cin >> M; Init_Unit_Matrix(P[0][0], M); A.n = A.m = M; for (int i = 1; i <= M; ++i) { for (int j = 1; j <= M; ++j) { A.g[i][j] = 0; if (i == 1) A.g[i][j] = 1; if (i - 1 == j) A.g[i][j] = 1; } } for (int i = 1; i <= 9; ++i) { P[i][0] = mul(P[i - 1][0], A); } for (int j = 1; j <= N; ++j) { for (int i = 0; i <= 9; ++i) { P[i][j] = power10(P[i][j - 1]); } } Init_Unit_Matrix(f[0], M); for (int i = 1; i <= N; ++i) { Init_Empty_Matrix(f[i], M); Matrix t; Init_Unit_Matrix(t, M); for (int j = i, c = 0; j >= 1; --j, ++c) { t = mul(t, P[s[j] - '0'][c]); f[i] = add(f[i], mul(f[j - 1], t)); } } cout << f[N].g[1][1]; return 0; }
|