diff --git a/include/CppCore.Test/Math/Util.h b/include/CppCore.Test/Math/Util.h index ab258fb3..5374c681 100644 --- a/include/CppCore.Test/Math/Util.h +++ b/include/CppCore.Test/Math/Util.h @@ -894,6 +894,24 @@ namespace CppCore { namespace Test { namespace Math return true; } + template + INLINE static bool upowmod() + { + uint64_t a1[N64]; uint64_t b1[N64]; uint64_t m1[N64]; uint64_t r1[N64]; + uint64_t a2[N64]; uint64_t b2[N64]; uint64_t m2[N64]; uint64_t r2[N64]; + CppCore::Random::Default64 prng; + for (size_t i = 0; i < 100; i++) { + prng.fill(a1); CppCore::clone(a2, a1); + prng.fill(b1); CppCore::clone(b2, b1); + prng.fill(m1); CppCore::clone(m2, m1); + CppCore::upowmod_single(a1, b1, m1, r1); + CppCore::upowmod(a2, b2, m2, r2); + if (!CppCore::equal(r1, r2)) + return false; + } + return true; + } + INLINE static bool upow32() { for (uint32_t base = 0; base < 20; base++) @@ -1553,6 +1571,8 @@ namespace CppCore { namespace Test { namespace VS { namespace Math { TEST_METHOD(UMULMOD64) { Assert::AreEqual(true, CppCore::Test::Math::Util::umulmod64()); } TEST_METHOD(UPOWMOD32) { Assert::AreEqual(true, CppCore::Test::Math::Util::upowmod32()); } TEST_METHOD(UPOWMOD64) { Assert::AreEqual(true, CppCore::Test::Math::Util::upowmod64()); } + TEST_METHOD(UPOWMOD128) { Assert::AreEqual(true, CppCore::Test::Math::Util::upowmod<2>()); } + TEST_METHOD(UPOWMOD256) { Assert::AreEqual(true, CppCore::Test::Math::Util::upowmod<4>()); } TEST_METHOD(UPOW32) { Assert::AreEqual(true, CppCore::Test::Math::Util::upow32()); } TEST_METHOD(UPOW64) { Assert::AreEqual(true, CppCore::Test::Math::Util::upow64()); } TEST_METHOD(UDIVMOD32) { Assert::AreEqual(true, CppCore::Test::Math::Util::udivmod32()); } diff --git a/include/CppCore/Math/Primes.h b/include/CppCore/Math/Primes.h index 14e9c460..98df104c 100644 --- a/include/CppCore/Math/Primes.h +++ b/include/CppCore/Math/Primes.h @@ -122,7 +122,7 @@ namespace CppCore /// SPRP with precalculated t, s and d and work memory m /// template - INLINE static bool sprp(const UINT& n, UINT& a, const UINT& t, const uint32_t& s, const UINT& d, UINT& r, UINT m[3]) + INLINE static bool sprp(const UINT& n, const UINT& a, const UINT& t, const uint32_t& s, const UINT& d, UINT& r, UINT m[3]) { assert(a != 1U); CppCore::upowmod(a, d, n, r, m); diff --git a/include/CppCore/Math/Util.h b/include/CppCore/Math/Util.h index 3c389552..d6ee6b06 100644 --- a/include/CppCore/Math/Util.h +++ b/include/CppCore/Math/Util.h @@ -3822,7 +3822,7 @@ namespace CppCore /// a^b mod m /// template - INLINE static void upowmod(UINT& a, const UINT& b, const UINT& m, UINT& r, UINT t[3]) + INLINE static void upowmod_single(UINT& a, const UINT& b, const UINT& m, UINT& r, UINT t[3]) { assert((&a != &r) && (&b != &r) && (&m != &r)); assert(!CppCore::testzero(m)); @@ -3848,10 +3848,63 @@ namespace CppCore /// a^b mod m /// template - INLINE static void upowmod(UINT& a, const UINT& b, const UINT& m, UINT& r) + INLINE static void upowmod_single(UINT& a, const UINT& b, const UINT& m, UINT& r) { CPPCORE_ALIGN_OPTIM(UINT) t[3]; - CppCore::upowmod(a, b, m, r, t); + CppCore::upowmod_single(a, b, m, r, t); + } + + /// + /// a^b mod m + /// + template + INLINE static void upowmod(const UINT& a, const UINT& b, const UINT& m, UINT& r, UINT t[3]) + { + assert((&a != &r) && (&b != &r) && (&m != &r)); + assert(!CppCore::testzero(m)); + CppCore::clear(r); + constexpr auto NUMBITS = sizeof(UINT)*8U; + const auto LZB = CppCore::lzcnt(b); + if (LZB == NUMBITS) CPPCORE_UNLIKELY { + if (NUMBITS-CppCore::lzcnt(m) != 1U) CPPCORE_LIKELY + *(uint32_t*)&r = 1U; + return; + } + *(uint32_t*)&r = 1U; + + // Precompute powers: base^0, base^1, ..., base^(2^k - 1) + constexpr size_t TABLE_SIZE = 1U << K; + CPPCORE_ALIGN_OPTIM(UINT) powers[TABLE_SIZE]; + CppCore::clear(powers[0]); *(uint32_t*)&powers[0] = 1U; + CppCore::clone(powers[1], a); + for (size_t i = 2; i < TABLE_SIZE; i++) + CppCore::umulmod(powers[i-1], a, m, powers[i], t); + + // Find the position of the highest bit in exp + // Round up to multiple of k + const auto HIDX = NUMBITS - LZB; + const auto HIDX_K = ((HIDX + K - 1U) / K) * K; + + // Process k bits at a time from left to right (MSB to LSB) + for (int pos = HIDX_K - K; pos >= 0; pos -= K) { + + if (pos < HIDX_K - K) + for (size_t i = 0; i < K; i++) + CppCore::umulmod(r, r, m, r, t); + const uint32_t N = MIN(K, NUMBITS-pos); + const uint32_t CHUNK = CppCore::getbits32(b, pos, N); + CppCore::umulmod(r, powers[CHUNK], m, r, t); + } + } + + /// + /// a^b mod m + /// + template + INLINE static void upowmod(const UINT& a, const UINT& b, const UINT& m, UINT& r) + { + CPPCORE_ALIGN_OPTIM(UINT) t[3]; + CppCore::upowmod(a, b, m, r, t); } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/CppCore.Test/Test.cpp b/src/CppCore.Test/Test.cpp index 8ea3e87c..2d3c88e7 100644 --- a/src/CppCore.Test/Test.cpp +++ b/src/CppCore.Test/Test.cpp @@ -378,6 +378,8 @@ int main() TEST(CppCore::Test::Math::Util::umulmod64, "umulmod64: ", std::endl); TEST(CppCore::Test::Math::Util::upowmod32, "upowmod32: ", std::endl); TEST(CppCore::Test::Math::Util::upowmod64, "upowmod64: ", std::endl); + TEST(CppCore::Test::Math::Util::upowmod<2>, "upowmod128: ", std::endl); + TEST(CppCore::Test::Math::Util::upowmod<4>, "upowmod256: ", std::endl); TEST(CppCore::Test::Math::Util::upow32, "upow32: ", std::endl); TEST(CppCore::Test::Math::Util::upow64, "upow64: ", std::endl); TEST(CppCore::Test::Math::Util::udivmod32, "udivmod32: ", std::endl);