1
0
Fork 0

refactor(training): InMemoryTrainingRepository uses locking internally

This commit is contained in:
Vojtěch Mareš 2024-05-05 15:42:48 +02:00
parent e388b8cb64
commit 6f5a91a92d
Signed by: vojtech.mares
GPG key ID: C6827B976F17240D

View file

@ -1,6 +1,8 @@
package training package training
import ( import (
"sync"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"gitlab.mareshq.com/hq/yggdrasil/internal/currency" "gitlab.mareshq.com/hq/yggdrasil/internal/currency"
) )
@ -15,6 +17,7 @@ type TrainingRepository interface {
type InMemoryTrainingRepository struct { type InMemoryTrainingRepository struct {
trainings map[TrainingID]Training trainings map[TrainingID]Training
lock sync.RWMutex
} }
func NewInMemoryTrainingRepository() *InMemoryTrainingRepository { func NewInMemoryTrainingRepository() *InMemoryTrainingRepository {
@ -189,12 +192,18 @@ func NewInMemoryTrainingRepository() *InMemoryTrainingRepository {
} }
func (r *InMemoryTrainingRepository) Create(training *Training) error { func (r *InMemoryTrainingRepository) Create(training *Training) error {
r.lock.Lock()
defer r.lock.Unlock()
training.ID = NewTrainingID() training.ID = NewTrainingID()
r.trainings[training.ID] = *training r.trainings[training.ID] = *training
return nil return nil
} }
func (r *InMemoryTrainingRepository) FindByID(id TrainingID) (*Training, error) { func (r *InMemoryTrainingRepository) FindByID(id TrainingID) (*Training, error) {
r.lock.RLock()
defer r.lock.RUnlock()
training, ok := r.trainings[id] training, ok := r.trainings[id]
if !ok { if !ok {
return nil, ErrTrainingNotFound return nil, ErrTrainingNotFound
@ -203,6 +212,8 @@ func (r *InMemoryTrainingRepository) FindByID(id TrainingID) (*Training, error)
} }
func (r *InMemoryTrainingRepository) FindAll() ([]Training, error) { func (r *InMemoryTrainingRepository) FindAll() ([]Training, error) {
r.lock.RLock()
defer r.lock.RUnlock()
trainings := make([]Training, 0, len(r.trainings)) trainings := make([]Training, 0, len(r.trainings))
for _, training := range r.trainings { for _, training := range r.trainings {
@ -212,11 +223,17 @@ func (r *InMemoryTrainingRepository) FindAll() ([]Training, error) {
} }
func (r *InMemoryTrainingRepository) Update(training *Training) error { func (r *InMemoryTrainingRepository) Update(training *Training) error {
r.lock.Lock()
defer r.lock.Unlock()
r.trainings[training.ID] = *training r.trainings[training.ID] = *training
return nil return nil
} }
func (r *InMemoryTrainingRepository) Delete(id TrainingID) error { func (r *InMemoryTrainingRepository) Delete(id TrainingID) error {
r.lock.Lock()
defer r.lock.Unlock()
_, ok := r.trainings[id] _, ok := r.trainings[id]
if !ok { if !ok {
return ErrTrainingNotFound return ErrTrainingNotFound