#include <assert.h>
#include <math.h>
#include <stdbool.h>
#include <stdio.h>

/*
 * Simple linear algebra mat/vec operations
 *
 * TODO: Normalizations
 * TODO: Ensure suitable test coverade
 * TODO: Mat3 adjoint should use a cofactor and a transpose function
 */

/**
 * @brief A 2x2 matrix stored in row-major order.
 *
 * The elements are laid out as:
 * [ arr[0] arr[1] ]
 * [ arr[2] arr[3] ]
 */
typedef struct Mat2 {
    float arr[4];
} Mat2;

/**
 * @brief A 3x3 matrix stored in row-major order.
 *
 * The elements are laid out as:
 * [ arr[0] arr[1] arr[2] ]
 * [ arr[3] arr[4] arr[5] ]
 * [ arr[6] arr[7] arr[8] ]
 */
typedef struct Mat3 {
    float arr[9];
} Mat3;

/**
 * @brief A 2D vector with x and y components.
 */
typedef struct Vec2 {
    float x, y;
} Vec2;

/**
 * @brief A 3D vector with x, y, and z components.
 */
typedef struct Vec3 {
    float x, y, z;
} Vec3;

/**
 * @brief Computes the dot product of two 3D vectors.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @return The dot product (a * b).
 */
inline float vec3_dot(const Vec3 *a, const Vec3 *b);

/**
 * @brief Computes the cross product of two 3D vectors.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @return The cross product vector (a x b).
 */
Vec3 vec3_cross(const Vec3 *a, const Vec3 *b);

/**
 * @brief Subtracts one 3D vector from another.
 *
 * @param a Pointer to the minuend vector.
 * @param b Pointer to the subtrahend vector.
 * @return The resulting vector (a - b).
 */
Vec3 vec3_sub(const Vec3 *a, const Vec3 *b);

/**
 * @brief Adds two 3D vectors.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @return The resulting vector (a + b).
 */
Vec3 vec3_add(const Vec3 *a, const Vec3 *b);

/**
 * @brief Scales a 3D vector by a scalar value.
 *
 * @param a Pointer to the vector to scale.
 * @param scalar The scalar value to multiply with the vector.
 * @return The scaled vector (a * scalar).
 */
Vec3 vec3_scale(const Vec3 *a, const float scalar);

/**
 * @brief Computes the dot product of two 2D vectors.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @return The dot product (a * b).
 */
inline float vec2_dot(const Vec2 *a, const Vec2 *b);

/**
 * @brief Subtracts one 2D vector from another.
 *
 * @param a Pointer to the minuend vector.
 * @param b Pointer to the subtrahend vector.
 * @return The resulting vector (a - b).
 */
Vec2 vec2_sub(const Vec2 *a, const Vec2 *b);

/**
 * @brief Adds two 2D vectors.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @return The resulting vector (a + b).
 */
Vec2 vec2_add(const Vec2 *a, const Vec2 *b);

/**
 * @brief Scales a 2D vector by a scalar value.
 *
 * @param a Pointer to the vector to scale.
 * @param scalar The scalar value to multiply with the vector.
 * @return The scaled vector (a * scalar).
 */
Vec2 vec2_scale(const Vec2 *a, const float scalar);
/**
 * @brief Computes the determinant of a 2x2 matrix.
 *
 * @param m Pointer to the matrix.
 * @return The determinant of the matrix.
 */
float mat2_det(const Mat2 *m);

/**
 * @brief Computes the determinant of a 3x3 matrix (row-major order).
 *
 * @param m Pointer to the matrix.
 * @return The determinant of the matrix.
 */
float mat3_det(const Mat3 *m);

/**
 * @brief Multiplies two 2x2 matrices (row-major order).
 *
 * @param m1 Pointer to the first matrix.
 * @param m2 Pointer to the second matrix.
 * @return The resulting matrix product (m1 x m2).
 */
Mat2 mat2_mul(const Mat2 *m1, const Mat2 *m2);

/**
 * @brief Multiplies two 3x3 matrices (row-major order).
 *
 * @param m1 Pointer to the first matrix.
 * @param m2 Pointer to the second matrix.
 * @return The resulting matrix product (m1 x m2).
 */
