diff --git a/rsa.c b/rsa.c index 1320f85..b547c93 100644 --- a/rsa.c +++ b/rsa.c @@ -3,28 +3,28 @@ #include #include +#define NULL ((void *)0) + uint64_t gcd(uint64_t a, uint64_t b) { - while (b != 0) { - uint64_t temp = b; - b = a % b; - a = temp; - } - return a; + return extended_euclid(a, b, NULL, NULL); } int extended_euclid(int a, int b, int *x, int *y) { if (b == 0) { - *x = 1; - *y = 0; + if (x) + *x = 1; + if (y) + *y = 0; return a; } int x1, y1; int gcd = extended_euclid(b, a % b, &x1, &y1); - // Update x and y using results from recursive call - *x = y1; - *y = x1 - (a / b) * y1; + if (x) + *x = y1; + if (y) + *y = x1 - (a / b) * y1; return gcd; } @@ -51,18 +51,31 @@ int totient(int n) { return result; } -uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { - uint64_t result = 1; - a = a % m; // In case a is greater than m +uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { + uint64_t result = 0; + a %= m; while (b > 0) { - // If b is odd, multiply a with result - if (b % 2 == 1) - result = (result * a) % m; + if (b & 1) { + result = (result + a) % m; // Avoid overflow + } + a = (a * 2) % m; // Double a, keep within mod + b >>= 1; + } - // b must be even now - b = b >> 1; // b = b // 2 - a = (a * a) % m; // Change a to a^2 + return result; +} + +uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { + uint64_t result = 1; + a %= m; + + while (b > 0) { + if (b & 1) { + result = mulmod(result, a, m); + } + b >>= 1; + a = mulmod(a, a, m); } return result; @@ -70,7 +83,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { uint64_t gen_prime(uint64_t min, uint64_t max) { uint64_t cand = 0; - while (!miller_rabin(cand, 5)) cand = prand_range(min, max); + while (!miller_rabin(cand, 10)) cand = prand_range(min, max); return cand; } @@ -119,17 +132,17 @@ bool miller_rabin(uint64_t n, uint64_t k) { return true; // Likely prime } -int mod_inverse(int a, int m) { - int m0 = m; - int y = 0, x = 1; +uint64_t mod_inverse(uint64_t a, uint64_t m) { + uint64_t m0 = m; + uint64_t y = 0, x = 1; if (m == 1) return 0; while (a > 1) { // q is quotient - int q = a / m; - int t = m; + uint64_t q = a / m; + uint64_t t = m; // m is remainder now m = a % m; diff --git a/rsa.h b/rsa.h index e56004d..53bc78d 100644 --- a/rsa.h +++ b/rsa.h @@ -21,6 +21,19 @@ uint64_t gcd(uint64_t a, uint64_t b); */ int totient(int n); +/** + * @brief Computes (a * b) % m safely without overflow. + * + * Uses repeated addition and bit shifting to handle large values, + * ensuring correctness even on 32-bit microcontrollers. + * + * @param a The first operand. + * @param b The second operand. + * @param m The modulus. + * @return (a * b) % m computed safely. + */ +uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m); + /** * @brief Modular exponentiation (a^b) mod m * @@ -37,7 +50,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m); * @param m The modulus. * @return The modular inverse of a modulo m, or -1 if no inverse exists. */ -int mod_inverse(int a, int m); +uint64_t mod_inverse(uint64_t a, uint64_t m); /** * @brief Generates a random prime number within the given range.