diff --git a/pkg/training/repository.go b/pkg/training/repository.go index 642b8a3..2e4fce0 100644 --- a/pkg/training/repository.go +++ b/pkg/training/repository.go @@ -1,6 +1,8 @@ package training import ( + "sync" + "github.com/shopspring/decimal" "gitlab.mareshq.com/hq/yggdrasil/internal/currency" ) @@ -15,6 +17,7 @@ type TrainingRepository interface { type InMemoryTrainingRepository struct { trainings map[TrainingID]Training + lock sync.RWMutex } func NewInMemoryTrainingRepository() *InMemoryTrainingRepository { @@ -189,12 +192,18 @@ func NewInMemoryTrainingRepository() *InMemoryTrainingRepository { } func (r *InMemoryTrainingRepository) Create(training *Training) error { + r.lock.Lock() + defer r.lock.Unlock() + training.ID = NewTrainingID() r.trainings[training.ID] = *training return nil } func (r *InMemoryTrainingRepository) FindByID(id TrainingID) (*Training, error) { + r.lock.RLock() + defer r.lock.RUnlock() + training, ok := r.trainings[id] if !ok { return nil, ErrTrainingNotFound @@ -203,6 +212,8 @@ func (r *InMemoryTrainingRepository) FindByID(id TrainingID) (*Training, error) } func (r *InMemoryTrainingRepository) FindAll() ([]Training, error) { + r.lock.RLock() + defer r.lock.RUnlock() trainings := make([]Training, 0, len(r.trainings)) for _, training := range r.trainings { @@ -212,11 +223,17 @@ func (r *InMemoryTrainingRepository) FindAll() ([]Training, error) { } func (r *InMemoryTrainingRepository) Update(training *Training) error { + r.lock.Lock() + defer r.lock.Unlock() + r.trainings[training.ID] = *training return nil } func (r *InMemoryTrainingRepository) Delete(id TrainingID) error { + r.lock.Lock() + defer r.lock.Unlock() + _, ok := r.trainings[id] if !ok { return ErrTrainingNotFound