From b16c3b098a6ec23ac36bdf21b922472e6d3943a5 Mon Sep 17 00:00:00 2001
From: Imbus <>
Date: Thu, 13 Feb 2025 00:35:52 +0100
Subject: [PATCH] Safe mulmod, used in modexp and friends

---
 rsa.c | 65 +++++++++++++++++++++++++++++++++++------------------------
 rsa.h | 15 +++++++++++++-
 2 files changed, 53 insertions(+), 27 deletions(-)

diff --git a/rsa.c b/rsa.c
index 1320f85..b547c93 100644
--- a/rsa.c
+++ b/rsa.c
@@ -3,28 +3,28 @@
 #include <stdbool.h>
 #include <stdint.h>
 
+#define NULL ((void *)0)
+
 uint64_t gcd(uint64_t a, uint64_t b) {
-    while (b != 0) {
-        uint64_t temp = b;
-        b = a % b;
-        a = temp;
-    }
-    return a;
+    return extended_euclid(a, b, NULL, NULL);
 }
 
 int extended_euclid(int a, int b, int *x, int *y) {
     if (b == 0) {
-        *x = 1;
-        *y = 0;
+        if (x)
+            *x = 1;
+        if (y)
+            *y = 0;
         return a;
     }
 
     int x1, y1;
     int gcd = extended_euclid(b, a % b, &x1, &y1);
 
-    // Update x and y using results from recursive call
-    *x = y1;
-    *y = x1 - (a / b) * y1;
+    if (x)
+        *x = y1;
+    if (y)
+        *y = x1 - (a / b) * y1;
 
     return gcd;
 }
@@ -51,18 +51,31 @@ int totient(int n) {
     return result;
 }
 
-uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) {
-    uint64_t result = 1;
-    a = a % m; // In case a is greater than m
+uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) {
+    uint64_t result = 0;
+    a %= m;
 
     while (b > 0) {
-        // If b is odd, multiply a with result
-        if (b % 2 == 1)
-            result = (result * a) % m;
+        if (b & 1) {
+            result = (result + a) % m; // Avoid overflow
+        }
+        a = (a * 2) % m; // Double a, keep within mod
+        b >>= 1;
+    }
 
-        // b must be even now
-        b = b >> 1;      // b = b // 2
-        a = (a * a) % m; // Change a to a^2
+    return result;
+}
+
+uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) {
+    uint64_t result = 1;
+    a %= m;
+
+    while (b > 0) {
+        if (b & 1) {
+            result = mulmod(result, a, m);
+        }
+        b >>= 1;
+        a = mulmod(a, a, m);
     }
 
     return result;
@@ -70,7 +83,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) {
 
 uint64_t gen_prime(uint64_t min, uint64_t max) {
     uint64_t cand = 0;
-    while (!miller_rabin(cand, 5)) cand = prand_range(min, max);
+    while (!miller_rabin(cand, 10)) cand = prand_range(min, max);
 
     return cand;
 }
@@ -119,17 +132,17 @@ bool miller_rabin(uint64_t n, uint64_t k) {
     return true; // Likely prime
 }
 
-int mod_inverse(int a, int m) {
-    int m0 = m;
-    int y = 0, x = 1;
+uint64_t mod_inverse(uint64_t a, uint64_t m) {
+    uint64_t m0 = m;
+    uint64_t y = 0, x = 1;
 
     if (m == 1)
         return 0;
 
     while (a > 1) {
         // q is quotient
-        int q = a / m;
-        int t = m;
+        uint64_t q = a / m;
+        uint64_t t = m;
 
         // m is remainder now
         m = a % m;
diff --git a/rsa.h b/rsa.h
index e56004d..53bc78d 100644
--- a/rsa.h
+++ b/rsa.h
@@ -21,6 +21,19 @@ uint64_t gcd(uint64_t a, uint64_t b);
  */
 int totient(int n);
 
+/**
+ * @brief Computes (a * b) % m safely without overflow.
+ *
+ * Uses repeated addition and bit shifting to handle large values,
+ * ensuring correctness even on 32-bit microcontrollers.
+ *
+ * @param a The first operand.
+ * @param b The second operand.
+ * @param m The modulus.
+ * @return (a * b) % m computed safely.
+ */
+uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m);
+
 /**
  * @brief Modular exponentiation (a^b) mod m
  *
@@ -37,7 +50,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m);
  * @param m The modulus.
  * @return The modular inverse of a modulo m, or -1 if no inverse exists.
  */
-int mod_inverse(int a, int m);
+uint64_t mod_inverse(uint64_t a, uint64_t m);
 
 /**
  * @brief Generates a random prime number within the given range.