Mat3 mat3_mul(const Mat3 *m1, const Mat3 *m2);

/**
 * @brief Checks if two 2x2 matrices are approximately equal.
 *
 * @param a Pointer to the first matrix.
 * @param b Pointer to the second matrix.
 * @param epsilon Tolerance for comparison.
 * @return true if all elements are approximately equal within epsilon.
 */
bool mat2_approx_eq(const Mat2 *a, const Mat2 *b, float epsilon);

/**
 * @brief Checks if two 3x3 matrices are approximately equal.
 *
 * @param a Pointer to the first matrix.
 * @param b Pointer to the second matrix.
 * @param epsilon Tolerance for comparison.
 * @return true if all elements are approximately equal within epsilon.
 */
bool mat3_approx_eq(const Mat3 *a, const Mat3 *b, float epsilon);

/**
 * @brief Checks if two 3x3 vectors are approximately equal.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @param epsilon Tolerance for comparison.
 * @return true if all elements are approximately equal within epsilon.
 */
bool vec3_approx_eq(const Vec3 *a, const Vec3 *b, float epsilon);

/**
 * @brief Checks if two 2x2 vectors are approximately equal.
 *
 * @param a Pointer to the first vector.
 * @param b Pointer to the second vector.
 * @param epsilon Tolerance for comparison.
 * @return true if all elements are approximately equal within epsilon.
 */
bool vec2_approx_eq(const Vec2 *a, const Vec2 *b, float epsilon);

/**
 * @brief Computes the adjugate (adjoint) of a 2x2 matrix.
 *
 * The adjugate of a 2x2 matrix is obtained by swapping the diagonal elements
 * and negating the off-diagonal elements.
 *
 * @param m Pointer to the 2x2 matrix.
 * @return The adjugate of the input matrix.
 */
Mat2 mat2_adj(const Mat2 *m);

/**
 * @brief Computes the adjugate (adjoint) of a 3x3 matrix.
 *
 * The adjugate of a 3x3 matrix is the transpose of its cofactor matrix.
 * It is used in computing the inverse of the matrix.
 *
 * @param m Pointer to the 3x3 matrix.
 * @return The adjugate of the input matrix.
 */
Mat3 mat3_adj(const Mat3 *m);

#define MAT2_AT(m, row, col) ((m)->arr[(col) * 2 + (row)])
#define MAT3_AT(m, row, col) ((m)->arr[(col) * 3 + (row)])

/* Header end... */

/* Row major */
#define MAT2_DET(a, b, c, d) ((a) * (d) - (b) * (c))

float mat3_det(const Mat3 *m) {
    float m00 = MAT3_AT(m, 0, 0);
    float m01 = MAT3_AT(m, 0, 1);
    float m02 = MAT3_AT(m, 0, 2);

    float m10 = MAT3_AT(m, 1, 0);
    float m11 = MAT3_AT(m, 1, 1);
    float m12 = MAT3_AT(m, 1, 2);

    float m20 = MAT3_AT(m, 2, 0);
    float m21 = MAT3_AT(m, 2, 1);
    float m22 = MAT3_AT(m, 2, 2);

    return m00 * (m11 * m22 - m12 * m21) - m01 * (m10 * m22 - m12 * m20) +
           m02 * (m10 * m21 - m11 * m20);
}

inline float vec3_dot(const Vec3 *a, const Vec3 *b) {
    return a->x * b->x + a->y * b->y + a->z * b->z;
}

Vec3 vec3_cross(const Vec3 *a, const Vec3 *b) {
    Vec3 res = {.x = a->y * b->z - a->z * b->y,
                .y = a->x * b->z - a->z * b->x,
                .z = a->x * b->y - a->y * b->x};

    return res;
}

Vec3 vec3_sub(const Vec3 *a, const Vec3 *b) {
    Vec3 res = {
        .x = a->x - b->x,
        .y = a->y - b->y,
        .z = a->z - b->z,
    };

    return res;
}

