From 50e640ea8435848ebc4777b72c4c9d3b2d2957ae Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Thu, 13 Feb 2025 00:35:28 +0100 Subject: [PATCH 1/2] Seed saving to flash with regular intervals --- rand.c | 71 ++++++++++++++++++++++++++++++++++++++++++++++++++++------ rand.h | 25 ++++++++++++++++----- 2 files changed, 84 insertions(+), 12 deletions(-) diff --git a/rand.c b/rand.c index 74cb7ad..062a64a 100644 --- a/rand.c +++ b/rand.c @@ -1,22 +1,79 @@ +#include "rand.h" +#include "ch32fun.h" #include -#include +#define BUILD_SEED \ + ((uint64_t)(__TIME__[0]) * (uint64_t)(__TIME__[1]) * \ + (uint64_t)(__TIME__[3]) * (uint64_t)(__TIME__[4]) * \ + (uint64_t)(__TIME__[6]) * (uint64_t)(__TIME__[7])) -#define BUILD_SEED ((uint64_t)(__TIME__[0]) * (uint64_t)(__TIME__[1]) * \ - (uint64_t)(__TIME__[3]) * (uint64_t)(__TIME__[4]) * \ - (uint64_t)(__TIME__[6]) * (uint64_t)(__TIME__[7])) +#define FLASH_SEED_ADDR ((uintptr_t *)0x08003700) // PRNG state storage +#define PRNG_SAVE_INTERVAL 50 // Save every 1000 calls to prand() static uint64_t seed = BUILD_SEED; -void sprand(uint64_t s) { - seed = s ? s : 1; // Ensure the seed is never 0 -} +// Initialize this to something close to interval +static int prand_counter = PRNG_SAVE_INTERVAL - 10; uint64_t prand() { seed = seed * 6364136223846793005ULL + 1; + + if (++prand_counter >= PRNG_SAVE_INTERVAL) { + rand_save_to_flash(); + prand_counter = 0; + } + return seed; } uint64_t prand_range(uint64_t min, uint64_t max) { return min + (prand() % (max - min + 1)); } + +void sprand(uint64_t s) { + if (s) { + seed = s; + } else { + rand_reseed(); + } +} + +void rand_reseed() { + uint64_t stored_seed = *(volatile uint64_t *)FLASH_SEED_ADDR; + + if (stored_seed == 0 || stored_seed == 0xFFFFFFFFFFFFFFFFULL) { + seed = BUILD_SEED; + } else { + seed = stored_seed; + } +} + +// See: +// https://github.com/cnlohr/ch32v003fun/blob/2ac62072272f2ccd2122e688a9e0566de3976a94/examples/flashtest/flashtest.c +void rand_save_to_flash() { + FLASH->KEYR = 0x45670123; // Unlock flash + FLASH->KEYR = 0xCDEF89AB; + + FLASH->MODEKEYR = 0x45670123; // Unlock programming mode + FLASH->MODEKEYR = 0xCDEF89AB; + + // Erase the flash page + FLASH->CTLR = CR_PAGE_ER; + FLASH->ADDR = (intptr_t)FLASH_SEED_ADDR; + FLASH->CTLR = CR_STRT_Set | CR_PAGE_ER; + while (FLASH->STATR & FLASH_STATR_BSY); // Wait for erase + + // Write new seed + FLASH->CTLR = CR_PAGE_PG; + FLASH->CTLR = CR_BUF_RST | CR_PAGE_PG; + FLASH->ADDR = (intptr_t)FLASH_SEED_ADDR; + + ((uint32_t *)FLASH_SEED_ADDR)[0] = (uint32_t)seed; + ((uint32_t *)FLASH_SEED_ADDR)[1] = (uint32_t)(seed >> 32); + + FLASH->CTLR = CR_PAGE_PG | FLASH_CTLR_BUF_LOAD; + while (FLASH->STATR & FLASH_STATR_BSY); // Wait for completion + + FLASH->CTLR = CR_PAGE_PG | CR_STRT_Set; // Commit write + while (FLASH->STATR & FLASH_STATR_BSY); // Wait for completion +} diff --git a/rand.h b/rand.h index 3722734..26c4194 100644 --- a/rand.h +++ b/rand.h @@ -2,18 +2,18 @@ #include /** - * @brief Sets the seed for the custom random number generator. + * @brief Sets the seed for the PRNG. * - * This function initializes the seed value used by rand_custom(). - * Providing the same seed will produce the same sequence of random numbers. - * - * @param s The seed value (must be nonzero for best results). + * @param s The specific seed value or zero. If zero is passed, it will call + * rand_reseed(). */ void sprand(uint64_t s); /** * @brief Generates a pseudo-random 64-bit number. * + * Saves PRNG state to flash periodically. + * * Uses a simple Linear Congruential Generator (LCG) to produce * a sequence of pseudo-random numbers. * @@ -32,3 +32,18 @@ uint64_t prand(); * @return A random number between min and max. */ uint64_t prand_range(uint64_t min, uint64_t max); + +/** + * @brief Saves the current PRNG seed to flash memory. + * + * This function erases the designated flash page and writes the current seed + * to ensure the PRNG state persists across resets. + */ +void rand_save_to_flash(); + +/** + * @brief Re-seeds the PRNG seed state from either flash or BUILD_SEED. + * + * This function will not write to flash. + */ +void rand_reseed(); From b16c3b098a6ec23ac36bdf21b922472e6d3943a5 Mon Sep 17 00:00:00 2001 From: Imbus <> Date: Thu, 13 Feb 2025 00:35:52 +0100 Subject: [PATCH 2/2] Safe mulmod, used in modexp and friends --- rsa.c | 65 +++++++++++++++++++++++++++++++++++------------------------ rsa.h | 15 +++++++++++++- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/rsa.c b/rsa.c index 1320f85..b547c93 100644 --- a/rsa.c +++ b/rsa.c @@ -3,28 +3,28 @@ #include #include +#define NULL ((void *)0) + uint64_t gcd(uint64_t a, uint64_t b) { - while (b != 0) { - uint64_t temp = b; - b = a % b; - a = temp; - } - return a; + return extended_euclid(a, b, NULL, NULL); } int extended_euclid(int a, int b, int *x, int *y) { if (b == 0) { - *x = 1; - *y = 0; + if (x) + *x = 1; + if (y) + *y = 0; return a; } int x1, y1; int gcd = extended_euclid(b, a % b, &x1, &y1); - // Update x and y using results from recursive call - *x = y1; - *y = x1 - (a / b) * y1; + if (x) + *x = y1; + if (y) + *y = x1 - (a / b) * y1; return gcd; } @@ -51,18 +51,31 @@ int totient(int n) { return result; } -uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { - uint64_t result = 1; - a = a % m; // In case a is greater than m +uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { + uint64_t result = 0; + a %= m; while (b > 0) { - // If b is odd, multiply a with result - if (b % 2 == 1) - result = (result * a) % m; + if (b & 1) { + result = (result + a) % m; // Avoid overflow + } + a = (a * 2) % m; // Double a, keep within mod + b >>= 1; + } - // b must be even now - b = b >> 1; // b = b // 2 - a = (a * a) % m; // Change a to a^2 + return result; +} + +uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { + uint64_t result = 1; + a %= m; + + while (b > 0) { + if (b & 1) { + result = mulmod(result, a, m); + } + b >>= 1; + a = mulmod(a, a, m); } return result; @@ -70,7 +83,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { uint64_t gen_prime(uint64_t min, uint64_t max) { uint64_t cand = 0; - while (!miller_rabin(cand, 5)) cand = prand_range(min, max); + while (!miller_rabin(cand, 10)) cand = prand_range(min, max); return cand; } @@ -119,17 +132,17 @@ bool miller_rabin(uint64_t n, uint64_t k) { return true; // Likely prime } -int mod_inverse(int a, int m) { - int m0 = m; - int y = 0, x = 1; +uint64_t mod_inverse(uint64_t a, uint64_t m) { + uint64_t m0 = m; + uint64_t y = 0, x = 1; if (m == 1) return 0; while (a > 1) { // q is quotient - int q = a / m; - int t = m; + uint64_t q = a / m; + uint64_t t = m; // m is remainder now m = a % m; diff --git a/rsa.h b/rsa.h index e56004d..53bc78d 100644 --- a/rsa.h +++ b/rsa.h @@ -21,6 +21,19 @@ uint64_t gcd(uint64_t a, uint64_t b); */ int totient(int n); +/** + * @brief Computes (a * b) % m safely without overflow. + * + * Uses repeated addition and bit shifting to handle large values, + * ensuring correctness even on 32-bit microcontrollers. + * + * @param a The first operand. + * @param b The second operand. + * @param m The modulus. + * @return (a * b) % m computed safely. + */ +uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m); + /** * @brief Modular exponentiation (a^b) mod m * @@ -37,7 +50,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m); * @param m The modulus. * @return The modular inverse of a modulo m, or -1 if no inverse exists. */ -int mod_inverse(int a, int m); +uint64_t mod_inverse(uint64_t a, uint64_t m); /** * @brief Generates a random prime number within the given range.