diff --git a/greataped/config/env.go b/greataped/config/env.go index ba338c5..e1e819d 100644 --- a/greataped/config/env.go +++ b/greataped/config/env.go @@ -6,11 +6,11 @@ import ( ) var ( - PROTOCOL = getEnv("PROTOCOL", "http") - DOMAIN = getEnv("DOMAIN", "localhost") - PORT = getEnv("PORT", "80") - + PROTOCOL = getEnv("PROTOCOL", "http") + DOMAIN = getEnv("DOMAIN", "localhost") + PORT = getEnv("PORT", "80") SQLITE_DB = getEnv("SQLITE_DB", "db.sqlite") + // TOKENKEY returns the jwt token secret TOKENKEY = getEnv("TOKEN_KEY", "put-your-secure-jwt-secret-key-here") // TOKENEXP returns the jwt token expiration duration. @@ -21,6 +21,8 @@ var ( // Maximum allowed upload file size in megabytes. MAX_UPLOAD_SIZE = getEnv("MAX_UPLOAD_SIZE", "1") UPLOAD_PATH = getEnv("UPLOAD_PATH", "./upload") + CSRF_PROTECTION = getEnv("CSRF_PROTECTION", "false") + RATE_LIMITER = getEnv("RATE_LIMITER", "false") ) func getEnv(name string, fallback string) string { @@ -31,6 +33,14 @@ func getEnv(name string, fallback string) string { return fallback } +func CsrfProtection() bool { + return CSRF_PROTECTION == "true" +} + +func RateLimiter() bool { + return RATE_LIMITER == "true" +} + func BodyLimit() int { maxFileSize, err := strconv.ParseInt(MAX_UPLOAD_SIZE, 10, 32) if err != nil { diff --git a/greataped/server/http_server.go b/greataped/server/http_server.go index b4a332f..c5582b6 100644 --- a/greataped/server/http_server.go +++ b/greataped/server/http_server.go @@ -8,6 +8,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/gofiber/fiber/v2/middleware/csrf" "github.com/gofiber/fiber/v2/middleware/limiter" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/recover" @@ -44,12 +45,6 @@ func New() IServer { cors.New(), recover.New(), helmet.New(), - // csrf.New(), - limiter.New(limiter.Config{ - Max: 20, - Expiration: 30 * time.Second, - LimiterMiddleware: limiter.SlidingWindow{}, - }), logger.New(logger.Config{ Next: nil, Format: "[${time}] ${status} - ${latency} ${method} ${path} ${body}\n", @@ -59,6 +54,20 @@ func New() IServer { }), ) + if config.CsrfProtection() { + framework.Use(csrf.New()) + } + + if config.RateLimiter() { + framework.Use( + limiter.New(limiter.Config{ + Max: 20, + Expiration: 30 * time.Second, + LimiterMiddleware: limiter.SlidingWindow{}, + }), + ) + } + framework.Static("/media", config.UPLOAD_PATH) framework.Group("/api/v1/profile").Use(authorize.New()) // framework.Get("/u/:name/inbox").Use(authorize.New())