This documentation is automatically generated by competitive-verifier/competitive-verifier
#include "cp-algo/number_theory/discrete_sqrt.hpp"
#ifndef CP_ALGO_NUMBER_THEORY_DISCRETE_SQRT_HPP
#define CP_ALGO_NUMBER_THEORY_DISCRETE_SQRT_HPP
#include "modint.hpp"
#include "../random/rng.hpp"
#include "../math/affine.hpp"
namespace cp_algo::math {
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
template<modint_type base>
std::optional<base> sqrt(base b) {
if(b == base(0)) {
return base(0);
} else if(bpow(b, (b.mod() - 1) / 2) != base(1)) {
return std::nullopt;
} else {
while(true) {
base z = random::rng();
if(z * z == b) {
return z;
}
lin<base> x(1, z, b); // x + z (mod x^2 - b)
x = bpow(x, (b.mod() - 1) / 2, lin<base>(0, 1, b));
if(x.a != base(0)) {
return x.a.inv();
}
}
}
}
}
#endif // CP_ALGO_NUMBER_THEORY_SQRT_HPP
#line 1 "cp-algo/number_theory/discrete_sqrt.hpp"
#line 1 "cp-algo/number_theory/modint.hpp"
#line 1 "cp-algo/math/common.hpp"
#include <functional>
#include <cstdint>
namespace cp_algo::math {
#ifdef CP_ALGO_MAXN
const int maxn = CP_ALGO_MAXN;
#else
const int maxn = 1 << 19;
#endif
const int magic = 64; // threshold for sizes to run the naive algo
auto bpow(auto const& x, auto n, auto const& one, auto op) {
if(n == 0) {
return one;
} else {
auto t = bpow(x, n / 2, one, op);
t = op(t, t);
if(n % 2) {
t = op(t, x);
}
return t;
}
}
auto bpow(auto x, auto n, auto ans) {
return bpow(x, n, ans, std::multiplies{});
}
template<typename T>
T bpow(T const& x, auto n) {
return bpow(x, n, T(1));
}
}
#line 4 "cp-algo/number_theory/modint.hpp"
#include <iostream>
#include <cassert>
namespace cp_algo::math {
template<typename modint, typename _Int>
struct modint_base {
using Int = _Int;
using UInt = std::make_unsigned_t<Int>;
static constexpr size_t bits = sizeof(Int) * 8;
using Int2 = std::conditional_t<bits <= 32, int64_t, __int128_t>;
using UInt2 = std::conditional_t<bits <= 32, uint64_t, __uint128_t>;
static Int mod() {
return modint::mod();
}
static Int remod() {
return modint::remod();
}
static UInt2 modmod() {
return UInt2(mod()) * mod();
}
modint_base(): r(0) {}
modint_base(Int2 rr) {
to_modint().setr(UInt((rr + modmod()) % mod()));
}
modint inv() const {
return bpow(to_modint(), mod() - 2);
}
modint operator - () const {
modint neg;
neg.r = std::min(-r, remod() - r);
return neg;
}
modint& operator /= (const modint &t) {
return to_modint() *= t.inv();
}
modint& operator *= (const modint &t) {
r = UInt(UInt2(r) * t.r % mod());
return to_modint();
}
modint& operator += (const modint &t) {
r += t.r; r = std::min(r, r - remod());
return to_modint();
}
modint& operator -= (const modint &t) {
r -= t.r; r = std::min(r, r + remod());
return to_modint();
}
modint operator + (const modint &t) const {return modint(to_modint()) += t;}
modint operator - (const modint &t) const {return modint(to_modint()) -= t;}
modint operator * (const modint &t) const {return modint(to_modint()) *= t;}
modint operator / (const modint &t) const {return modint(to_modint()) /= t;}
// Why <=> doesn't work?..
auto operator == (const modint &t) const {return to_modint().getr() == t.getr();}
auto operator != (const modint &t) const {return to_modint().getr() != t.getr();}
auto operator <= (const modint &t) const {return to_modint().getr() <= t.getr();}
auto operator >= (const modint &t) const {return to_modint().getr() >= t.getr();}
auto operator < (const modint &t) const {return to_modint().getr() < t.getr();}
auto operator > (const modint &t) const {return to_modint().getr() > t.getr();}
Int rem() const {
UInt R = to_modint().getr();
return R - (R > (UInt)mod() / 2) * mod();
}
void setr(UInt rr) {
r = rr;
}
UInt getr() const {
return r;
}
// Only use these if you really know what you're doing!
static UInt modmod8() {return UInt(8 * modmod());}
void add_unsafe(UInt t) {r += t;}
void pseudonormalize() {r = std::min(r, r - modmod8());}
modint const& normalize() {
if(r >= (UInt)mod()) {
r %= mod();
}
return to_modint();
}
void setr_direct(UInt rr) {r = rr;}
UInt getr_direct() const {return r;}
protected:
UInt r;
private:
modint& to_modint() {return static_cast<modint&>(*this);}
modint const& to_modint() const {return static_cast<modint const&>(*this);}
};
template<typename modint>
concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
template<modint_type modint>
decltype(std::cin)& operator >> (decltype(std::cin) &in, modint &x) {
typename modint::UInt r;
auto &res = in >> r;
x.setr(r);
return res;
}
template<modint_type modint>
decltype(std::cout)& operator << (decltype(std::cout) &out, modint const& x) {
return out << x.getr();
}
template<auto m>
struct modint: modint_base<modint<m>, decltype(m)> {
using Base = modint_base<modint<m>, decltype(m)>;
using Base::Base;
static constexpr Base::Int mod() {return m;}
static constexpr Base::UInt remod() {return m;}
auto getr() const {return Base::r;}
};
inline constexpr auto inv2(auto x) {
assert(x % 2);
std::make_unsigned_t<decltype(x)> y = 1;
while(y * x != 1) {
y *= 2 - x * y;
}
return y;
}
template<typename Int = int64_t>
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
using Base = modint_base<dynamic_modint<Int>, Int>;
using Base::Base;
static Base::UInt m_reduce(Base::UInt2 ab) {
if(mod() % 2 == 0) [[unlikely]] {
return typename Base::UInt(ab % mod());
} else {
typename Base::UInt2 m = typename Base::UInt(ab) * imod();
return typename Base::UInt((ab + m * mod()) >> Base::bits);
}
}
static Base::UInt m_transform(Base::UInt a) {
if(mod() % 2 == 0) [[unlikely]] {
return a;
} else {
return m_reduce(a * pw128());
}
}
dynamic_modint& operator *= (const dynamic_modint &t) {
Base::r = m_reduce(typename Base::UInt2(Base::r) * t.r);
return *this;
}
void setr(Base::UInt rr) {
Base::r = m_transform(rr);
}
Base::UInt getr() const {
typename Base::UInt res = m_reduce(Base::r);
return std::min(res, res - mod());
}
static Int mod() {return m;}
static Int remod() {return 2 * m;}
static Base::UInt imod() {return im;}
static Base::UInt2 pw128() {return r2;}
static void switch_mod(Int nm) {
m = nm;
im = m % 2 ? inv2(-m) : 0;
r2 = static_cast<Base::UInt>(static_cast<Base::UInt2>(-1) % m + 1);
}
// Wrapper for temp switching
auto static with_mod(Int tmp, auto callback) {
struct scoped {
Int prev = mod();
~scoped() {switch_mod(prev);}
} _;
switch_mod(tmp);
return callback();
}
private:
static thread_local Int m;
static thread_local Base::UInt im, r2;
};
template<typename Int>
Int thread_local dynamic_modint<Int>::m = 1;
template<typename Int>
dynamic_modint<Int>::Base::UInt thread_local dynamic_modint<Int>::im = -1;
template<typename Int>
dynamic_modint<Int>::Base::UInt thread_local dynamic_modint<Int>::r2 = 0;
}
#line 1 "cp-algo/random/rng.hpp"
#include <chrono>
#include <random>
namespace cp_algo::random {
uint64_t rng() {
static std::mt19937_64 rng(
std::chrono::steady_clock::now().time_since_epoch().count()
);
return rng();
}
}
#line 1 "cp-algo/math/affine.hpp"
#include <optional>
#include <utility>
#line 6 "cp-algo/math/affine.hpp"
#include <tuple>
namespace cp_algo::math {
// a * x + b
template<typename base>
struct lin {
base a = 1, b = 0;
std::optional<base> c;
lin() {}
lin(base b): a(0), b(b) {}
lin(base a, base b): a(a), b(b) {}
lin(base a, base b, base _c): a(a), b(b), c(_c) {}
// polynomial product modulo x^2 - c
lin operator * (const lin& t) {
assert(c && t.c && *c == *t.c);
return {a * t.b + b * t.a, b * t.b + a * t.a * (*c), *c};
}
// a * (t.a * x + t.b) + b
lin apply(lin const& t) const {
return {a * t.a, a * t.b + b};
}
void prepend(lin const& t) {
*this = t.apply(*this);
}
base eval(base x) const {
return a * x + b;
}
};
// (ax+b) / (cx+d)
template<typename base>
struct linfrac {
base a, b, c, d;
linfrac(): a(1), b(0), c(0), d(1) {} // x, identity for composition
linfrac(base a): a(a), b(1), c(1), d(0) {} // a + 1/x, for continued fractions
linfrac(base a, base b, base c, base d): a(a), b(b), c(c), d(d) {}
// composition of two linfracs
linfrac operator * (linfrac t) const {
return t.prepend(linfrac(*this));
}
linfrac operator-() const {
return {-a, -b, -c, -d};
}
linfrac adj() const {
return {d, -b, -c, a};
}
linfrac& prepend(linfrac const& t) {
t.apply(a, c);
t.apply(b, d);
return *this;
}
// apply linfrac to A/B
void apply(base &A, base &B) const {
std::tie(A, B) = std::pair{a * A + b * B, c * A + d * B};
}
};
}
#line 6 "cp-algo/number_theory/discrete_sqrt.hpp"
namespace cp_algo::math {
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
template<modint_type base>
std::optional<base> sqrt(base b) {
if(b == base(0)) {
return base(0);
} else if(bpow(b, (b.mod() - 1) / 2) != base(1)) {
return std::nullopt;
} else {
while(true) {
base z = random::rng();
if(z * z == b) {
return z;
}
lin<base> x(1, z, b); // x + z (mod x^2 - b)
x = bpow(x, (b.mod() - 1) / 2, lin<base>(0, 1, b));
if(x.a != base(0)) {
return x.a.inv();
}
}
}
}
}