#include "rsa.h" #include "rand.h" #include #include #define NULL ((void *)0) uint64_t gcd(uint64_t a, uint64_t b) { return extended_euclid(a, b, NULL, NULL); } int extended_euclid(int a, int b, int *x, int *y) { if (b == 0) { if (x) *x = 1; if (y) *y = 0; return a; } int x1, y1; int gcd = extended_euclid(b, a % b, &x1, &y1); if (x) *x = y1; if (y) *y = x1 - (a / b) * y1; return gcd; } int totient(int n) { int result = n; // Check for prime factors for (int p = 2; p * p <= n; p++) { if (n % p == 0) { // If p is a prime factor of n, remove all occurrences of p while (n % p == 0) { n /= p; } result -= result / p; } } // If n is still greater than 1, then it's a prime factor itself if (n > 1) { result -= result / n; } return result; } uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { uint64_t result = 0; a %= m; while (b > 0) { if (b & 1) { result = (result + a) % m; // Avoid overflow } a = (a * 2) % m; // Double a, keep within mod b >>= 1; } return result; } uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { uint64_t result = 1; a %= m; while (b > 0) { if (b & 1) { result = mulmod(result, a, m); } b >>= 1; a = mulmod(a, a, m); } return result; } uint64_t gen_prime(uint64_t min, uint64_t max) { uint64_t cand = 0; while (!miller_rabin(cand, 10)) cand = prand_range(min, max); return cand; } bool is_prime(int n) { if (n < 2) return false; for (int i = 2; i < n / 2 + 1; i++) { if (n % i == 0) return false; } return true; } bool miller_rabin(uint64_t n, uint64_t k) { if (n < 2) return false; uint64_t d = n - 1; uint64_t s = 0; while (d % 2 == 0) { d /= 2; s++; } for (uint64_t i = 0; i < k; i++) { uint64_t a = prand_range(2, n - 2); uint64_t x = modexp(a, d, n); if (x == 1 || x == n - 1) continue; for (uint64_t r = 1; r < s; r++) { x = modexp(x, 2, n); if (x == n - 1) break; } if (x != n - 1) return false; // Not prime } return true; // Likely prime } uint64_t mod_inverse(uint64_t a, uint64_t m) { uint64_t m0 = m; uint64_t y = 0, x = 1; if (m == 1) return 0; while (a > 1) { // q is quotient uint64_t q = a / m; uint64_t t = m; // m is remainder now m = a % m; a = t; t = y; // Update x and y y = x - q * y; x = t; } // Make x positive if (x < 0) x += m0; return x; }