CP-Algorithms Library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub cp-algorithms/cp-algorithms-aux

:heavy_check_mark: cp-algo/math/cvector.hpp

Depends on

Required by

Verified with

Code

#ifndef CP_ALGO_MATH_CVECTOR_HPP
#define CP_ALGO_MATH_CVECTOR_HPP
#include "../util/complex.hpp"
#include <experimental/simd>
namespace cp_algo::math::fft {
    using ftype = double;
    using point = complex<ftype>;
    using vftype = std::experimental::native_simd<ftype>;
    using vpoint = complex<vftype>;
    static constexpr size_t flen = vftype::size();

    struct cvector {
        static constexpr size_t pre_roots = 1 << 18;
        std::vector<vftype> x, y;
        cvector(size_t n) {
            n = std::max(flen, std::bit_ceil(n));
            x.resize(n / flen);
            y.resize(n / flen);
        }
        template<class pt = point>
        void set(size_t k, pt t) {
            if constexpr(std::is_same_v<pt, point>) {
                x[k / flen][k % flen] = real(t);
                y[k / flen][k % flen] = imag(t);
            } else {
                x[k / flen] = real(t);
                y[k / flen] = imag(t);
            }
        }
        template<class pt = point>
        pt get(size_t k) const {
            if constexpr(std::is_same_v<pt, point>) {
                return {x[k / flen][k % flen], y[k / flen][k % flen]};
            } else {
                return {x[k / flen], y[k / flen]};
            }
        }
        vpoint vget(size_t k) const {
            return get<vpoint>(k);
        }

        size_t size() const {
            return flen * std::size(x);
        }
        void dot(cvector const& t) {
            size_t n = size();
            for(size_t k = 0; k < n; k += flen) {
                set(k, get<vpoint>(k) * t.get<vpoint>(k));
            }
        }
        static const cvector roots;
        template< bool precalc = false, class ft = point>
        static auto root(size_t n, size_t k, ft &&arg) {
            if(n < pre_roots && !precalc) {
                return roots.get<complex<ft>>(n + k);
            } else {
                return complex<ft>::polar(1., arg);
            }
        }
        template<class pt = point, bool precalc = false>
        static void exec_on_roots(size_t n, size_t m, auto &&callback) {
            ftype arg = std::numbers::pi / (ftype)n;
            size_t step = sizeof(pt) / sizeof(point);
            using ft = pt::value_type;
            auto k = [&]() {
                if constexpr(std::is_same_v<pt, point>) {
                    return ft{};
                } else {
                    return ft{[](auto i) {return i;}};
                }
            }();
            for(size_t i = 0; i < m; i += step, k += (ftype)step) {
                callback(i, root<precalc>(n, i, arg * k));
            }
        }

        void ifft() {
            size_t n = size();
            for(size_t i = 1; i < n; i *= 2) {
                for(size_t j = 0; j < n; j += 2 * i) {
                    auto butterfly = [&]<class pt>(size_t k, pt rt) {
                        k += j;
                        auto t = get<pt>(k + i) * conj(rt);
                        set(k + i, get<pt>(k) - t);
                        set(k, get<pt>(k) + t);
                    };
                    if(i < flen) {
                        exec_on_roots<point>(i, i, butterfly);
                    } else {
                        exec_on_roots<vpoint>(i, i, butterfly);
                    }
                }
            }
            for(size_t k = 0; k < n; k += flen) {
                set(k, get<vpoint>(k) /= (ftype)n);
            }
        }
        void fft() {
            size_t n = size();
            for(size_t i = n / 2; i >= 1; i /= 2) {
                for(size_t j = 0; j < n; j += 2 * i) {
                    auto butterfly = [&]<class pt>(size_t k, pt rt) {
                        k += j;
                        auto A = get<pt>(k) + get<pt>(k + i);
                        auto B = get<pt>(k) - get<pt>(k + i);
                        set(k, A);
                        set(k + i, B * rt);
                    };
                    if(i < flen) {
                        exec_on_roots<point>(i, i, butterfly);
                    } else {
                        exec_on_roots<vpoint>(i, i, butterfly);
                    }
                }
            }
        }
    };
    const cvector cvector::roots = []() {
        cvector res(pre_roots);
        for(size_t n = 1; n < res.size(); n *= 2) {
            auto propagate = [&](size_t k, auto rt) {
                res.set(n + k, rt);
            };
            if(n < flen) {
                res.exec_on_roots<point, true>(n, n, propagate);
            } else {
                res.exec_on_roots<vpoint, true>(n, n, propagate);
            }
        }
        return res;
    }();
}
#endif // CP_ALGO_MATH_CVECTOR_HPP
#line 1 "cp-algo/math/cvector.hpp"


#line 1 "cp-algo/util/complex.hpp"


