#include <ch32fun.h>
#include <rand.h>
#include <rsa.h>
#include <stdint.h>
#include <stdio.h>

#define LED_PIN PD6

void exit_blink() {
    for (int i = 0; i < 4; i++) {
        funDigitalWrite(LED_PIN, FUN_HIGH);
        Delay_Ms(50);
        funDigitalWrite(LED_PIN, FUN_LOW);
        Delay_Ms(50);
    }
}

void enter_blink() {
    for (int i = 0; i < 2; i++) {
        funDigitalWrite(LED_PIN, FUN_HIGH);
        Delay_Ms(200);
        funDigitalWrite(LED_PIN, FUN_LOW);
        Delay_Ms(200);
    }
}

int main() {
    SystemInit();
    sprand(0);

    funGpioInitAll();
    funPinMode(LED_PIN, GPIO_Speed_10MHz | GPIO_CNF_OUT_PP);

    enter_blink();

    uint64_t p = gen_prime(1 << 15, 1 << 16);
    uint64_t q = gen_prime(1 << 15, 1 << 16);

    while (p == q) p = gen_prime(1 << 15, 1 << 16);

    uint64_t n = p * q;
    uint64_t phi_n = (p - 1) * (q - 1);

    // 'e' is public. E for encrypt.
    uint64_t e = prand_range(3, phi_n - 1);
    while (gcd(e, phi_n) != 1) e = prand_range(3, phi_n - 1);

    // 'd' is our private key. D as in decrypt
    uint64_t d = mod_inverse(e, phi_n);
    if (d == 0 || d == 1) {
        printf("Modular inverse not found...");
        while (1);
    }

    char msg[] = "Hello";
    uint64_t coded[sizeof(msg)] = {0};
    char decoded[sizeof(msg)] = {0};

    // Encode the message
    for (int i = 0; i < sizeof(msg); i++) {
        coded[i] = (uint64_t)modexp((uint64_t)msg[i], e, n);
    }

    // Decode the message
    for (int i = 0; i < sizeof(msg); i++) {
        decoded[i] = (char)modexp(coded[i], d, n);
    }

    {
        printf("P: %u\n", (uint32_t)p);
        printf("Q: %u\n", (uint32_t)q);
        printf("N: %u\n", (uint32_t)n);
        printf("Phi_N: %u\n", (uint32_t)phi_n);
        printf("Pubkey (e): %u\n", (uint32_t)e);
        printf("Privkey (d): %u\n", (uint32_t)d);

        printf("Message: %s\n", msg);
        printf("Decoded: %s\n", decoded);
    }

    // Exit and hang forever
    exit_blink();
    while (1);
}