diff --git a/backend/internal/database/db.go b/backend/internal/database/db.go index 5cbb13f..82a3551 100644 --- a/backend/internal/database/db.go +++ b/backend/internal/database/db.go @@ -32,6 +32,7 @@ type Database interface { GetUserRole(username string, projectname string) (string, error) GetWeeklyReport(username string, projectName string, week int) (types.WeeklyReport, error) SignWeeklyReport(reportId int, projectManagerId int) error + IsSiteAdmin(username string) (bool, error) } // This struct is a wrapper type that holds the database connection @@ -313,6 +314,26 @@ func (d *Db) SignWeeklyReport(reportId int, projectManagerId int) error { return err } +// IsSiteAdmin checks if a given username is a site admin +func (d *Db) IsSiteAdmin(username string) (bool, error) { + // Define the SQL query to check if the user is a site admin + query := ` + SELECT COUNT(*) FROM site_admin + JOIN users ON site_admin.admin_id = users.id + WHERE users.username = ? + ` + + // Execute the query + var count int + err := d.Get(&count, query, username) + if err != nil { + return false, err + } + + // If count is greater than 0, the user is a site admin + return count > 0, nil +} + // Reads a directory of migration files and applies them to the database. // This will eventually be used on an embedded directory func (d *Db) Migrate() error { diff --git a/backend/internal/database/db_test.go b/backend/internal/database/db_test.go index 09de45b..2378b3d 100644 --- a/backend/internal/database/db_test.go +++ b/backend/internal/database/db_test.go @@ -536,3 +536,46 @@ func TestSignWeeklyReportByAnotherProjectManager(t *testing.T) { t.Error("Expected SignWeeklyReport to fail with a project manager who is not in the project, but it didn't") } } + +func TestIsSiteAdmin(t *testing.T) { + db, err := setupState() + if err != nil { + t.Error("setupState failed:", err) + } + + // Add a site admin + err = db.AddUser("admin", "password") + if err != nil { + t.Error("AddUser failed:", err) + } + + // Promote the user to site admin + err = db.PromoteToAdmin("admin") + if err != nil { + t.Error("PromoteToAdmin failed:", err) + } + + // Check if the user is a site admin + isAdmin, err := db.IsSiteAdmin("admin") + if err != nil { + t.Error("IsSiteAdmin failed:", err) + } + if !isAdmin { + t.Error("IsSiteAdmin failed: expected true, got false") + } + + // Add a regular user + err = db.AddUser("regularuser", "password") + if err != nil { + t.Error("AddUser failed:", err) + } + + // Check if the regular user is not a site admin + isRegularUserAdmin, err := db.IsSiteAdmin("regularuser") + if err != nil { + t.Error("IsSiteAdmin failed:", err) + } + if isRegularUserAdmin { + t.Error("IsSiteAdmin failed: expected false, got true") + } +}