#include <cmath>
namespace cp_algo {
    // Custom implementation, since std::complex is UB on non-floating types
    template<typename T>
    struct complex {
        using value_type = T;
        T x, y;
        constexpr complex() {}
        constexpr complex(T x): x(x), y(0) {}
        constexpr complex(T x, T y): x(x), y(y) {}
        complex& operator *= (T t) {x *= t; y *= t; return *this;}
        complex& operator /= (T t) {x /= t; y /= t; return *this;}
        complex operator * (T t) const {return complex(*this) *= t;}
        complex operator / (T t) const {return complex(*this) /= t;}
        complex& operator += (complex t) {x += t.x; y += t.y; return *this;}
        complex& operator -= (complex t) {x -= t.x; y -= t.y; return *this;}
        complex operator * (complex t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
        complex operator / (complex t) const {return *this * t.conj() / t.norm();}
        complex operator + (complex t) const {return complex(*this) += t;}
        complex operator - (complex t) const {return complex(*this) -= t;}
        complex& operator *= (complex t) {return *this = *this * t;}
        complex& operator /= (complex t) {return *this = *this / t;}
        complex operator - () const {return {-x, -y};}
        complex conj() const {return {x, -y};}
        T norm() const {return x * x + y * y;}
        T abs() const {return std::sqrt(norm());}
        T real() const {return x;}
        T imag() const {return y;}
        T& real() {return x;}
        T& imag() {return y;}
        static complex polar(T r, T theta) {return {r * cos(theta), r * sin(theta)};}
        auto operator <=> (complex const& t) const = default;
    };
    template<typename T>
    complex<T> operator * (auto x, complex<T> y) {return y *= x;}
    template<typename T> complex<T> conj(complex<T> x) {return x.conj();}
    template<typename T> T norm(complex<T> x) {return x.norm();}
    template<typename T> T abs(complex<T> x) {return x.abs();}
    template<typename T> T& real(complex<T> &x) {return x.real();}
    template<typename T> T& imag(complex<T> &x) {return x.imag();}
    template<typename T> T real(complex<T> const& x) {return x.real();}
    template<typename T> T imag(complex<T> const& x) {return x.imag();}
    template<typename T> complex<T> polar(T r, T theta) {return complex<T>::polar(r, theta);}
}

#line 4 "cp-algo/math/cvector.hpp"
#include <experimental/simd>
namespace cp_algo::math::fft {
    using ftype = double;
    using point = complex<ftype>;
    using vftype = std::experimental::native_simd<ftype>;
    using vpoint = complex<vftype>;
    static constexpr size_t flen = vftype::size();

    struct cvector {
        static constexpr size_t pre_roots = 1 << 18;
        std::vector<vftype> x, y;
        cvector(size_t n) {
            n = std::max(flen, std::bit_ceil(n));
            x.resize(n / flen);
            y.resize(n / flen);
        }
        template<class pt = point>
        void set(size_t k, pt t) {
            if constexpr(std::is_same_v<pt, point>) {
                x[k / flen][k % flen] = real(t);
                y[k / flen][k % flen] = imag(t);
            } else {
                x[k / flen] = real(t);
                y[k / flen] = imag(t);
            }
        }
        template<class pt = point>
        pt get(size_t k) const {
            if constexpr(std::is_same_v<pt, point>) {
                return {x[k / flen][k % flen], y[k / flen][k % flen]};
            } else {
                return {x[k / flen], y[k / flen]};
            }
        }
        vpoint vget(size_t k) const {
            return get<vpoint>(k);
        }

        size_t size() const {
            return flen * std::size(x);
        }
        void dot(cvector const& t) {
            size_t n = size();
            for(size_t k = 0; k < n; k += flen) {
                set(k, get<vpoint>(k) * t.get<vpoint>(k));
            }
        }
        static const cvector roots;
        template< bool precalc = false, class ft = point>
        static auto root(size_t n, size_t k, ft &&arg) {
            if(n < pre_roots && !precalc) {
                return roots.get<complex<ft>>(n + k);
            } else {
                return complex<ft>::polar(1., arg);
            }
        }
        template<class pt = point, bool precalc = false>
        static void exec_on_roots(size_t n, size_t m, auto &&callback) {
            ftype arg = std::numbers::pi / (ftype)n;
            size_t step = sizeof(pt) / sizeof(point);
            using ft = pt::value_type;
            auto k = [&]() {
                if constexpr(std::is_same_v<pt, point>) {
                    return ft{};
                } else {
                    return ft{[](auto i) {return i;}};
                }
            }();
            for(size_t i = 0; i < m; i += step, k += (ftype)step) {
                callback(i, root<precalc>(n, i, arg * k));
            }
        }

        void ifft() {
            size_t n = size();
            for(size_t i = 1; i < n; i *= 2) {
                for(size_t j = 0; j < n; j += 2 * i) {
                    auto butterfly = [&]<class pt>(size_t k, pt rt) {
                        k += j;
                        auto t = get<pt>(k + i) * conj(rt);
                        set(k + i, get<pt>(k) - t);
                        set(k, get<pt>(k) + t);
                    };
                    if(i < flen) {
                        exec_on_roots<point>(i, i, butterfly);
                    } else {
                        exec_on_roots<vpoint>(i, i, butterfly);
                    }
                }
            }
            for(size_t k = 0; k < n; k += flen) {
                set(k, get<vpoint>(k) /= (ftype)n);
            }
        }
        void fft() {
            size_t n = size();
            for(size_t i = n / 2; i >= 1; i /= 2) {
                for(size_t j = 0; j < n; j += 2 * i) {
                    auto butterfly = [&]<class pt>(size_t k, pt rt) {
                        k += j;
                        auto A = get<pt>(k) + get<pt>(k + i);
                        auto B = get<pt>(k) - get<pt>(k + i);
                        set(k, A);
                        set(k + i, B * rt);
                    };
                    if(i < flen) {
                        exec_on_roots<point>(i, i, butterfly);
                    } else {
                        exec_on_roots<vpoint>(i, i, butterfly);
                    }
                }
            }
        }
    };
    const cvector cvector::roots = []() {
        cvector res(pre_roots);
        for(size_t n = 1; n < res.size(); n *= 2) {
            auto propagate = [&](size_t k, auto rt) {
                res.set(n + k, rt);
            };
            if(n < flen) {
                res.exec_on_roots<point, true>(n, n, propagate);
            } else {
                res.exec_on_roots<vpoint, true>(n, n, propagate);
            }
        }
        return res;
    }();
}

Back to top page