1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| #include <bits/stdc++.h> using namespace std; typedef long long ll;
const ll MOD = 998244353, G = 3;
inline ll qpow(ll a, ll k) { ll ret = 1; while (k) { if (k & 1) ret = ret * a % MOD; a = a * a % MOD; k >>= 1; } return ret; }
namespace FFT { int n, r[10005]; void NTT(ll *a, int op) { int k = 0; for (; (1 << k) < n; ++k); for (int i = 0; i < n; ++i) { r[i] = r[i >> 1] >> 1 | (i & 1) << (k - 1); if (i < r[i]) swap(a[i], a[r[i]]); } for (int l = 2; l <= n; l <<= 1) { int m = l >> 1; ll w = qpow(G, (MOD - 1) / l); if (op == -1) w = qpow(w, MOD - 2); for (int i = 0; i < n; i += l) { ll wk = 1; for (int j = 0; j < m; ++j, wk = wk * w % MOD) { ll p = a[i + j], q = wk * a[i + j + m] % MOD; a[i + j] = (p + q) % MOD; a[i + j + m] = (p - q + MOD) % MOD; } } } } void DFT(ll *a) { NTT(a, 1); } void IDFT(ll *a) { NTT(a, -1); ll inv = qpow(n, MOD - 2); for (int i = 0; i < n; ++i) a[i] = a[i] * inv % MOD; } };
ll f[1005]; ll fac[4005], inv[4005], facinv[4005]; ll x[4005], y[4005], z[4005], w[4005];
ll C(int n, int k) { return fac[n] * facinv[k] % MOD * facinv[n - k] % MOD; }
int n, a, b, c, d, m; ll ans = 0;
int main() { fac[0] = facinv[0] = 1; fac[1] = inv[1] = facinv[1] = 1; for (int i = 2; i <= 4000; ++i) { fac[i] = fac[i - 1] * 1ll * i % MOD; inv[i] = (MOD - MOD / i) * inv[MOD % i] % MOD; facinv[i] = facinv[i - 1] * inv[i] % MOD; } scanf("%d%d%d%d%d", &n, &a, &b, &c, &d); m = min(min(min(min(n / 4, a), b), c), d); for (int i = 0; i <= m; ++i) { memset(x, 0, sizeof(x)); memset(y, 0, sizeof(y)); memset(z, 0, sizeof(z)); memset(w, 0, sizeof(w)); for (int j = 0; j <= a - i; ++j) x[j] = facinv[j]; for (int j = 0; j <= b - i; ++j) y[j] = facinv[j]; for (int j = 0; j <= c - i; ++j) z[j] = facinv[j]; for (int j = 0; j <= d - i; ++j) w[j] = facinv[j]; FFT::n = 1; while (FFT::n <= a + b + c + d - 4 * i) FFT::n <<= 1; FFT::DFT(x), FFT::DFT(y), FFT::DFT(z), FFT::DFT(w); for (int j = 0; j < FFT::n; ++j) x[j] = x[j] * y[j] % MOD * z[j] % MOD * w[j] % MOD; FFT::IDFT(x); f[i] = C(n - 3 * i, i) * fac[n - 4 * i] % MOD * x[n - 4 * i] % MOD; if (i & 1) ans = (ans - f[i] + MOD) % MOD; else ans = (ans + f[i]) % MOD; } printf("%lld", ans); return 0; }
|