diff --git a/pkg/training/repository.go b/pkg/training/repository.go index 5fc2313..0fc0c86 100644 --- a/pkg/training/repository.go +++ b/pkg/training/repository.go @@ -2,6 +2,7 @@ package training import ( "sync" + "time" "github.com/shopspring/decimal" "gitlab.mareshq.com/hq/yggdrasil/internal/currency" @@ -248,6 +249,8 @@ type TrainingDateRepository interface { FindByID(id TrainingDateID) (*TrainingDate, error) FindAll() ([]TrainingDate, error) FindAllByTrainingID(trainingID TrainingID) ([]TrainingDate, error) + FindUpcomingByTrainingID(trainingID TrainingID) ([]TrainingDate, error) + FindAllUpcoming() ([]TrainingDate, error) Update(trainingDate *TrainingDate) error Delete(id TrainingDateID) error } @@ -313,6 +316,37 @@ func (r *InMemoryTrainingDateRepository) FindAllByTrainingID(trainingID Training return dates, nil } +func (r *InMemoryTrainingDateRepository) FindUpcomingByTrainingID(trainingID TrainingID) ([]TrainingDate, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + now := time.Now() + var dates []TrainingDate + for _, id := range r.trainingToDates[trainingID] { + date := r.trainingDates[id] + if date.Date.After(now) { + dates = append(dates, date) + } + } + + return dates, nil +} + +func (r *InMemoryTrainingDateRepository) FindAllUpcoming() ([]TrainingDate, error) { + r.lock.RLock() + defer r.lock.RUnlock() + + now := time.Now() + var dates []TrainingDate + for _, date := range r.trainingDates { + if date.Date.After(now) { + dates = append(dates, date) + } + } + + return dates, nil +} + func (r *InMemoryTrainingDateRepository) Update(trainingDate *TrainingDate) error { r.lock.Lock() defer r.lock.Unlock()