1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-22 22:23:42 +00:00
fiber/middleware/cors.go
2020-03-14 12:30:21 +01:00

175 lines
4.2 KiB
Go

package middleware
import (
"log"
"net/http"
"strconv"
"strings"
"github.com/gofiber/fiber"
)
// CORSConfig ...
type CORSConfig struct {
Skip func(*fiber.Ctx) bool
// Optional. Default value []string{"*"}.
AllowOrigins []string
// Optional. Default value []string{"GET","POST","HEAD","PUT","DELETE","PATCH"}
AllowMethods []string
// Optional. Default value []string{}.
AllowHeaders []string
// Optional. Default value false.
AllowCredentials bool
// Optional. Default value []string{}.
ExposeHeaders []string
// Optional. Default value 0.
MaxAge int
}
// CorsConfigDefault is the defaul Cors middleware config.
var CorsConfigDefault = CORSConfig{
Skip: nil,
AllowOrigins: []string{"*"},
AllowMethods: []string{
http.MethodGet,
http.MethodPost,
http.MethodHead,
http.MethodPut,
http.MethodDelete,
http.MethodPatch,
},
}
// Cors ...
func Cors(config ...CORSConfig) func(*fiber.Ctx) {
log.Println("Warning: middleware.Cors() is deprecated since v1.8.3, please use github.com/gofiber/cors")
// Init config
var cfg CORSConfig
// Set config if provided
if len(config) > 0 {
cfg = config[0]
}
// Set config default values
if len(cfg.AllowOrigins) == 0 {
cfg.AllowOrigins = CorsConfigDefault.AllowOrigins
}
if len(cfg.AllowMethods) == 0 {
cfg.AllowMethods = CorsConfigDefault.AllowMethods
}
// Middleware settings
allowMethods := strings.Join(cfg.AllowMethods, ",")
allowHeaders := strings.Join(cfg.AllowHeaders, ",")
exposeHeaders := strings.Join(cfg.ExposeHeaders, ",")
maxAge := strconv.Itoa(cfg.MaxAge)
// Middleware function
return func(c *fiber.Ctx) {
// Skip middleware if Skip returns true
if cfg.Skip != nil && cfg.Skip(c) {
c.Next()
return
}
origin := c.Get(fiber.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range cfg.AllowOrigins {
if o == "*" && cfg.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Simple request
if c.Method() != http.MethodOptions {
c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
c.Next()
return
}
// Preflight request
c.Vary(fiber.HeaderOrigin)
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
}
c.SendStatus(204) // No Content
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}