1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-19 13:07:52 +00:00

196 lines
5.1 KiB
Go
Raw Normal View History

2020-09-13 11:20:11 +02:00
package cors
import (
"log"
2020-09-13 11:20:11 +02:00
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool
2020-09-13 11:20:11 +02:00
// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
//
// Optional. Default value false.
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
//
// Optional. Default value 0.
MaxAge int
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
2020-09-13 11:20:11 +02:00
AllowMethods: strings.Join([]string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodHead,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
}, ","),
AllowHeaders: "",
AllowCredentials: false,
ExposeHeaders: "",
MaxAge: 0,
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.AllowMethods == "" {
cfg.AllowMethods = ConfigDefault.AllowMethods
}
2020-09-16 15:45:59 +08:00
if cfg.AllowOrigins == "" {
cfg.AllowOrigins = ConfigDefault.AllowOrigins
}
2020-09-13 11:20:11 +02:00
}
// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != ConfigDefault.AllowOrigins && cfg.AllowOriginsFunc != nil {
log.Printf("[Warning] - [CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.\n")
}
2020-09-13 11:20:11 +02:00
// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")
2020-09-13 11:20:11 +02:00
// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
exposeHeaders := strings.ReplaceAll(cfg.ExposeHeaders, " ", "")
2020-09-13 11:20:11 +02:00
// Convert int to string
maxAge := strconv.Itoa(cfg.MaxAge)
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get origin header
origin := c.Get(fiber.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range allowOrigins {
if o == "*" {
allowOrigin = "*"
2020-09-13 11:20:11 +02:00
break
}
if o == origin {
2020-09-13 11:20:11 +02:00
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if (allowOrigin == "" || allowOrigin == ConfigDefault.AllowOrigins) && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}
}
2020-09-13 11:20:11 +02:00
// Simple request
if c.Method() != fiber.MethodOptions {
2020-09-13 11:20:11 +02:00
c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return c.Next()
}
// Preflight request
c.Vary(fiber.HeaderOrigin)
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
// Set Allow-Credentials if set to true
if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set MaxAge is set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
}
// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}