From 2bad1303dc24ff30da64e0ad039a5a806cde51fc Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 04:07:03 +0100 Subject: [PATCH 1/9] Type aliases for RSA related functionality --- rsa.c | 56 ++++++++++++++++++++++++++------------------------------ rsa.h | 19 ++++++++++--------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/rsa.c b/rsa.c index b547c93..ab9ec95 100644 --- a/rsa.c +++ b/rsa.c @@ -1,15 +1,11 @@ #include "rsa.h" +#include "funconfig.h" #include "rand.h" #include -#include -#define NULL ((void *)0) +u64 gcd(u64 a, u64 b) { return extended_euclid(a, b, NULL, NULL); } -uint64_t gcd(uint64_t a, uint64_t b) { - return extended_euclid(a, b, NULL, NULL); -} - -int extended_euclid(int a, int b, int *x, int *y) { +u64 extended_euclid(u64 a, u64 b, u64 *x, u64 *y) { if (b == 0) { if (x) *x = 1; @@ -18,8 +14,8 @@ int extended_euclid(int a, int b, int *x, int *y) { return a; } - int x1, y1; - int gcd = extended_euclid(b, a % b, &x1, &y1); + u64 x1, y1; + u64 gcd = extended_euclid(b, a % b, &x1, &y1); if (x) *x = y1; @@ -29,7 +25,7 @@ int extended_euclid(int a, int b, int *x, int *y) { return gcd; } -int totient(int n) { +u64 totient(u64 n) { int result = n; // Check for prime factors @@ -51,13 +47,13 @@ int totient(int n) { return result; } -uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { - uint64_t result = 0; +u64 mulmod(u64 a, u64 b, u64 m) { + u64 result = 0; a %= m; while (b > 0) { if (b & 1) { - result = (result + a) % m; // Avoid overflow + result = (result + a) % m; } a = (a * 2) % m; // Double a, keep within mod b >>= 1; @@ -66,8 +62,8 @@ uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { return result; } -uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { - uint64_t result = 1; +u64 modexp(u64 a, u64 b, u64 m) { + u64 result = 1; a %= m; while (b > 0) { @@ -81,14 +77,14 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { return result; } -uint64_t gen_prime(uint64_t min, uint64_t max) { - uint64_t cand = 0; +u64 gen_prime(u64 min, u64 max) { + u64 cand = 0; while (!miller_rabin(cand, 10)) cand = prand_range(min, max); return cand; } -bool is_prime(int n) { +bool is_prime(u64 n) { if (n < 2) return false; @@ -100,26 +96,26 @@ bool is_prime(int n) { return true; } -bool miller_rabin(uint64_t n, uint64_t k) { +bool miller_rabin(u64 n, u64 k) { if (n < 2) return false; - uint64_t d = n - 1; - uint64_t s = 0; + u64 d = n - 1; + u64 s = 0; while (d % 2 == 0) { d /= 2; s++; } - for (uint64_t i = 0; i < k; i++) { - uint64_t a = prand_range(2, n - 2); - uint64_t x = modexp(a, d, n); + for (u64 i = 0; i < k; i++) { + u64 a = prand_range(2, n - 2); + u64 x = modexp(a, d, n); if (x == 1 || x == n - 1) continue; - for (uint64_t r = 1; r < s; r++) { + for (u64 r = 1; r < s; r++) { x = modexp(x, 2, n); if (x == n - 1) break; @@ -132,17 +128,17 @@ bool miller_rabin(uint64_t n, uint64_t k) { return true; // Likely prime } -uint64_t mod_inverse(uint64_t a, uint64_t m) { - uint64_t m0 = m; - uint64_t y = 0, x = 1; +u64 mod_inverse(u64 a, u64 m) { + u64 m0 = m; + u64 y = 0, x = 1; if (m == 1) return 0; while (a > 1) { // q is quotient - uint64_t q = a / m; - uint64_t t = m; + u64 q = a / m; + u64 t = m; // m is remainder now m = a % m; diff --git a/rsa.h b/rsa.h index 53bc78d..7a129e6 100644 --- a/rsa.h +++ b/rsa.h @@ -1,4 +1,5 @@ #pragma once +#include "funconfig.h" #include #include @@ -10,7 +11,7 @@ * @param b Second number * @return The greatest common divider */ -uint64_t gcd(uint64_t a, uint64_t b); +u64 gcd(u64 a, u64 b); /** * @brief Computes Euler's Totient function φ(n), which counts the number of @@ -19,7 +20,7 @@ uint64_t gcd(uint64_t a, uint64_t b); * @param n The input number. * @return The number of integers from 1 to n that are coprime to n. */ -int totient(int n); +u64 totient(u64 n); /** * @brief Computes (a * b) % m safely without overflow. @@ -32,7 +33,7 @@ int totient(int n); * @param m The modulus. * @return (a * b) % m computed safely. */ -uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m); +u64 mulmod(u64 a, u64 b, u64 m); /** * @brief Modular exponentiation (a^b) mod m @@ -41,7 +42,7 @@ uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m); * @param b The exponent * @param m The modulus */ -uint64_t modexp(uint64_t a, uint64_t b, uint64_t m); +u64 modexp(u64 a, u64 b, u64 m); /** * @brief Computes the modular inverse of a modulo m. @@ -50,7 +51,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m); * @param m The modulus. * @return The modular inverse of a modulo m, or -1 if no inverse exists. */ -uint64_t mod_inverse(uint64_t a, uint64_t m); +u64 mod_inverse(u64 a, u64 m); /** * @brief Generates a random prime number within the given range. @@ -59,7 +60,7 @@ uint64_t mod_inverse(uint64_t a, uint64_t m); * @param max The upper bound (inclusive). * @return A prime number in the range [min, max]. */ -uint64_t gen_prime(uint64_t min, uint64_t max); +u64 gen_prime(u64 min, u64 max); /** * @brief Checks if a number is prime. @@ -67,7 +68,7 @@ uint64_t gen_prime(uint64_t min, uint64_t max); * @param n The number to check. * @return true if n is prime, false otherwise. */ -bool is_prime(int n); +bool is_prime(u64 n); /** * @brief Performs the Miller-Rabin primality test to check if a number is @@ -77,7 +78,7 @@ bool is_prime(int n); * @param k The number of rounds of testing to perform. * @return true if n is probably prime, false if n is composite. */ -bool miller_rabin(uint64_t n, uint64_t k); +bool miller_rabin(u64 n, u64 k); /** * @brief Computes the greatest common divisor (GCD) of two integers a and b @@ -92,4 +93,4 @@ bool miller_rabin(uint64_t n, uint64_t k); * + by = gcd(a, b). * @return The greatest common divisor (gcd) of a and b. */ -int extended_euclid(int a, int b, int *x, int *y); +u64 extended_euclid(u64 a, u64 b, u64 *x, u64 *y); From 120a61eca7109f53fe7220ba5187b98473f25cb3 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 04:07:40 +0100 Subject: [PATCH 2/9] Type aliases --- funconfig.h | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/funconfig.h b/funconfig.h index 998cf76..561ab22 100644 --- a/funconfig.h +++ b/funconfig.h @@ -1,7 +1,18 @@ +#include + #ifndef _FUNCONFIG_H #define _FUNCONFIG_H -#define CH32V003 1 +#define CH32V003 1 + +#define NULL ((void *)0) + +typedef int8_t i8; +typedef uint8_t u8; +typedef int16_t i16; +typedef uint32_t u32; +typedef int32_t i32; +typedef int64_t i64; +typedef uint64_t u64; #endif - From 52857830ac317e1030f391f86f88e5c9c5c386b0 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 04:13:23 +0100 Subject: [PATCH 3/9] Assert file --- assert.h | 25 +++++++++++++++++++++++++ main.c | 2 ++ 2 files changed, 27 insertions(+) create mode 100644 assert.h diff --git a/assert.h b/assert.h new file mode 100644 index 0000000..f8e9d05 --- /dev/null +++ b/assert.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#define ASSERTold(expr) \ + do { \ + if (!(expr)) { \ + printf("ASSERTION FAILED: %s at %s:%d\n", #expr, __FILE__, \ + __LINE__); \ + while (1); \ + } \ + } while (0) + +#define ASSERT_EQ(expr, expected) \ + do { \ + uint64_t result = (expr); \ + if (result != (expected)) { \ + printf("ASSERTION FAILED: %s at %s:%d\n", #expr, __FILE__, \ + __LINE__); \ + printf("Expected: %lu, Got: %lu\n", (unsigned long)(expected), \ + (unsigned long)result); \ + while (1); \ + } \ + } while (0) diff --git a/main.c b/main.c index 29770df..0ed24e6 100644 --- a/main.c +++ b/main.c @@ -1,8 +1,10 @@ +#include "assert.h" #include #include #include #include #include +#include #define LED_PIN PD6 From 557c06ca22a5c4d585ced55e88ffa0a3589ea43a Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 04:13:33 +0100 Subject: [PATCH 4/9] Some testing --- main.c | 97 +++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 15 deletions(-) diff --git a/main.c b/main.c index 0ed24e6..d503654 100644 --- a/main.c +++ b/main.c @@ -7,6 +7,7 @@ #include #define LED_PIN PD6 +#define RANDOM void exit_blink() { for (int i = 0; i < 4; i++) { @@ -26,6 +27,30 @@ void enter_blink() { } } +void test_mulmod() { + ASSERT_EQ(mulmod(3, 2, 4), 2); + ASSERT_EQ((3 * 2) % 4, 2); + + ASSERT_EQ(mulmod(31, 3, 8), 5); + ASSERT_EQ(mulmod((u64)1 << 63, 2, 1000000007ULL), 582344008); +} + +void test_modexp() { + ASSERT_EQ(modexp(3, 2, 4), 1); + ASSERT_EQ((3 ^ 2) % 4, 1); + + ASSERT_EQ(modexp(31, 3, 8), 7); + ASSERT_EQ(modexp((u64)1 << 63, 2, 1000000007ULL), 319908071); +} + +void debug_string(char *str) { + printf("Got string: %s\n", str); + for (int i = 0; i < strlen(str); i++) { + printf("decoded[%d] = '%c' (ASCII: %d)\n", i, str[i], + str[i]); // Print decoded chars and ASCII values + } +} + int main() { SystemInit(); sprand(0); @@ -36,43 +61,79 @@ int main() { enter_blink(); #ifdef RANDOM - uint64_t p = gen_prime(1 << 15, 1 << 16); - uint64_t q = p; +#define W 16 + const int64_t p = gen_prime(1 << (W - 1), 1 << W); - while (p == q) p = gen_prime(1 << 15, 1 << 16); + int64_t qprev = p; + while (p == qprev) qprev = gen_prime(1 << (W - 1), 1 << W); + + const i64 q = qprev; +#undef W #else - uint64_t p = 56857; - uint64_t q = 47963; + int64_t p = 56857; + int64_t q = 47963; #endif - uint64_t n = p * q; - uint64_t phi_n = (p - 1) * (q - 1); + int64_t n = p * q; + int64_t phi_n = (p - 1) * (q - 1); // 'e' is public. E for encrypt. - uint64_t e = 0; - while (gcd(e, phi_n) != 1) e = prand_range(3, phi_n - 1); + int64_t e = 0; + do { + e = prand_range(3, phi_n - 1); + } while (gcd(e, phi_n) != 1); // 'd' is our private key. D as in decrypt - uint64_t d = mod_inverse(e, phi_n); + int64_t d = mod_inverse(e, phi_n); if (d == 0 || d == 1) { printf("Modular inverse not found..."); while (1); } + { + char test = 'o'; + u64 enc = modexp(test, e, n); + char dec = (char)modexp(enc, d, n); + + if (dec != test) { + printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); + // while (1); + } + } + { + char test = 'c'; + + u64 p = 3, q = 11; + u64 n = p * q; + u64 e = 7; + u64 d = 3; + u64 enc = modexp(test, e, n); + char dec = (char)modexp(enc, d, n); + + if (dec != test) { + printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); + } else + printf("INFO: %c == %c => %d == %d\n", test, dec, test, dec); + } + char msg[] = "Hello"; - uint64_t coded[sizeof(msg)] = {0}; + int64_t coded[sizeof(msg)] = {0}; char decoded[sizeof(msg)] = {0}; // Encode the message - for (int i = 0; i < sizeof(msg); i++) { - coded[i] = (uint64_t)modexp((uint64_t)msg[i], e, n); + for (int i = 0; i < strlen(msg); i++) { + coded[i] = modexp((int64_t)msg[i], e, n); } // Decode the message - for (int i = 0; i < sizeof(msg); i++) { - decoded[i] = (char)modexp(coded[i], d, n); + for (int i = 0; i < strlen(msg); i++) { + int64_t dec = modexp(coded[i], d, n); + decoded[i] = dec & 0xFF; } + test_mulmod(); + test_modexp(); + { printf("P: %u\n", (uint32_t)p); printf("Q: %u\n", (uint32_t)q); @@ -83,6 +144,12 @@ int main() { printf("Message: %s\n", msg); printf("Decoded: %s\n", decoded); + + for (int i = 0; i < strlen(msg); i++) { + printf("coded[%d] = 0x%016lx\n", i, (unsigned long)coded[i]); + } + + debug_string(decoded); } // Exit and hang forever From 8ff7937a889b082eb773c3c6120f555861de851e Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 04:13:58 +0100 Subject: [PATCH 5/9] Track a specific commit --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index ccd6d27..bdf06c6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ -REV := master +# 'master' or hash +REV := 8ba9981e5 BASE := https://raw.githubusercontent.com/cnlohr/ch32v003fun/$(REV) CURL_FLAGS := -O -\# --fail --location --tlsv1.3 --proto =https --max-time 300 From 70dcda61e81ba07f37678c56572efafb4e630456 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 06:09:25 +0100 Subject: [PATCH 6/9] Fix naming --- assert.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assert.h b/assert.h index f8e9d05..e56b318 100644 --- a/assert.h +++ b/assert.h @@ -3,7 +3,7 @@ #include #include -#define ASSERTold(expr) \ +#define ASSERT(expr) \ do { \ if (!(expr)) { \ printf("ASSERTION FAILED: %s at %s:%d\n", #expr, __FILE__, \ From e1ece2358f72e2a1c7caa72d2f19b1c9a3498959 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 06:09:54 +0100 Subject: [PATCH 7/9] Better comments in rsa --- rsa.c | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/rsa.c b/rsa.c index ab9ec95..89a3069 100644 --- a/rsa.c +++ b/rsa.c @@ -51,12 +51,13 @@ u64 mulmod(u64 a, u64 b, u64 m) { u64 result = 0; a %= m; + // Perform the multiplication bit by bit (binary multiplication) while (b > 0) { if (b & 1) { result = (result + a) % m; } - a = (a * 2) % m; // Double a, keep within mod - b >>= 1; + a = (a * 2) % m; // Double a, keep it within the modulus + b >>= 1; // Right shift b (divide by 2) } return result; @@ -132,6 +133,7 @@ u64 mod_inverse(u64 a, u64 m) { u64 m0 = m; u64 y = 0, x = 1; + // Modular inverse does not exist when m is 1 if (m == 1) return 0; From 7d20e7f0093ee24d7295594c036e60cceedeeb7b Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 06:10:06 +0100 Subject: [PATCH 8/9] Define public exponent in header --- rsa.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rsa.h b/rsa.h index 7a129e6..910ae1a 100644 --- a/rsa.h +++ b/rsa.h @@ -3,6 +3,9 @@ #include #include +// Common public exponent, in Fermat prime form +#define PUBEXP ((1 << 16) | 0x1) + /** * @brief Calculates greatest common divider of two integers using the euclidean * algorithm From bfcbb77570ae615e4d5f14a2d8f2af61ef2080ac Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Fri, 14 Feb 2025 06:20:41 +0100 Subject: [PATCH 9/9] Remove junk, add testing and assertions --- main.c | 80 ++++++++++++++++------------------------------------------ 1 file changed, 22 insertions(+), 58 deletions(-) diff --git a/main.c b/main.c index d503654..28a423d 100644 --- a/main.c +++ b/main.c @@ -8,6 +8,7 @@ #define LED_PIN PD6 #define RANDOM +#define W 16 void exit_blink() { for (int i = 0; i < 4; i++) { @@ -60,88 +61,51 @@ int main() { enter_blink(); -#ifdef RANDOM -#define W 16 - const int64_t p = gen_prime(1 << (W - 1), 1 << W); + test_mulmod(); + test_modexp(); - int64_t qprev = p; + const u64 p = gen_prime(1 << (W - 1), 1 << W); + printf("P: %u\n", (u32)p); + + u64 qprev = p; while (p == qprev) qprev = gen_prime(1 << (W - 1), 1 << W); - const i64 q = qprev; -#undef W -#else - int64_t p = 56857; - int64_t q = 47963; -#endif + const u64 q = qprev; + printf("Q: %u\n", (u32)q); - int64_t n = p * q; - int64_t phi_n = (p - 1) * (q - 1); + ASSERT(gcd(p - 1, PUBEXP) == 1); + ASSERT(gcd(q - 1, PUBEXP) == 1); - // 'e' is public. E for encrypt. - int64_t e = 0; - do { - e = prand_range(3, phi_n - 1); - } while (gcd(e, phi_n) != 1); + u64 n = p * q; + printf("N: %u\n", (u32)n); - // 'd' is our private key. D as in decrypt - int64_t d = mod_inverse(e, phi_n); + u64 phi_n = (p - 1) * (q - 1); + printf("Phi_N: %u\n", (u32)phi_n); + + u64 d = mod_inverse(PUBEXP, phi_n); + printf("D: %u\n", (u32)d); if (d == 0 || d == 1) { printf("Modular inverse not found..."); - while (1); } - { - char test = 'o'; - u64 enc = modexp(test, e, n); - char dec = (char)modexp(enc, d, n); - - if (dec != test) { - printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); - // while (1); - } - } - { - char test = 'c'; - - u64 p = 3, q = 11; - u64 n = p * q; - u64 e = 7; - u64 d = 3; - u64 enc = modexp(test, e, n); - char dec = (char)modexp(enc, d, n); - - if (dec != test) { - printf("ERROR: %c != %c => %d != %d\n", test, dec, test, dec); - } else - printf("INFO: %c == %c => %d == %d\n", test, dec, test, dec); - } + ASSERT_EQ(mulmod(PUBEXP, d, phi_n), 1); char msg[] = "Hello"; - int64_t coded[sizeof(msg)] = {0}; + u64 coded[sizeof(msg)] = {0}; char decoded[sizeof(msg)] = {0}; // Encode the message for (int i = 0; i < strlen(msg); i++) { - coded[i] = modexp((int64_t)msg[i], e, n); + coded[i] = modexp((u64)msg[i], PUBEXP, n); } // Decode the message for (int i = 0; i < strlen(msg); i++) { - int64_t dec = modexp(coded[i], d, n); + u64 dec = modexp(coded[i], d, n); decoded[i] = dec & 0xFF; } - test_mulmod(); - test_modexp(); - { - printf("P: %u\n", (uint32_t)p); - printf("Q: %u\n", (uint32_t)q); - printf("N: %u\n", (uint32_t)n); - printf("Phi_N: %u\n", (uint32_t)phi_n); - printf("Pubkey (e): %u\n", (uint32_t)e); - printf("Privkey (d): %u\n", (uint32_t)d); - printf("Message: %s\n", msg); printf("Decoded: %s\n", decoded);