Vec3 vec3_add(const Vec3 *a, const Vec3 *b) {
    Vec3 res = {
        .x = a->x + b->x,
        .y = a->y + b->y,
        .z = a->z + b->z,
    };

    return res;
}

Vec3 vec3_scale(const Vec3 *a, const float scalar) {
    Vec3 res = {
        .x = a->x * scalar,
        .y = a->y * scalar,
        .z = a->z * scalar,
    };

    return res;
}

float vec2_dot(const Vec2 *a, const Vec2 *b) {
    return a->x * b->x + a->y * b->y;
}

Vec2 vec2_sub(const Vec2 *a, const Vec2 *b) {
    Vec2 res = {
        .x = a->x - b->x,
        .y = a->y - b->y,
    };

    return res;
}

Vec2 vec2_add(const Vec2 *a, const Vec2 *b) {
    Vec2 res = {
        .x = a->x + b->x,
        .y = a->y + b->y,
    };

    return res;
}

Vec2 vec2_scale(const Vec2 *a, const float scalar) {
    Vec2 res = {
        .x = a->x * scalar,
        .y = a->y * scalar,
    };

    return res;
}

Mat2 mat2_mul(const Mat2 *m1, const Mat2 *m2) {
    Mat2 m3 = {.arr = {
                   MAT2_AT(m1, 0, 0) * MAT2_AT(m2, 0, 0) +
                       MAT2_AT(m1, 0, 1) * MAT2_AT(m2, 1, 0),
                   MAT2_AT(m1, 1, 0) * MAT2_AT(m2, 0, 0) +
                       MAT2_AT(m1, 1, 1) * MAT2_AT(m2, 1, 0),
                   MAT2_AT(m1, 0, 0) * MAT2_AT(m2, 0, 1) +
                       MAT2_AT(m1, 0, 1) * MAT2_AT(m2, 1, 1),
                   MAT2_AT(m1, 1, 0) * MAT2_AT(m2, 0, 1) +
                       MAT2_AT(m1, 1, 1) * MAT2_AT(m2, 1, 1),
               }};

    return m3;
}

Mat3 mat3_mul(const Mat3 *m1, const Mat3 *m2) {
    Mat3 m3;

    for (int col = 0; col < 3; ++col) {
        for (int row = 0; row < 3; ++row) {
            float sum = 0.0f;
            for (int k = 0; k < 3; ++k) {
                sum += MAT3_AT(m1, row, k) * MAT3_AT(m2, k, col);
            }
            m3.arr[col * 3 + row] = sum;
        }
    }

    return m3;
}

void mat3_print(const Mat3 *m) {
    for (int row = 0; row < 3; ++row) {
        printf("| ");
        for (int col = 0; col < 3; ++col) {
            printf("%8.3f ", MAT3_AT(m, row, col));
        }
        printf("|\n");
    }
}

bool mat2_approx_eq(const Mat2 *a, const Mat2 *b, float epsilon) {
    for (int i = 0; i < 4; ++i) {
        if (fabsf(a->arr[i] - b->arr[i]) > epsilon)
            return false;
    }
    return true;
}

bool mat3_approx_eq(const Mat3 *a, const Mat3 *b, float epsilon) {
    for (int i = 0; i < 9; ++i) {
        if (fabsf(a->arr[i] - b->arr[i]) > epsilon)
            return false;
    }
    return true;
}

bool vec2_approx_eq(const Vec2 *a, const Vec2 *b, float epsilon) {
    return (fabsf(a->x - b->x) <= epsilon) && (fabsf(a->y - b->y) <= epsilon);
}

bool vec3_approx_eq(const Vec3 *a, const Vec3 *b, float epsilon) {
    return (fabsf(a->x - b->x) <= epsilon) && (fabsf(a->y - b->y) <= epsilon) &&
           (fabsf(a->z - b->z) <= epsilon);
}

Mat2 mat2_adj(const Mat2 *m) {
    Mat2 res;

    MAT2_AT(&res, 0, 0) = MAT2_AT(m, 1, 1);
    MAT2_AT(&res, 0, 1) = -MAT2_AT(m, 0, 1);
    MAT2_AT(&res, 1, 0) = -MAT2_AT(m, 1, 0);
    MAT2_AT(&res, 1, 1) = MAT2_AT(m, 0, 0);

    return res;
}

