Compare commits

...

9 commits

Author SHA1 Message Date
Imbus
bfcbb77570 Remove junk, add testing and assertions 2025-02-14 06:20:41 +01:00
Imbus
7d20e7f009 Define public exponent in header 2025-02-14 06:10:06 +01:00
Imbus
e1ece2358f Better comments in rsa 2025-02-14 06:09:54 +01:00
Imbus
70dcda61e8 Fix naming 2025-02-14 06:09:25 +01:00
Imbus
8ff7937a88 Track a specific commit 2025-02-14 04:13:58 +01:00
Imbus
557c06ca22 Some testing 2025-02-14 04:13:33 +01:00
Imbus
52857830ac Assert file 2025-02-14 04:13:23 +01:00
Imbus
120a61eca7 Type aliases 2025-02-14 04:07:40 +01:00
Imbus
2bad1303dc Type aliases for RSA related functionality 2025-02-14 04:07:03 +01:00
6 changed files with 144 additions and 72 deletions

View file

@ -1,4 +1,5 @@
REV := master # 'master' or hash
REV := 8ba9981e5
BASE := https://raw.githubusercontent.com/cnlohr/ch32v003fun/$(REV) BASE := https://raw.githubusercontent.com/cnlohr/ch32v003fun/$(REV)
CURL_FLAGS := -O -\# --fail --location --tlsv1.3 --proto =https --max-time 300 CURL_FLAGS := -O -\# --fail --location --tlsv1.3 --proto =https --max-time 300

25
assert.h Normal file
View file

@ -0,0 +1,25 @@
#pragma once
#include <stdint.h>
#include <stdio.h>
#define ASSERT(expr) \
do { \
if (!(expr)) { \
printf("ASSERTION FAILED: %s at %s:%d\n", #expr, __FILE__, \
__LINE__); \
while (1); \
} \
} while (0)
#define ASSERT_EQ(expr, expected) \
do { \
uint64_t result = (expr); \
if (result != (expected)) { \
printf("ASSERTION FAILED: %s at %s:%d\n", #expr, __FILE__, \
__LINE__); \
printf("Expected: %lu, Got: %lu\n", (unsigned long)(expected), \
(unsigned long)result); \
while (1); \
} \
} while (0)

View file

@ -1,7 +1,18 @@
#include <stdint.h>
#ifndef _FUNCONFIG_H #ifndef _FUNCONFIG_H
#define _FUNCONFIG_H #define _FUNCONFIG_H
#define CH32V003 1 #define CH32V003 1
#define NULL ((void *)0)
typedef int8_t i8;
typedef uint8_t u8;
typedef int16_t i16;
typedef uint32_t u32;
typedef int32_t i32;
typedef int64_t i64;
typedef uint64_t u64;
#endif #endif

89
main.c
View file

