diff --git a/api/lock_manager.go b/api/lock_manager.go new file mode 100644 index 00000000..8b6120a3 --- /dev/null +++ b/api/lock_manager.go @@ -0,0 +1,64 @@ +package api + +import ( + "sync" + log "github.com/sirupsen/logrus" +) + +var ( + lockManagerInstance *UserLockManager + once sync.Once +) + +type UserLockManager struct { + locks sync.Map +} + +func GetUserLockManager() *UserLockManager { + once.Do(func() { + lockManagerInstance = &UserLockManager{} + }) + return lockManagerInstance +} + +func (m *UserLockManager) getLock(lockId string) *sync.Mutex { + lock, _ := m.locks.LoadOrStore(lockId, &sync.Mutex{}) + return lock.(*sync.Mutex) +} + +func (m *UserLockManager) TryLock(lockId string) bool { + lock := m.getLock(lockId) + return lock.TryLock() +} + +func (m *UserLockManager) Unlock(userID string) { + lockVal, exists := m.locks.Load(userID) + if !exists { + log.Errorf("unlock called for non-existent lock, user_id: %s", userID) + return + } + + lock := lockVal.(*sync.Mutex) + + // Catch panic from double-unlock + defer func() { + if r := recover(); r != nil { + log.Errorf("double unlock detected for user_id: %s", userID) + } + }() + + lock.Unlock() +} + +// Run a scheduler for periodically cleanup locks for inactive users +func (m *UserLockManager) Cleanup() { + m.locks.Range(func(key, value interface{}) bool { + mutex := value.(*sync.Mutex) + // Only delete if mutex is unlocked (not in use) + if mutex.TryLock() { + mutex.Unlock() + m.locks.Delete(key) + } + return true + }) +} diff --git a/api/submit.go b/api/submit.go index a37f3fde..166b23da 100644 --- a/api/submit.go +++ b/api/submit.go @@ -1,6 +1,7 @@ package api import ( + "fmt" "math" "net/http" "strconv" @@ -96,7 +97,6 @@ func submitFlagHandler(c *gin.Context) { }) return } - chall, err := database.QueryChallengeEntries("id", strconv.Itoa(int(parsedChallId))) if err != nil { c.JSON(http.StatusInternalServerError, HTTPErrorResp{ @@ -105,6 +105,24 @@ func submitFlagHandler(c *gin.Context) { return } + if len(chall) == 0 { + c.JSON(http.StatusBadRequest, HTTPErrorResp{ + Error: "Challenge not found.", + }) + return + } + + lockManager := GetUserLockManager() + lockId := fmt.Sprintf("%d_%s", user.ID, challId) + if !lockManager.TryLock(lockId) { + c.JSON(http.StatusOK, FlagSubmitResp{ + Message: "Request already in process", + Success: false, + }) + return + } + defer lockManager.Unlock(lockId) + challenge := chall[0] if challenge.Status != core.DEPLOY_STATUS["deployed"] { c.JSON(http.StatusOK, FlagSubmitResp{ @@ -116,14 +134,12 @@ func submitFlagHandler(c *gin.Context) { if challenge.PreReqs != "" { preReqsStatus, err := database.CheckPreReqsStatus(challenge, user.ID) - if err != nil { c.JSON(http.StatusInternalServerError, HTTPErrorResp{ Error: "DATABASE ERROR while processing the request.", }) return } - if !preReqsStatus { c.JSON(http.StatusOK, FlagSubmitResp{ Message: "You have not solved the prerequisites of this challenge.", @@ -133,32 +149,6 @@ func submitFlagHandler(c *gin.Context) { } } - if challenge.MaxAttemptLimit > 0 { - previousTries, err := database.GetUserPreviousTries(user.ID, challenge.ID) - if err != nil { - c.JSON(http.StatusInternalServerError, HTTPErrorResp{ - Error: "DATABASE ERROR while processing the request."}) - return - } - - if previousTries >= challenge.MaxAttemptLimit { - c.JSON(http.StatusOK, FlagSubmitResp{ - Message: "You have reached the maximum number of tries for this challenge.", - Success: false, - }) - return - } - } - - // Increase user tries by 1 - err = database.UpdateUserChallengeTries(user.ID, challenge.ID) - - if err != nil { - c.JSON(http.StatusInternalServerError, HTTPErrorResp{ - Error: "DATABASE ERROR while processing the request.", - }) - return - } solved, err := database.CheckPreviousSubmissions(user.ID, challenge.ID) if err != nil { c.JSON(http.StatusInternalServerError, HTTPErrorResp{ @@ -166,7 +156,6 @@ func submitFlagHandler(c *gin.Context) { }) return } - if solved { c.JSON(http.StatusOK, FlagSubmitResp{ Message: "Challenge has already been solved.", @@ -175,6 +164,29 @@ func submitFlagHandler(c *gin.Context) { return } + var previousTries int + if challenge.MaxAttemptLimit > 0 { + previousTries, err = database.GetUserPreviousTries(user.ID, challenge.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, HTTPErrorResp{ + Error: "DATABASE ERROR while processing the request."}) + return + } + if previousTries >= challenge.MaxAttemptLimit { + c.JSON(http.StatusOK, FlagSubmitResp{ + Message: "You have reached the maximum number of tries for this challenge.", + Success: false, + }) + return + } + err = database.UpdateUserChallengeTries(user.ID, challenge.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, HTTPErrorResp{ + Error: "DATABASE ERROR while processing the request.", + }) + return + } + } // If the challenge is dynamic, then the flag is not stored in the database if challenge.DynamicFlag { whereMap := map[string]interface{}{ @@ -191,6 +203,20 @@ func submitFlagHandler(c *gin.Context) { // flag not present in validFlags table if len(validFlags) == 0 { + UserChallengesEntry := database.UserChallenges{ + CreatedAt: time.Now(), + UserID: user.ID, + ChallengeID: challenge.ID, + Solved: false, + Flag: flag, + } + err = database.SaveFlagSubmission(&UserChallengesEntry) + if err != nil { + c.JSON(http.StatusInternalServerError, HTTPErrorResp{ + Error: "DATABASE ERROR while processing the request.", + }) + return + } c.JSON(http.StatusOK, FlagSubmitResp{ Message: "Your flag is incorrect", Success: false, diff --git a/core/database/challenges.go b/core/database/challenges.go index dda0a300..36ec7159 100644 --- a/core/database/challenges.go +++ b/core/database/challenges.go @@ -499,7 +499,17 @@ func SaveFlagSubmission(user_challenges *UserChallenges) error { return fmt.Errorf("error while saving record: %s", tx.Error) } - if err := tx.FirstOrCreate(user_challenges, *user_challenges).Error; err != nil { + var solvedChallenge UserChallenges + querySolved := tx.Where("user_id = ? AND challenge_id = ? AND solved = ?", user_challenges.UserID, user_challenges.ChallengeID, true).First(&solvedChallenge) + if querySolved.Error == nil { + tx.Rollback() + return fmt.Errorf("already solved: user_id %v, challenge_id %v", user_challenges.UserID, user_challenges.ChallengeID) + } else if querySolved.Error != nil && !errors.Is(querySolved.Error, gorm.ErrRecordNotFound) { + tx.Rollback() + return querySolved.Error + } + + if err := tx.Create(user_challenges).Error; err != nil { tx.Rollback() return err }