diff --git a/main.c b/main.c index 0ed24e6..d503654 100644 --- a/main.c +++ b/main.c @@ -7,6 +7,7 @@ #include #define LED_PIN PD6 +#define RANDOM void exit_blink() { for (int i = 0; i < 4; i++) { @@ -26,6 +27,30 @@ void enter_blink() { } } +void test_mulmod() { + ASSERT_EQ(mulmod(3, 2, 4), 2); + ASSERT_EQ((3 * 2) % 4, 2); + + ASSERT_EQ(mulmod(31, 3, 8), 5); + ASSERT_EQ(mulmod((u64)1 << 63, 2, 1000000007ULL), 582344008); +} + +void test_modexp() { + ASSERT_EQ(modexp(3, 2, 4), 1); + ASSERT_EQ((3 ^ 2) % 4, 1); + + ASSERT_EQ(modexp(31, 3, 8), 7); + ASSERT_EQ(modexp((u64)1 << 63, 2, 1000000007ULL), 319908071); +} + +void debug_string(char *str) { + printf("Got string: %s\n", str); + for (int i = 0; i < strlen(str); i++) { + printf("decoded[%d] = '%c' (ASCII: %d)\n", i, str[i], + str[i]); // Print decoded chars and ASCII values + } +} + int main() { SystemInit(); sprand(0); @@ -36,43 +61,79 @@ int main() { enter_blink(); #ifdef RANDOM - uint64_t p = gen_prime(1 << 15, 1 << 16); - uint64_t q = p; +#define W 16 + const int64_t p = gen_prime(1 << (W - 1), 1 << W); - while (p == q) p = gen_prime(1 << 15, 1 << 16); + int64_t qprev = p; + while (p == qprev) qprev = gen_prime(1 << (W - 1), 1 << W); + + const i64 q = qprev; +#undef W #else - uint64_t p = 56857; - uint64_t q = 47963; + int64_t p = 56857; + int64_t q = 47963; #endif - uint64_t n = p * q; - uint64_t phi_n = (p - 1) * (q - 1); + int64_t n = p * q; + int64_t phi_n = (p - 1) * (q - 1); // 'e' is public. E for encrypt. - uint64_t e = 0; - while (gcd(e, phi_n) != 1) e = prand_range(3, phi_n - 1); + int64_t e = 0; + do { + e = prand_range(3, phi_n - 1); + } while (gcd(e, phi_n) != 1); // 'd' is our private key. D as in decrypt - uint64_t d = mod_inverse(e, phi_n); + int64_t d = mod_inverse(e, phi_n); if (d == 0 || d == 1) { printf("Modular inverse not found..."); while (1); } + { + char test = 'o'; + u64 enc = modexp(test, e, n); + char dec = (char)modexp(enc, d, n); + + if (dec != test) { + printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); + // while (1); + } + } + { + char test = 'c'; + + u64 p = 3, q = 11; + u64 n = p * q; + u64 e = 7; + u64 d = 3; + u64 enc = modexp(test, e, n); + char dec = (char)modexp(enc, d, n); + + if (dec != test) { + printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); + } else + printf("INFO: %c == %c => %d == %d\n", test, dec, test, dec); + } + char msg[] = "Hello"; - uint64_t coded[sizeof(msg)] = {0}; + int64_t coded[sizeof(msg)] = {0}; char decoded[sizeof(msg)] = {0}; // Encode the message - for (int i = 0; i < sizeof(msg); i++) { - coded[i] = (uint64_t)modexp((uint64_t)msg[i], e, n); + for (int i = 0; i < strlen(msg); i++) { + coded[i] = modexp((int64_t)msg[i], e, n); } // Decode the message - for (int i = 0; i < sizeof(msg); i++) { - decoded[i] = (char)modexp(coded[i], d, n); + for (int i = 0; i < strlen(msg); i++) { + int64_t dec = modexp(coded[i], d, n); + decoded[i] = dec & 0xFF; } + test_mulmod(); + test_modexp(); + { printf("P: %u\n", (uint32_t)p); printf("Q: %u\n", (uint32_t)q); @@ -83,6 +144,12 @@ int main() { printf("Message: %s\n", msg); printf("Decoded: %s\n", decoded); + + for (int i = 0; i < strlen(msg); i++) { + printf("coded[%d] = 0x%016lx\n", i, (unsigned long)coded[i]); + } + + debug_string(decoded); } // Exit and hang forever