@ -1,10 +1,14 @@
#include "assert.h"
#include <ch32fun.h> #include <ch32fun.h>
#include <rand.h> #include <rand.h>
#include <rsa.h> #include <rsa.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <string.h>
#define LED_PIN PD6 #define LED_PIN PD6
#define RANDOM
#define W 16
void exit_blink() { void exit_blink() {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -24,6 +28,30 @@ void enter_blink() {
} }
} }
void test_mulmod() {
ASSERT_EQ(mulmod(3, 2, 4), 2);
ASSERT_EQ((3 * 2) % 4, 2);
ASSERT_EQ(mulmod(31, 3, 8), 5);
ASSERT_EQ(mulmod((u64)1 << 63, 2, 1000000007ULL), 582344008);
}
void test_modexp() {
ASSERT_EQ(modexp(3, 2, 4), 1);
ASSERT_EQ((3 ^ 2) % 4, 1);
ASSERT_EQ(modexp(31, 3, 8), 7);
ASSERT_EQ(modexp((u64)1 << 63, 2, 1000000007ULL), 319908071);
}
void debug_string(char *str) {
printf("Got string: %s\n", str);
for (int i = 0; i < strlen(str); i++) {
printf("decoded[%d] = '%c' (ASCII: %d)\n", i, str[i],
str[i]); // Print decoded chars and ASCII values
}
}
int main() { int main() {
SystemInit(); SystemInit();
sprand(0); sprand(0);
@ -33,54 +61,59 @@ int main() {
enter_blink(); enter_blink();
#ifdef RANDOM test_mulmod();
uint64_t p = gen_prime(1 << 15, 1 << 16); test_modexp();
uint64_t q = p;
while (p == q) p = gen_prime(1 << 15, 1 << 16); const u64 p = gen_prime(1 << (W - 1), 1 << W);
#else printf("P: %u\n", (u32)p);
uint64_t p = 56857;
uint64_t q = 47963;
#endif
uint64_t n = p * q; u64 qprev = p;
uint64_t phi_n = (p - 1) * (q - 1); while (p == qprev) qprev = gen_prime(1 << (W - 1), 1 << W);
// 'e' is public. E for encrypt. const u64 q = qprev;
uint64_t e = 0; printf("Q: %u\n", (u32)q);
while (gcd(e, phi_n) != 1) e = prand_range(3, phi_n - 1);
// 'd' is our private key. D as in decrypt ASSERT(gcd(p - 1, PUBEXP) == 1);
uint64_t d = mod_inverse(e, phi_n); ASSERT(gcd(q - 1, PUBEXP) == 1);
u64 n = p * q;
printf("N: %u\n", (u32)n);
u64 phi_n = (p - 1) * (q - 1);
printf("Phi_N: %u\n", (u32)phi_n);
u64 d = mod_inverse(PUBEXP, phi_n);
printf("D: %u\n", (u32)d);
if (d == 0 || d == 1) { if (d == 0 || d == 1) {
printf("Modular inverse not found..."); printf("Modular inverse not found...");
while (1);
} }
ASSERT_EQ(mulmod(PUBEXP, d, phi_n), 1);
char msg[] = "Hello"; char msg[] = "Hello";
uint64_t coded[sizeof(msg)] = {0}; u64 coded[sizeof(msg)] = {0};
char decoded[sizeof(msg)] = {0}; char decoded[sizeof(msg)] = {0};
// Encode the message // Encode the message
for (int i = 0; i < sizeof(msg); i++) { for (int i = 0; i < strlen(msg); i++) {
coded[i] = (uint64_t)modexp((uint64_t)msg[i], e, n); coded[i] = modexp((u64)msg[i], PUBEXP, n);
} }
// Decode the message // Decode the message
for (int i = 0; i < sizeof(msg); i++) { for (int i = 0; i < strlen(msg); i++) {
decoded[i] = (char)modexp(coded[i], d, n); u64 dec = modexp(coded[i], d, n);
decoded[i] = dec & 0xFF;
} }
{ {
printf("P: %u\n", (uint32_t)p);
printf("Q: %u\n", (uint32_t)q);
printf("N: %u\n", (uint32_t)n);
printf("Phi_N: %u\n", (uint32_t)phi_n);
printf("Pubkey (e): %u\n", (uint32_t)e);
printf("Privkey (d): %u\n", (uint32_t)d);
printf("Message: %s\n", msg); printf("Message: %s\n", msg);
printf("Decoded: %s\n", decoded); printf("Decoded: %s\n", decoded);
for (int i = 0; i < strlen(msg); i++) {
printf("coded[%d] = 0x%016lx\n", i, (unsigned long)coded[i]);
}
debug_string(decoded);
} }
// Exit and hang forever // Exit and hang forever

62
rsa.c
View file

@ -1,15 +1,11 @@
#include "rsa.h" #include "rsa.h"
#include "funconfig.h"
#include "rand.h" #include "rand.h"
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h>
#define NULL ((void *)0) u64 gcd(u64 a, u64 b) { return extended_euclid(a, b, NULL, NULL); }
uint64_t gcd(uint64_t a, uint64_t b) { u64 extended_euclid(u64 a, u64 b, u64 *x, u64 *y) {
return extended_euclid(a, b, NULL, NULL);
}
int extended_euclid(int a, int b, int *x, int *y) {
if (b == 0) { if (b == 0) {
if (x) if (x)
*x = 1; *x = 1;
@ -18,8 +14,8 @@ int extended_euclid(int a, int b, int *x, int *y) {
return a; return a;
} }
int x1, y1; u64 x1, y1;
int gcd = extended_euclid(b, a % b, &x1, &y1); u64 gcd = extended_euclid(b, a % b, &x1, &y1);
if (x) if (x)
*x = y1; *x = y1;
@ -29,7 +25,7 @@ int extended_euclid(int a, int b, int *x, int *y) {
return gcd; return gcd;
} }
int totient(int n) { u64 totient(u64 n) {
int result = n; int result = n;
// Check for prime factors // Check for prime factors
@ -51,23 +47,24 @@ int totient(int n) {
return result; return result;
} }
uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m) { u64 mulmod(u64 a, u64 b, u64 m) {
uint64_t result = 0; u64 result = 0;
a %= m; a %= m;
// Perform the multiplication bit by bit (binary multiplication)
while (b > 0) { while (b > 0) {
if (b & 1) { if (b & 1) {
result = (result + a) % m; // Avoid overflow result = (result + a) % m;
} }
a = (a * 2) % m; // Double a, keep within mod a = (a * 2) % m; // Double a, keep it within the modulus
b >>= 1; b >>= 1; // Right shift b (divide by 2)
} }
return result; return result;
} }
uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) { u64 modexp(u64 a, u64 b, u64 m) {
uint64_t result = 1; u64 result = 1;
a %= m; a %= m;
while (b > 0) { while (b > 0) {
@ -81,14 +78,14 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m) {
return result; return result;
} }
uint64_t gen_prime(uint64_t min, uint64_t max) { u64 gen_prime(u64 min, u64 max) {
uint64_t cand = 0; u64 cand = 0;
while (!miller_rabin(cand, 10)) cand = prand_range(min, max); while (!miller_rabin(cand, 10)) cand = prand_range(min, max);
return cand; return cand;
} }
bool is_prime(int n) { bool is_prime(u64 n) {
if (n < 2) if (n < 2)
return false; return false;
@ -100,26 +97,26 @@ bool is_prime(int n) {
return true; return true;
} }
bool miller_rabin(uint64_t n, uint64_t k) { bool miller_rabin(u64 n, u64 k) {
if (n < 2) if (n < 2)
return false; return false;
uint64_t d = n - 1; u64 d = n - 1;
uint64_t s = 0; u64 s = 0;
while (d % 2 == 0) { while (d % 2 == 0) {
d /= 2; d /= 2;
s++; s++;
} }
for (uint64_t i = 0; i < k; i++) { for (u64 i = 0; i < k; i++) {
uint64_t a = prand_range(2, n - 2); u64 a = prand_range(2, n - 2);
uint64_t x = modexp(a, d, n); u64 x = modexp(a, d, n);
if (x == 1 || x == n - 1) if (x == 1 || x == n - 1)
continue; continue;
for (uint64_t r = 1; r < s; r++) { for (u64 r = 1; r < s; r++) {
x = modexp(x, 2, n); x = modexp(x, 2, n);
if (x == n - 1) if (x == n - 1)
break; break;
@ -132,17 +129,18 @@ bool miller_rabin(uint64_t n, uint64_t k) {
return true; // Likely prime return true; // Likely prime
} }
uint64_t mod_inverse(uint64_t a, uint64_t m) { u64 mod_inverse(u64 a, u64 m) {
uint64_t m0 = m; u64 m0 = m;
uint64_t y = 0, x = 1; u64 y = 0, x = 1;
// Modular inverse does not exist when m is 1
if (m == 1) if (m == 1)
return 0; return 0;
while (a > 1) { while (a > 1) {
// q is quotient // q is quotient
uint64_t q = a / m; u64 q = a / m;
uint64_t t = m; u64 t = m;
// m is remainder now // m is remainder now
m = a % m; m = a % m;

22
rsa.h
View file

@ -1,7 +1,11 @@
#pragma once #pragma once
#include "funconfig.h"
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
// Common public exponent, in Fermat prime form
#define PUBEXP ((1 << 16) | 0x1)
/** /**
* @brief Calculates greatest common divider of two integers using the euclidean * @brief Calculates greatest common divider of two integers using the euclidean
* algorithm * algorithm
@ -10,7 +14,7 @@
* @param b Second number * @param b Second number
* @return The greatest common divider * @return The greatest common divider
*/ */
uint64_t gcd(uint64_t a, uint64_t b); u64 gcd(u64 a, u64 b);
/** /**
* @brief Computes Euler's Totient function φ(n), which counts the number of * @brief Computes Euler's Totient function φ(n), which counts the number of
@ -19,7 +23,7 @@ uint64_t gcd(uint64_t a, uint64_t b);
* @param n The input number. * @param n The input number.
* @return The number of integers from 1 to n that are coprime to n. * @return The number of integers from 1 to n that are coprime to n.
*/ */
int totient(int n); u64 totient(u64 n);
/** /**
* @brief Computes (a * b) % m safely without overflow. * @brief Computes (a * b) % m safely without overflow.
@ -32,7 +36,7 @@ int totient(int n);
* @param m The modulus. * @param m The modulus.
* @return (a * b) % m computed safely. * @return (a * b) % m computed safely.
*/ */
uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m); u64 mulmod(u64 a, u64 b, u64 m);
/** /**
* @brief Modular exponentiation (a^b) mod m * @brief Modular exponentiation (a^b) mod m
@ -41,7 +45,7 @@ uint64_t mulmod(uint64_t a, uint64_t b, uint64_t m);
* @param b The exponent * @param b The exponent
* @param m The modulus * @param m The modulus
*/ */
uint64_t modexp(uint64_t a, uint64_t b, uint64_t m); u64 modexp(u64 a, u64 b, u64 m);
/** /**
* @brief Computes the modular inverse of a modulo m. * @brief Computes the modular inverse of a modulo m.
@ -50,7 +54,7 @@ uint64_t modexp(uint64_t a, uint64_t b, uint64_t m);
* @param m The modulus. * @param m The modulus.
* @return The modular inverse of a modulo m, or -1 if no inverse exists. * @return The modular inverse of a modulo m, or -1 if no inverse exists.
*/ */
uint64_t mod_inverse(uint64_t a, uint64_t m); u64 mod_inverse(u64 a, u64 m);
/** /**
* @brief Generates a random prime number within the given range. * @brief Generates a random prime number within the given range.
@ -59,7 +63,7 @@ uint64_t mod_inverse(uint64_t a, uint64_t m);
* @param max The upper bound (inclusive). * @param max The upper bound (inclusive).
* @return A prime number in the range [min, max]. * @return A prime number in the range [min, max].
*/ */
uint64_t gen_prime(uint64_t min, uint64_t max); u64 gen_prime(u64 min, u64 max);
/** /**
* @brief Checks if a number is prime. * @brief Checks if a number is prime.
@ -67,7 +71,7 @@ uint64_t gen_prime(uint64_t min, uint64_t max);
* @param n The number to check. * @param n The number to check.
* @return true if n is prime, false otherwise. * @return true if n is prime, false otherwise.
*/ */
bool is_prime(int n); bool is_prime(u64 n);
/** /**
* @brief Performs the Miller-Rabin primality test to check if a number is * @brief Performs the Miller-Rabin primality test to check if a number is
@ -77,7 +81,7 @@ bool is_prime(int n);
* @param k The number of rounds of testing to perform. * @param k The number of rounds of testing to perform.
* @return true if n is probably prime, false if n is composite. * @return true if n is probably prime, false if n is composite.
*/ */
bool miller_rabin(uint64_t n, uint64_t k); bool miller_rabin(u64 n, u64 k);
/** /**
* @brief Computes the greatest common divisor (GCD) of two integers a and b * @brief Computes the greatest common divisor (GCD) of two integers a and b
@ -92,4 +96,4 @@ bool miller_rabin(uint64_t n, uint64_t k);
* + by = gcd(a, b). * + by = gcd(a, b).
* @return The greatest common divisor (gcd) of a and b. * @return The greatest common divisor (gcd) of a and b.
*/ */
int extended_euclid(int a, int b, int *x, int *y); u64 extended_euclid(u64 a, u64 b, u64 *x, u64 *y);