mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-21 23:33:18 +00:00
🔥 Feature: Add encrypt cookies middleware (#1343)
* 🔥 Feature: Add encrypt cookies middleware
* Encrypt cookies when error happens
* Improve encrypt cookie middleware
* Fix errors
* Update encryptcookie config doc blocks
* Change `SetCookie` to `SetCookieBytesKV` for invalid cookies
* Update middleware/encryptcookie/config.go
* Update README.md
* Remove `GenerateKey` parameter
* Update README.md
Co-authored-by: hi019 <65871571+hi019@users.noreply.github.com>
This commit is contained in:
parent
d89207831d
commit
bff8843abd
68
middleware/encryptcookie/README.md
Normal file
68
middleware/encryptcookie/README.md
Normal file
@ -0,0 +1,68 @@
|
||||
# Encrypt Cookie Middleware
|
||||
|
||||
Encrypt middleware for [Fiber](https://github.com/gofiber/fiber) which encrypts cookie values. Note: this middleware does not encrypt cookie names.
|
||||
|
||||
|
||||
## Signaures
|
||||
|
||||
```go
|
||||
// Intitializes the middleware
|
||||
func New(config ...Config) fiber.Handler
|
||||
|
||||
// Returns a random 32 character long string
|
||||
func GenerateKey() string
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
First import the middleware from Fiber,
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cache"
|
||||
)
|
||||
|
||||
Then create a Fiber app with `app := fiber.New()`.
|
||||
|
||||
## Minimum Config
|
||||
|
||||
```go
|
||||
// `Key` must be a 32 character string. It's used to encrpyt the values, so make sure it is random and keep it secret.
|
||||
// You can call `encryptcookie.GenerateKey()` to create a random key for you.
|
||||
// Make sure not to set `Key` to `encryptcookie.GenerateKey()` because that will create a new key every run.
|
||||
app.Use(encryptcookie.New(encryptcookie.Config{
|
||||
Key: "secret-thirty-2-character-string",
|
||||
}))
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Array of cookie keys that should not be encrypted.
|
||||
//
|
||||
// Optional. Default: []
|
||||
Except []string
|
||||
|
||||
// Base64 encoded unique key to encode & decode cookies.
|
||||
//
|
||||
// Required. Key length should be 32 characters.
|
||||
// You may use `encryptcookie.GenerateKey()` to generate a new key.
|
||||
Key string
|
||||
|
||||
// Custom function to encrypt cookies.
|
||||
//
|
||||
// Optional. Default: EncryptCookie
|
||||
Encryptor func(decryptedString, key string) (string, error)
|
||||
|
||||
// Custom function to decrypt cookies.
|
||||
//
|
||||
// Optional. Default: DecryptCookie
|
||||
Decryptor func(encryptedString, key string) (string, error)
|
||||
}
|
||||
```
|
76
middleware/encryptcookie/config.go
Normal file
76
middleware/encryptcookie/config.go
Normal file
@ -0,0 +1,76 @@
|
||||
package encryptcookie
|
||||
|
||||
import "github.com/gofiber/fiber/v2"
|
||||
|
||||
// Config defines the config for middleware.
|
||||
type Config struct {
|
||||
// Next defines a function to skip this middleware when returned true.
|
||||
//
|
||||
// Optional. Default: nil
|
||||
Next func(c *fiber.Ctx) bool
|
||||
|
||||
// Array of cookie keys that should not be encrypted.
|
||||
//
|
||||
// Optional. Default: []
|
||||
Except []string
|
||||
|
||||
// Base64 encoded unique key to encode & decode cookies.
|
||||
//
|
||||
// Required. Key length should be 32 characters.
|
||||
// You may use `encryptcookie.GenerateKey()` to generate a new key.
|
||||
Key string
|
||||
|
||||
// Custom function to encrypt cookies.
|
||||
//
|
||||
// Optional. Default: EncryptCookie
|
||||
Encryptor func(decryptedString, key string) (string, error)
|
||||
|
||||
// Custom function to decrypt cookies.
|
||||
//
|
||||
// Optional. Default: DecryptCookie
|
||||
Decryptor func(encryptedString, key string) (string, error)
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
Next: nil,
|
||||
Except: make([]string, 0),
|
||||
Key: "",
|
||||
Encryptor: EncryptCookie,
|
||||
Decryptor: DecryptCookie,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
func configDefault(config ...Config) Config {
|
||||
// Set default config
|
||||
cfg := ConfigDefault
|
||||
|
||||
// Override config if provided
|
||||
if len(config) > 0 {
|
||||
cfg = config[0]
|
||||
|
||||
// Set default values
|
||||
|
||||
if cfg.Next == nil {
|
||||
cfg.Next = ConfigDefault.Next
|
||||
}
|
||||
|
||||
if cfg.Except == nil {
|
||||
cfg.Except = ConfigDefault.Except
|
||||
}
|
||||
|
||||
if cfg.Encryptor == nil {
|
||||
cfg.Encryptor = ConfigDefault.Encryptor
|
||||
}
|
||||
|
||||
if cfg.Decryptor == nil {
|
||||
cfg.Decryptor = ConfigDefault.Decryptor
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Key == "" {
|
||||
panic("fiber: encrypt cookie middleware requires key")
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
56
middleware/encryptcookie/encryptcookie.go
Normal file
56
middleware/encryptcookie/encryptcookie.go
Normal file
@ -0,0 +1,56 @@
|
||||
package encryptcookie
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
cfg := configDefault(config...)
|
||||
|
||||
// Return new handler
|
||||
return func(c *fiber.Ctx) error {
|
||||
// Don't execute middleware if Next returns true
|
||||
if cfg.Next != nil && cfg.Next(c) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Decrypt request cookies
|
||||
c.Request().Header.VisitAllCookie(func(key, value []byte) {
|
||||
keyString := string(key)
|
||||
if !isDisabled(keyString, cfg.Except) {
|
||||
decryptedValue, err := cfg.Decryptor(string(value), cfg.Key)
|
||||
if err != nil {
|
||||
c.Request().Header.SetCookieBytesKV(key, nil)
|
||||
} else {
|
||||
c.Request().Header.SetCookie(string(key), decryptedValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Continue stack
|
||||
err := c.Next()
|
||||
|
||||
// Encrypt response cookies
|
||||
c.Response().Header.VisitAllCookie(func(key, value []byte) {
|
||||
keyString := string(key)
|
||||
if !isDisabled(keyString, cfg.Except) {
|
||||
cookieValue := fasthttp.Cookie{}
|
||||
cookieValue.SetKeyBytes(key)
|
||||
if c.Response().Header.Cookie(&cookieValue) {
|
||||
encryptedValue, err := cfg.Encryptor(string(cookieValue.Value()), cfg.Key)
|
||||
if err == nil {
|
||||
cookieValue.SetValue(encryptedValue)
|
||||
c.Response().Header.SetCookie(&cookieValue)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
184
middleware/encryptcookie/encryptcookie_test.go
Normal file
184
middleware/encryptcookie/encryptcookie_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
package encryptcookie
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var testKey = GenerateKey()
|
||||
|
||||
func Test_Middleware_Encrypt_Cookie(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("value=" + c.Cookies("test"))
|
||||
})
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
// Test empty cookie
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
|
||||
// Test invalid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetCookie("test", "Invalid")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
ctx.Request.Header.SetCookie("test", "ixQURE2XOyZUs0WAOh2ehjWcP7oZb07JvnhWOsmeNUhPsj4+RyI=")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=", string(ctx.Response.Body()))
|
||||
|
||||
// Test valid cookie
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Next(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Next: func(_ *fiber.Ctx) bool {
|
||||
return true
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
return nil
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "SomeThing", resp.Cookies()[0].Value)
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Except(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Except: []string{
|
||||
"test1",
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test1",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test2",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
rawCookie := fasthttp.Cookie{}
|
||||
rawCookie.SetKey("test1")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&rawCookie), "Get cookie value")
|
||||
utils.AssertEqual(t, "SomeThing", string(rawCookie.Value()))
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test2")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decryptedCookieValue, _ := DecryptCookie(string(encryptedCookie.Value()), testKey)
|
||||
utils.AssertEqual(t, "SomeThing", decryptedCookieValue)
|
||||
}
|
||||
|
||||
func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Key: testKey,
|
||||
Encryptor: func(decryptedString, _ string) (string, error) {
|
||||
return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil
|
||||
},
|
||||
Decryptor: func(encryptedString, _ string) (string, error) {
|
||||
decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString)
|
||||
return string(decodedBytes), err
|
||||
},
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendString("value=" + c.Cookies("test"))
|
||||
})
|
||||
app.Post("/", func(c *fiber.Ctx) error {
|
||||
c.Cookie(&fiber.Cookie{
|
||||
Name: "test",
|
||||
Value: "SomeThing",
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
|
||||
encryptedCookie := fasthttp.Cookie{}
|
||||
encryptedCookie.SetKey("test")
|
||||
utils.AssertEqual(t, true, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value")
|
||||
decodedBytes, _ := base64.StdEncoding.DecodeString(string(encryptedCookie.Value()))
|
||||
utils.AssertEqual(t, "SomeThing", string(decodedBytes))
|
||||
|
||||
ctx = &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value()))
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "value=SomeThing", string(ctx.Response.Body()))
|
||||
}
|
88
middleware/encryptcookie/utils.go
Normal file
88
middleware/encryptcookie/utils.go
Normal file
@ -0,0 +1,88 @@
|
||||
package encryptcookie
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncryptCookie Encrypts a cookie value with specific encryption key
|
||||
func EncryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
||||
plaintext := []byte(value)
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptCookie Decrypts a cookie value with specific encryption key
|
||||
func DecryptCookie(value, key string) (string, error) {
|
||||
keyDecoded, _ := base64.StdEncoding.DecodeString(key)
|
||||
enc, _ := base64.StdEncoding.DecodeString(value)
|
||||
|
||||
block, err := aes.NewCipher(keyDecoded)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
|
||||
if len(enc) < nonceSize {
|
||||
return "", errors.New("encrypted value is not valid")
|
||||
}
|
||||
|
||||
nonce, ciphertext := enc[:nonceSize], enc[nonceSize:]
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GenerateKey Generates an encryption key
|
||||
func GenerateKey() string {
|
||||
ret := make([]byte, 32)
|
||||
|
||||
if _, err := rand.Read(ret); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ret)
|
||||
}
|
||||
|
||||
// Check given cookie key is disabled for encryption or not
|
||||
func isDisabled(key string, except []string) bool {
|
||||
for _, k := range except {
|
||||
if key == k {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user