2024-03-29 14:37:22 +01:00
|
|
|
package database
|
|
|
|
|
2024-04-14 09:21:34 +02:00
|
|
|
import (
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
|
|
"github.com/gofiber/fiber/v2/log"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
|
|
)
|
2024-03-29 14:37:22 +01:00
|
|
|
|
2024-04-14 09:21:34 +02:00
|
|
|
// Simple middleware that provides a transaction as a local key "db"
|
|
|
|
func DbMiddleware(db *sqlx.DB) func(c *fiber.Ctx) error {
|
2024-03-29 14:37:22 +01:00
|
|
|
return func(c *fiber.Ctx) error {
|
2024-04-14 09:21:34 +02:00
|
|
|
tx := db.MustBegin()
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
if err = tx.Rollback(); err != nil {
|
|
|
|
log.Error("Failed to rollback transaction: ", err)
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
var db_iface Database = &Db{tx}
|
|
|
|
|
|
|
|
c.Locals("db", &db_iface)
|
2024-03-29 14:37:22 +01:00
|
|
|
return c.Next()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Helper function to get the database from the context, without fiddling with casts
|
|
|
|
func GetDb(c *fiber.Ctx) Database {
|
|
|
|
// Dereference a pointer to a local, casted to a pointer to a Database
|
|
|
|
return *c.Locals("db").(*Database)
|
|
|
|
}
|