This is clearly a partial dearrangement problem. If is the number of permutation with exactly prime number found away from their natural position and is the total number of permutation. Then the answer is .
Code
- C++
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod = 1e9 + 123;
vector<ll> primes;
vector<ll> facts(10000001);
// ------------------ Modular Multiplicative inverse -------------------------
ll modularExponent(ll num, ll pow) {
num %= mod;
ll res = 1ll;
while (pow) {
ll lb = pow & 1ll;
if (lb) {
res = (res * num) % mod;
}
pow >>= 1ll;
num *= num;
num %= mod;
}
return res;
}
ll inv(ll v) { return modularExponent(v, mod - 2ll); }
// -------- Precomputation of all primes less than the upperbound -------------
void preComputation() {
bitset<10000000> bt;
for (int i = 2; i <= 10000000; i++) {
if (!bt[i]) {
primes.push_back(i);
for (int j = 2 * i; j <= 10000000; j += i) {
bt[j] = 1;
}
}
}
facts[0] = facts[1] = 1;
for (ll i = 2; i <= 10000000; i++) {
facts[i] = facts[i - 1] * i;
facts[i] %= mod;
}
}
ll fact(ll a) { return facts[a]; }
// ---------------- Partial Dearrangements ----------
// this is not optimised
map<int, int> mp[664581];
// small improvement in performance by avoiding recalculation of calculated shit
ll dearrangements(ll moves, ll dontCare) {
if(mp[moves][dontCare] > 0) return mp[moves][dontCare];
if (moves <= 0) return fact(dontCare);
ll ans = (dontCare * dearrangements(moves - 1, dontCare)) % mod;
if (moves - 2 >= 0) {
ans += ((moves - 1) * dearrangements(moves - 2, dontCare + 1)) % mod;
}
ans %= mod;
return mp[moves][dontCare] = ans;
}
//------------ Combination ---------------
ll nCr(ll n, ll r) {
ll ans = (fact(n) * inv(fact(n - r))) % mod;
ans *= inv(fact(r));
ans %= mod;
return ans;
}
// --------------- Something important* -----------
void solve() {
ll n, k;
cin >> n >> k;
ll tot_primes = upper_bound(primes.begin(), primes.end(), n) - primes.begin();
ll ans = dearrangements(k, n - tot_primes);
ans *= nCr(tot_primes, k);
ans %= mod;
ans *= inv(fact(n));
ans %= mod;
cout << ans << "\n";
}
// --------------- main, still not main -----------------
int main() {
preComputation();
ll tc = 1;
while (tc--) {
solve();
}
}
//------ The end, now get your a$$ out of here --------------
// Well, if you can fully optimise this code help other folks like me by contributing to this problem.