序列计数

Time Limit: 30 Sec Memory Limit: 128 MB

Description

Alice想要得到一个长度为n的序列,序列中的数都是不超过m的正整数,而且这n个数的和是p的倍数。Alice还希望,这n个数中,至少有一个数是质数。Alice想知道,有多少个序列满足她的要求。

Input

一行三个数,n,m,p。

Output

一行一个数,满足Alice的要求的序列数量,答案对20170408取模。

Sample Input

3 5 3

Sample Output

33

HINT

1<=n<=10^9,1<=m<=2×10^7,1<=p<=100

Solution

先考虑容斥,用Ans=全部的方案数 - 一个质数都没有的方案,那么我们首先想到了一个暴力DP,令 f[i][j] 表示选了前 i 个数,%p时余数为 j 的方案数。那么显然 %p 同余的可以分为一类,那么就可以用矩阵乘法来优化这个DP了。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
using namespace std;
typedef long long s64;

const int MaxM = 2e7+5;
const int ONE = 105;
const int MOD = 20170408;

int n,m,p;
int prime[1300005],p_num;
int Record[ONE][2],a[ONE][ONE],b[ONE][ONE];
bool isp[MaxM];

inline int get()
{
int res=1,Q=1; char c;
while( (c=getchar())<48 || c>57)
if(c=='-')Q=-1;
if(Q) res=c-48;
while((c=getchar())>=48 && c<=57)
res=res*10+c-48;
return res*Q;
}

void Getp(int MaxN)
{
isp[1] = 1;
for(int i=2; i<=MaxN; i++)
{
if(!isp[i])
prime[++p_num] = i;
for(int j=1; j<=p_num, i*prime[j]<=MaxN; j++)
{
isp[i * prime[j]] = 1;
if(i % prime[j] == 0) break;
}
}
}

void Mul(int a[ONE][ONE],int b[ONE][ONE],int ans[ONE][ONE])
{
int record[ONE][ONE];
for(int i=0;i<p;i++)
for(int j=0;j<p;j++)
{
record[i][j] = 0;
for(int k=0;k<p;k++)
record[i][j] = (s64)(record[i][j] + (s64)a[i][k]*b[k][j] % MOD) %MOD;
}

for(int i=0;i<p;i++)
for(int j=0;j<p;j++)
ans[i][j] = record[i][j];
}

void Quickpow(int a[ONE][ONE],int b[ONE][ONE],int t)
{
while(t)
{
if(t&1) Mul(a,b,a);
Mul(b,b,b);
t>>=1;
}
}

int Solve(int PD)
{
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));

for(int i=0;i<p;i++)
for(int j=0;j<p;j++)
b[i][j] = Record[((i-j)%p+p)%p][PD];

for(int i=0;i<p;i++)
a[i][i] = 1;

Quickpow(a,b,n);
return a[0][0];
}

int main()
{
n=get(); m=get(); p=get(); Getp(m);
for(int i=1;i<=m;i++)
{
int x = i%p;
Record[x][0]++;
if(isp[i]) Record[x][1]++;
}

printf("%d",(Solve(0)-Solve(1)+MOD) % MOD);
}