#include "hashmap.h"
#include <stdlib.h>
#include <string.h>

typedef struct hashmap_entry {
    char *key;
    void *value;
    struct hashmap_entry *next;
} hashmap_entry_t;

struct hashmap {
    hashmap_entry_t **buckets;
    size_t bucket_count;
};

static unsigned long hash(const char *str) {
    unsigned long hash = 5381;
    int c;
    while ((c = *str++)) hash = ((hash << 5) + hash) + c;
    return hash;
}

hashmap_t *hashmap_create(size_t bucket_count) {
    hashmap_t *map = malloc(sizeof(hashmap_t));
    map->buckets = calloc(bucket_count, sizeof(hashmap_entry_t *));
    map->bucket_count = bucket_count;
    return map;
}

void hashmap_destroy(hashmap_t *map) {
    for (size_t i = 0; i < map->bucket_count; ++i) {
        hashmap_entry_t *entry = map->buckets[i];
        while (entry) {
            hashmap_entry_t *next = entry->next;
            free(entry->key);
            free(entry);
            entry = next;
        }
    }

    free(map->buckets);
    free(map);
}

void hashmap_put(hashmap_t *map, const char *key, void *value) {
    size_t index = (hash(key) % map->bucket_count);
    hashmap_entry_t *entry = map->buckets[index];

    while (entry) {
        if (strcmp(entry->key, key) == 0) {
            entry->value = value;
            return;
        }

        entry = entry->next;
    }

    hashmap_entry_t *new_entry = malloc(sizeof(hashmap_entry_t));
    new_entry->key = strdup(key);
    new_entry->value = value;
    new_entry->next = map->buckets[index];
    map->buckets[index] = new_entry;
}

void *hashmap_get(hashmap_t *map, const char *key) {
    size_t index = (hash(key) % map->bucket_count);
    hashmap_entry_t *entry = map->buckets[index];

    while (entry) {
        if (strcmp(entry->key, key) == 0) {
            return entry->value;
        }

        entry = entry->next;
    }

    return NULL;
}

void hashmap_remove(hashmap_t *map, const char *key) {
    size_t index = (hash(key) % map->bucket_count);

    hashmap_entry_t **prev = &map->buckets[index];
    hashmap_entry_t *entry = *prev;

    while (entry) {
        if (strcmp(entry->key, key) == 0) {
            *prev = entry->next;
            free(entry->key);
            free(entry);
            return;
        }

        prev = &entry->next;
        entry = entry->next;
    }
}