diff --git a/go.mod b/go.mod index 2aff859..431b52f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/oapi-codegen/fiber-middleware v1.0.1 github.com/oapi-codegen/runtime v1.1.1 github.com/shopspring/decimal v1.4.0 + github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 ) diff --git a/internal/server/server.go b/internal/server/server.go index 8b9963e..059f6b9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -27,8 +27,41 @@ func NewServer(apiHandlers *APIHandlers, port int, logger *zap.Logger) *Server { } func (s *Server) Run(ctx context.Context) { - sugared := s.logger.Sugar() + middlewares := make([]fiber.Handler, 0) + loggingMiddleware := fiberzap.New(fiberzap.Config{ + Logger: s.logger, + }) + middlewares = append(middlewares, loggingMiddleware) + + app := setupServer(s.apiHandlers, middlewares, s.logger, false) + + go func() { + fiberErr := app.Listen(fmt.Sprintf(":%d", s.port)) + if fiberErr != nil { + panic(fiberErr) + } + }() + + s.logger.Info("HTTP server is now running", zap.Int("port", s.port)) + + <-ctx.Done() + shutdownBegan := time.Now() + + s.logger.Info("Shutting HTTP server down gracefully...") + err := app.ShutdownWithTimeout(10 * time.Second) + if err != nil { + panic(err) + } + + s.logger.Info("HTTP server shut down gracefully.", zap.Duration("duration", time.Since(shutdownBegan))) +} + +func NewTestingServer(apiHandlers *APIHandlers) *fiber.App { + return setupServer(apiHandlers, nil, zap.NewNop(), true) +} + +func setupServer(apiHandlers *APIHandlers, middlewares []fiber.Handler, logger *zap.Logger, testing bool) *fiber.App { config := fiber.Config{ DisableStartupMessage: true, @@ -39,9 +72,9 @@ func (s *Server) Run(ctx context.Context) { app := fiber.New(config) - app.Use(fiberzap.New(fiberzap.Config{ - Logger: s.logger, - })) + for _, mw := range middlewares { + app.Use(mw) + } swaggerConfig := swagger.Config{ BasePath: "/", @@ -49,42 +82,32 @@ func (s *Server) Run(ctx context.Context) { Path: "/swagger", Title: "Swagger API Docs", } + + if testing { + swaggerConfig = swagger.Config{ + BasePath: "/", + FilePath: "../../api/v1/openapi.yaml", + Path: "/swagger", + Title: "Swagger API Docs", + } + } + app.Use(swagger.New(swaggerConfig)) - swagger, err := GetSwagger() + serverSwaggerSpec, err := GetSwagger() if err != nil { - sugared.Fatal("Error getting swagger", zap.Error(err)) + logger.Fatal("Error getting swagger", zap.Error(err)) } // See: https://github.com/deepmap/oapi-codegen/blob/master/examples/petstore-expanded/fiber/petstore.go#L41 // Clear out the servers array in the swagger spec, that skips validating // that server names match. We don't know how this thing will be run. - swagger.Servers = nil + serverSwaggerSpec.Servers = nil - app.Use(middleware.OapiRequestValidator(swagger)) + app.Use(middleware.OapiRequestValidator(serverSwaggerSpec)) - handlers := NewStrictHandler(s.apiHandlers, nil) + handlers := NewStrictHandler(apiHandlers, nil) RegisterHandlers(app, handlers) - go func() { - err := app.Listen(fmt.Sprintf(":%d", s.port)) - if err != nil { - panic(err) - } - }() - - sugared.Infow("HTTP server is now running", "port", s.port) - - <-ctx.Done() - shutdownBegan := time.Now() - - sugared.Infoln("Shutting HTTP server down gracefully...") - err = app.ShutdownWithTimeout(10 * time.Second) - if err != nil { - panic(err) - } - - // Cleanup code here - - sugared.Infow("HTTP server shut down gracefully.", "took", time.Since(shutdownBegan)) + return app } diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..22133a8 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,124 @@ +package server + +import ( + "bytes" + "encoding/json" + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "gitlab.mareshq.com/hq/yggdrasil/pkg/training" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func doGet(t *testing.T, app *fiber.App, rawURL string) (*http.Response, error) { + + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + req := httptest.NewRequest("GET", u.RequestURI(), nil) + req.Header.Add("Accept", "application/json") + req.Host = u.Host + + return app.Test(req) +} + +func doPost(t *testing.T, app *fiber.App, rawURL string, jsonBody interface{}) (*http.Response, error) { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + buf, err := json.Marshal(jsonBody) + if err != nil { + return nil, err + } + req := httptest.NewRequest("POST", u.RequestURI(), bytes.NewReader(buf)) + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Host = u.Host + return app.Test(req) +} + +func doPut(t *testing.T, app *fiber.App, rawURL string, jsonBody interface{}) (*http.Response, error) { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + buf, err := json.Marshal(jsonBody) + if err != nil { + return nil, err + } + req := httptest.NewRequest("PUT", u.RequestURI(), bytes.NewReader(buf)) + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json") + req.Host = u.Host + return app.Test(req) +} + +func doDelete(t *testing.T, app *fiber.App, rawURL string) (*http.Response, error) { + u, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("Invalid url: %s", rawURL) + } + + req := httptest.NewRequest("DELETE", u.RequestURI(), nil) + req.Header.Add("Accept", "application/json") + req.Host = u.Host + return app.Test(req) +} + +func TestServer(t *testing.T) { + //var err error + trainingRepo := training.NewInMemoryTrainingRepository() + trainingDateRepo := training.NewInMemoryTrainingDateRepository() + trainingDateAttendeeRepo := training.NewInMemoryTrainingDateAttendeeRepository() + + handlers := NewAPIHandlers(trainingRepo, trainingDateRepo, trainingDateAttendeeRepo) + app := NewTestingServer(handlers) + + t.Run("Add training", func(t *testing.T) { + newTraining := NewTraining{ + Name: "Testing Training", + Description: "This is a test training", + Days: 1, + Pricing: []TrainingPrice{ + //{ + // Amount: decimal.NewFromInt(200), + // Currency: "EUR", + // Type: OPEN, + //}, + //{ + // Amount: decimal.NewFromInt(1000), + // Currency: "EUR", + // Type: CORPORATE, + //}, + //{ + // Amount: decimal.NewFromInt(4900), + // Currency: "CZK", + // Type: OPEN, + //}, + //{ + // Amount: decimal.NewFromInt(24000), + // Currency: "CZK", + // Type: CORPORATE, + //}, + }, + } + + rr, _ := doPost(t, app, "/v1/trainings", newTraining) + assert.Equal(t, http.StatusCreated, rr.StatusCode) + + var resultTraining Training + err := json.NewDecoder(rr.Body).Decode(&resultTraining) + assert.NoError(t, err, "error unmarshalling response") + assert.Equal(t, newTraining.Name, resultTraining.Name) + assert.Equal(t, newTraining.Description, resultTraining.Description) + assert.Equal(t, newTraining.Days, resultTraining.Days) + //assert.Equal(t, newTraining.Pricing, resultTraining.Pricing) + }) +}