Mat3 mat3_adj(const Mat3 *m) {
    Mat3 res;

    const float a = MAT3_AT(m, 0, 0);
    const float b = MAT3_AT(m, 0, 1);
    const float c = MAT3_AT(m, 0, 2);
    const float d = MAT3_AT(m, 1, 0);
    const float e = MAT3_AT(m, 1, 1);
    const float f = MAT3_AT(m, 1, 2);
    const float g = MAT3_AT(m, 2, 0);
    const float h = MAT3_AT(m, 2, 1);
    const float i = MAT3_AT(m, 2, 2);

    MAT3_AT(&res, 0, 0) = MAT2_DET(e, f, h, i);
    MAT3_AT(&res, 1, 0) = -MAT2_DET(d, f, g, i);
    MAT3_AT(&res, 2, 0) = MAT2_DET(d, e, g, h);

    MAT3_AT(&res, 0, 1) = -MAT2_DET(b, c, h, i);
    MAT3_AT(&res, 1, 1) = MAT2_DET(a, c, g, i);
    MAT3_AT(&res, 2, 1) = -MAT2_DET(a, b, g, h);

    MAT3_AT(&res, 0, 2) = MAT2_DET(b, c, e, f);
    MAT3_AT(&res, 1, 2) = -MAT2_DET(a, c, d, f);
    MAT3_AT(&res, 2, 2) = MAT2_DET(a, b, d, e);

    return res;
}

/* Implem end */

int main(void) {
    {
        Mat3 m = {{1, 0, 0, 0, 1, 0, 0, 0, 1}};

        Mat3 m3 = mat3_mul(&m, &m);
        assert(mat3_approx_eq(&m3, &m, 0.01));
    }
    {
        Mat3 m = {{1, 0, 0, 0, 1, 0, 0, 0, 1}};
        float d = mat3_det(&m);
        printf("Determinant: %f\n", d);

        MAT3_AT(&m, 0, 0) = 2;
        d = mat3_det(&m);
        printf("Determinant: %f\n", d);
    }
    {
        /* Vector tests for addition, subtraction and multiplication (scale) */

        /* Vec3 */

        Vec3 a3 = {2, 2, 2};
        Vec3 b3 = {1, 1, 1};
        Vec3 c3_sub = vec3_sub(&a3, &b3);
        Vec3 c3_add = vec3_add(&a3, &b3);
        Vec3 m3 = vec3_scale(&a3, 1.5);

        Vec3 v3_expected_add = {3, 3, 3};

        assert(vec3_approx_eq(&c3_sub, &b3, 0.01));
        assert(vec3_approx_eq(&c3_add, &v3_expected_add, 0.01));
        assert(vec3_approx_eq(&m3, &v3_expected_add, 0.01));

        /* Vec2 */

        Vec2 a2 = {2, 2};
        Vec2 b2 = {1, 1};
        Vec2 c2_sub = vec2_sub(&a2, &b2);
        Vec2 c2_add = vec2_add(&a2, &b2);
        Vec2 m2 = vec2_scale(&a2, 1.5);

        Vec2 v2_expected_add = {3, 3};

        assert(vec2_approx_eq(&c2_sub, &b2, 0.01));
        assert(vec2_approx_eq(&c2_add, &v2_expected_add, 0.01));
        assert(vec2_approx_eq(&m2, &v2_expected_add, 0.01));
    }
    {
        Vec3 a = {10, 10, 10};
        Vec3 b = {5, 5, 5};
        Vec3 c = vec3_cross(&a, &b);
        printf("{ Vec3: %f, %f, %f }\n", c.x, c.y, c.z);
    }
    {
        Vec3 a = {0, 1, 0};
        Vec3 b = {0, 0, 1};
        Vec3 c = vec3_cross(&a, &b);
        printf("{ Vec3: %f, %f, %f }\n", c.x, c.y, c.z);
    }

    return 0;
}