mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-22 22:23:42 +00:00
175 lines
4.2 KiB
Go
175 lines
4.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/gofiber/fiber"
|
|
)
|
|
|
|
// CORSConfig ...
|
|
type CORSConfig struct {
|
|
Skip func(*fiber.Ctx) bool
|
|
// Optional. Default value []string{"*"}.
|
|
AllowOrigins []string
|
|
// Optional. Default value []string{"GET","POST","HEAD","PUT","DELETE","PATCH"}
|
|
AllowMethods []string
|
|
// Optional. Default value []string{}.
|
|
AllowHeaders []string
|
|
// Optional. Default value false.
|
|
AllowCredentials bool
|
|
// Optional. Default value []string{}.
|
|
ExposeHeaders []string
|
|
// Optional. Default value 0.
|
|
MaxAge int
|
|
}
|
|
|
|
// CorsConfigDefault is the defaul Cors middleware config.
|
|
var CorsConfigDefault = CORSConfig{
|
|
Skip: nil,
|
|
AllowOrigins: []string{"*"},
|
|
AllowMethods: []string{
|
|
http.MethodGet,
|
|
http.MethodPost,
|
|
http.MethodHead,
|
|
http.MethodPut,
|
|
http.MethodDelete,
|
|
http.MethodPatch,
|
|
},
|
|
}
|
|
|
|
// Cors ...
|
|
func Cors(config ...CORSConfig) func(*fiber.Ctx) {
|
|
log.Println("Warning: middleware.Cors() is deprecated since v1.8.3, please use github.com/gofiber/cors")
|
|
// Init config
|
|
var cfg CORSConfig
|
|
// Set config if provided
|
|
if len(config) > 0 {
|
|
cfg = config[0]
|
|
}
|
|
// Set config default values
|
|
if len(cfg.AllowOrigins) == 0 {
|
|
cfg.AllowOrigins = CorsConfigDefault.AllowOrigins
|
|
}
|
|
if len(cfg.AllowMethods) == 0 {
|
|
cfg.AllowMethods = CorsConfigDefault.AllowMethods
|
|
}
|
|
// Middleware settings
|
|
allowMethods := strings.Join(cfg.AllowMethods, ",")
|
|
allowHeaders := strings.Join(cfg.AllowHeaders, ",")
|
|
exposeHeaders := strings.Join(cfg.ExposeHeaders, ",")
|
|
maxAge := strconv.Itoa(cfg.MaxAge)
|
|
// Middleware function
|
|
return func(c *fiber.Ctx) {
|
|
// Skip middleware if Skip returns true
|
|
if cfg.Skip != nil && cfg.Skip(c) {
|
|
c.Next()
|
|
return
|
|
}
|
|
origin := c.Get(fiber.HeaderOrigin)
|
|
allowOrigin := ""
|
|
// Check allowed origins
|
|
for _, o := range cfg.AllowOrigins {
|
|
if o == "*" && cfg.AllowCredentials {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
if o == "*" || o == origin {
|
|
allowOrigin = o
|
|
break
|
|
}
|
|
if matchSubdomain(origin, o) {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
// Simple request
|
|
if c.Method() != http.MethodOptions {
|
|
c.Vary(fiber.HeaderOrigin)
|
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
|
|
if cfg.AllowCredentials {
|
|
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
if exposeHeaders != "" {
|
|
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
|
|
}
|
|
c.Next()
|
|
return
|
|
}
|
|
// Preflight request
|
|
c.Vary(fiber.HeaderOrigin)
|
|
c.Vary(fiber.HeaderAccessControlRequestMethod)
|
|
c.Vary(fiber.HeaderAccessControlRequestHeaders)
|
|
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
|
|
|
|
if cfg.AllowCredentials {
|
|
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
if allowHeaders != "" {
|
|
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
|
|
} else {
|
|
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
|
|
if h != "" {
|
|
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
|
|
}
|
|
}
|
|
if cfg.MaxAge > 0 {
|
|
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
|
|
}
|
|
c.SendStatus(204) // No Content
|
|
}
|
|
}
|
|
|
|
func matchScheme(domain, pattern string) bool {
|
|
didx := strings.Index(domain, ":")
|
|
pidx := strings.Index(pattern, ":")
|
|
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
|
}
|
|
|
|
// matchSubdomain compares authority with wildcard
|
|
func matchSubdomain(domain, pattern string) bool {
|
|
if !matchScheme(domain, pattern) {
|
|
return false
|
|
}
|
|
didx := strings.Index(domain, "://")
|
|
pidx := strings.Index(pattern, "://")
|
|
if didx == -1 || pidx == -1 {
|
|
return false
|
|
}
|
|
domAuth := domain[didx+3:]
|
|
// to avoid long loop by invalid long domain
|
|
if len(domAuth) > 253 {
|
|
return false
|
|
}
|
|
patAuth := pattern[pidx+3:]
|
|
|
|
domComp := strings.Split(domAuth, ".")
|
|
patComp := strings.Split(patAuth, ".")
|
|
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
|
opp := len(domComp) - 1 - i
|
|
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
|
}
|
|
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
|
opp := len(patComp) - 1 - i
|
|
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
|
}
|
|
|
|
for i, v := range domComp {
|
|
if len(patComp) <= i {
|
|
return false
|
|
}
|
|
p := patComp[i]
|
|
if p == "*" {
|
|
return true
|
|
}
|
|
if p != v {
|
|
return false
|
|
}
|
|
}
|
|
return false
|
|
}
|