#include "rsa.h"
#include "funconfig.h"
#include "rand.h"
#include <stdbool.h>

u64 gcd(u64 a, u64 b) { return extended_euclid(a, b, NULL, NULL); }

u64 extended_euclid(u64 a, u64 b, u64 *x, u64 *y) {
    if (b == 0) {
        if (x)
            *x = 1;
        if (y)
            *y = 0;
        return a;
    }

    u64 x1, y1;
    u64 gcd = extended_euclid(b, a % b, &x1, &y1);

    if (x)
        *x = y1;
    if (y)
        *y = x1 - (a / b) * y1;

    return gcd;
}

u64 totient(u64 n) {
    int result = n;

    // Check for prime factors
    for (int p = 2; p * p <= n; p++) {
        if (n % p == 0) {
            // If p is a prime factor of n, remove all occurrences of p
            while (n % p == 0) {
                n /= p;
            }
            result -= result / p;
        }
    }

    // If n is still greater than 1, then it's a prime factor itself
    if (n > 1) {
        result -= result / n;
    }

    return result;
}

u64 mulmod(u64 a, u64 b, u64 m) {
    u64 result = 0;
    a %= m;

    // Perform the multiplication bit by bit (binary multiplication)
    while (b > 0) {
        if (b & 1) {
            result = (result + a) % m;
        }
        a = (a * 2) % m;  // Double a, keep it within the modulus
        b >>= 1;  // Right shift b (divide by 2)
    }

    return result;
}

u64 modexp(u64 a, u64 b, u64 m) {
    u64 result = 1;
    a %= m;

    while (b > 0) {
        if (b & 1) {
            result = mulmod(result, a, m);
        }
        b >>= 1;
        a = mulmod(a, a, m);
    }

    return result;
}

u64 gen_prime(u64 min, u64 max) {
    u64 cand = 0;
    while (!miller_rabin(cand, 10)) cand = prand_range(min, max);

    return cand;
}

bool is_prime(u64 n) {
    if (n < 2)
        return false;

    for (int i = 2; i < n / 2 + 1; i++) {
        if (n % i == 0)
            return false;
    }

    return true;
}

bool miller_rabin(u64 n, u64 k) {
    if (n < 2)
        return false;

    u64 d = n - 1;
    u64 s = 0;

    while (d % 2 == 0) {
        d /= 2;
        s++;
    }

    for (u64 i = 0; i < k; i++) {
        u64 a = prand_range(2, n - 2);
        u64 x = modexp(a, d, n);

        if (x == 1 || x == n - 1)
            continue;

        for (u64 r = 1; r < s; r++) {
            x = modexp(x, 2, n);
            if (x == n - 1)
                break;
        }

        if (x != n - 1)
            return false; // Not prime
    }

    return true; // Likely prime
}

u64 mod_inverse(u64 a, u64 m) {
    u64 m0 = m;
    u64 y = 0, x = 1;

    // Modular inverse does not exist when m is 1
    if (m == 1)
        return 0;

    while (a > 1) {
        // q is quotient
        u64 q = a / m;
        u64 t = m;

        // m is remainder now
        m = a % m;
        a = t;
        t = y;

        // Update x and y
        y = x - q * y;
        x = t;
    }

    // Make x positive
    if (x < 0)
        x += m0;

    return x;
}