mirror of
https://github.com/H0llyW00dzZ/fiber2fa.git
synced 2025-02-06 11:02:32 +00:00
H0llyW00dzZ
5acecf711a
* Fix Default Time - [+] refactor(tests): update tests to use default TimeSource and Hash from config - [+] fix(config): change default TimeSource to use UTC time - [+] fix(otpverifier): update TOTPTime to return UTC time - [+] test(otpverifier): add test for AdjustSyncWindow function - [+] refactor(middleware): remove unused options from TOTP verifier config * Docs [pkg.go.dev] Update Documentation - [+] docs(otpverifier): clarify TOTPTime function documentation - [+] The TOTPTime function documentation is updated to provide a clearer explanation of the time zone used and the format of the returned time. It now specifies that the South Pole time zone is used and links to the relevant Wikipedia article for more information. The note is also updated to clarify that the returned time is always expressed in UTC to avoid ambiguity.
1579 lines
44 KiB
Go
1579 lines
44 KiB
Go
// Copyright (c) 2024 H0llyW00dz All rights reserved.
|
|
//
|
|
// License: BSD 3-Clause License
|
|
|
|
package twofa_test
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
"image/color"
|
|
"image/png"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
twofa "github.com/H0llyW00dzZ/fiber2fa"
|
|
otp "github.com/H0llyW00dzZ/fiber2fa/internal/otpverifier"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/storage/memory/v2"
|
|
"github.com/google/uuid"
|
|
"github.com/skip2/go-qrcode"
|
|
"github.com/xlzd/gotp"
|
|
)
|
|
|
|
// TestInfo_GetSecret tests the GetSecret method of the Info struct.
|
|
func TestInfo_GetSecret(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
info := twofa.Info{
|
|
Secret: secret,
|
|
}
|
|
|
|
if got := info.GetSecret(); got != secret {
|
|
t.Errorf("Info.GetSecret() = %v, want %v", got, secret)
|
|
}
|
|
}
|
|
|
|
// TestInfo_GetSetCookieValue tests the GetCookieValue and SetCookieValue methods of the Info struct.
|
|
func TestInfo_GetSetCookieValue(t *testing.T) {
|
|
cookieValue := "cookie_value"
|
|
info := twofa.Info{}
|
|
info.SetCookieValue(cookieValue)
|
|
|
|
if got := info.GetCookieValue(); got != cookieValue {
|
|
t.Errorf("Info.GetCookieValue() = %v, want %v", got, cookieValue)
|
|
}
|
|
}
|
|
|
|
// TestInfo_GetSetExpirationTime tests the GetExpirationTime and SetExpirationTime methods of the Info struct.
|
|
func TestInfo_GetSetExpirationTime(t *testing.T) {
|
|
expirationTime := time.Now().Add(24 * time.Hour)
|
|
info := twofa.Info{}
|
|
info.SetExpirationTime(expirationTime)
|
|
|
|
if got := info.GetExpirationTime(); !got.Equal(expirationTime) {
|
|
t.Errorf("Info.GetExpirationTime() = %v, want %v", got, expirationTime)
|
|
}
|
|
}
|
|
|
|
// TestInfo_IsSetRegistered tests the IsRegistered and SetRegistered methods of the Info struct.
|
|
func TestInfo_IsSetRegistered(t *testing.T) {
|
|
info := twofa.Info{}
|
|
|
|
// Test initial registration status
|
|
if info.IsRegistered() {
|
|
t.Error("Info.IsRegistered() = true, want false")
|
|
}
|
|
|
|
// Set registration status to true
|
|
info.SetRegistered(true)
|
|
if !info.IsRegistered() {
|
|
t.Error("Info.IsRegistered() = false, want true")
|
|
}
|
|
|
|
// Set registration status back to false
|
|
info.SetRegistered(false)
|
|
if info.IsRegistered() {
|
|
t.Error("Info.IsRegistered() = true, want false")
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetContextKey tests the SetContextKey method of the Info struct.
|
|
func TestInfo_SetContextKey(t *testing.T) {
|
|
info := twofa.Info{}
|
|
contextKey := "user_id"
|
|
|
|
info.SetContextKey(contextKey)
|
|
|
|
if info.ContextKey != contextKey {
|
|
t.Errorf("Info.ContextKey = %v, want %v", info.ContextKey, contextKey)
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetSecret tests the SetSecret method of the Info struct.
|
|
func TestInfo_SetSecret(t *testing.T) {
|
|
info := twofa.Info{}
|
|
secret := gotp.RandomSecret(16)
|
|
|
|
info.SetSecret(secret)
|
|
|
|
if info.Secret != secret {
|
|
t.Errorf("Info.Secret = %v, want %v", info.Secret, secret)
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetIdentifier_ValidIdentifier tests setting a valid identifier value.
|
|
func TestInfo_SetIdentifier_ValidIdentifier(t *testing.T) {
|
|
info := twofa.Info{}
|
|
validIdentifier := "123e4567-e89b-12d3-a456-426655440000"
|
|
|
|
info.SetIdentifier(validIdentifier)
|
|
|
|
if info.GetIdentifier() != validIdentifier {
|
|
t.Errorf("Info.GetIdentifier() = %v, want %v", info.GetIdentifier(), validIdentifier)
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetIdentifier_GoogleUUID tests setting an identifier generated from the google/uuid package.
|
|
func TestInfo_SetIdentifier_GoogleUUID(t *testing.T) {
|
|
info := twofa.Info{}
|
|
googleUUID := uuid.New().String()
|
|
|
|
info.SetIdentifier(googleUUID)
|
|
|
|
if info.GetIdentifier() != googleUUID {
|
|
t.Errorf("Info.GetIdentifier() = %v, want %v", info.GetIdentifier(), googleUUID)
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetIdentifier_EmptyIdentifier tests setting an empty identifier value.
|
|
func TestInfo_SetIdentifier_EmptyIdentifier(t *testing.T) {
|
|
info := twofa.Info{}
|
|
emptyIdentifier := ""
|
|
|
|
info.SetIdentifier(emptyIdentifier)
|
|
|
|
if info.GetIdentifier() != emptyIdentifier {
|
|
t.Errorf("Info.GetIdentifier() = %v, want %v", info.GetIdentifier(), emptyIdentifier)
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetIdentifier_OverwriteIdentifier tests overwriting an existing identifier value.
|
|
func TestInfo_SetIdentifier_OverwriteIdentifier(t *testing.T) {
|
|
info := twofa.Info{}
|
|
initialIdentifier := "123e4567-e89b-12d3-a456-426655440000"
|
|
newIdentifier := uuid.New().String()
|
|
|
|
info.SetIdentifier(initialIdentifier)
|
|
info.SetIdentifier(newIdentifier)
|
|
|
|
if info.GetIdentifier() != newIdentifier {
|
|
t.Errorf("Info.GetIdentifier() = %v, want %v", info.GetIdentifier(), newIdentifier)
|
|
}
|
|
}
|
|
|
|
// TestInfo_GetIdentifier_InitialValue tests getting the initial identifier value.
|
|
func TestInfo_GetIdentifier_InitialValue(t *testing.T) {
|
|
info := twofa.Info{}
|
|
|
|
if info.GetIdentifier() != "" {
|
|
t.Errorf("Info.GetIdentifier() = %v, want empty string", info.GetIdentifier())
|
|
}
|
|
}
|
|
|
|
// TestInfo_SetGetQRCodeData tests setting and getting the QRCode data image.
|
|
func TestInfo_SetGetQRCodeData(t *testing.T) {
|
|
info := twofa.Info{}
|
|
|
|
// Create a sample QRCode image
|
|
qrCodeImage := createSampleQRCodeImage()
|
|
|
|
// Encode the QRCode image to PNG format
|
|
var buf bytes.Buffer
|
|
err := png.Encode(&buf, qrCodeImage)
|
|
if err != nil {
|
|
t.Fatalf("Failed to encode QRCode image: %v", err)
|
|
}
|
|
qrCodeData := buf.Bytes()
|
|
|
|
// Set the QRCode data
|
|
info.SetQRCodeData(qrCodeData)
|
|
|
|
// Get the QRCode data
|
|
retrievedQRCodeData := info.GetQRCodeData()
|
|
|
|
// Compare the retrieved QRCode data with the original data
|
|
if !bytes.Equal(retrievedQRCodeData, qrCodeData) {
|
|
t.Error("Retrieved QRCode data does not match the original data")
|
|
}
|
|
|
|
// Decode the retrieved QRCode data back to an image
|
|
retrievedQRCodeImage, _, err := image.Decode(bytes.NewReader(retrievedQRCodeData))
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode retrieved QRCode data: %v", err)
|
|
}
|
|
|
|
// Compare the dimensions of the retrieved QRCode image with the original image
|
|
if retrievedQRCodeImage.Bounds() != qrCodeImage.Bounds() {
|
|
t.Error("Retrieved QRCode image dimensions do not match the original image")
|
|
}
|
|
}
|
|
|
|
// createSampleQRCodeImage creates a sample QRCode image for testing purposes.
|
|
func createSampleQRCodeImage() image.Image {
|
|
// Create a sample QRCode image
|
|
width := 200
|
|
height := 200
|
|
qrCodeImage := image.NewRGBA(image.Rect(0, 0, width, height))
|
|
|
|
// Fill the image with sample data
|
|
for y := 0; y < height; y++ {
|
|
for x := 0; x < width; x++ {
|
|
qrCodeImage.Set(x, y, image.Black)
|
|
}
|
|
}
|
|
|
|
return qrCodeImage
|
|
}
|
|
|
|
func TestMiddleware_Handle(t *testing.T) {
|
|
// Set up the storage with an in-memory store for simplicity
|
|
store := memory.New()
|
|
secret := gotp.RandomSecret(16)
|
|
|
|
// Create a default Info struct and store it for Simulate State
|
|
info := twofa.Info{
|
|
Secret: secret,
|
|
CookieValue: "",
|
|
ExpirationTime: time.Time{},
|
|
}
|
|
infoJSON, _ := json.Marshal(info)
|
|
_ = store.Set("user123", infoJSON, 0) // Ignoring error for brevity
|
|
|
|
// Define a middleware instance with default configuration
|
|
middleware := twofa.New(twofa.Config{
|
|
Secret: secret,
|
|
Storage: store,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
ContextKey: "user123",
|
|
RedirectURL: "/2fa",
|
|
CookieMaxAge: 86400,
|
|
CookieName: "twofa_cookie",
|
|
TokenLookup: "header:Authorization,query:token,form:token,param:token,cookie:token",
|
|
})
|
|
|
|
// Create a new Fiber app and register the middleware
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user123", "user123")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
|
|
// Define routes that will be used for testing
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
// Generate a valid 2FA token
|
|
totp := otp.NewTOTPVerifier(otp.Config{
|
|
Secret: info.Secret,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
})
|
|
|
|
validToken := totp.GenerateToken()
|
|
totp.Verify(validToken)
|
|
|
|
// Create a separate instance of the Middleware struct for testing
|
|
testMiddleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
Secret: secret,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
},
|
|
}
|
|
|
|
// Define test cases
|
|
testCases := []struct {
|
|
name string
|
|
requestURL string
|
|
requestMethod string
|
|
requestBody io.Reader
|
|
requestHeaders map[string]string
|
|
requestCookies []*http.Cookie
|
|
expectedStatus int
|
|
expectedLocation string
|
|
expectedBody string
|
|
setupFunc func()
|
|
}{
|
|
{
|
|
name: "GET request without token",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusFound,
|
|
expectedLocation: "/2fa",
|
|
},
|
|
{
|
|
name: "GET request with valid token in query parameter",
|
|
requestURL: fmt.Sprintf("https://hack/?token=%s", validToken),
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid token in header",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", validToken),
|
|
},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "POST request with valid token in form data",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "POST",
|
|
requestBody: strings.NewReader(fmt.Sprintf("token=%s", validToken)),
|
|
requestHeaders: map[string]string{"Content-Type": "application/x-www-form-urlencoded"},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid token in cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{{Name: "token", Value: validToken}},
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: testMiddleware.GenerateCookieValue(time.Now().Add(time.Hour)),
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with invalid cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: "invalid_cookie_value",
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusFound,
|
|
expectedLocation: "/2fa",
|
|
},
|
|
{
|
|
name: "Invalid 2FA token",
|
|
requestURL: "https://hack/?token=invalid_token",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusUnauthorized,
|
|
expectedBody: "Invalid 2FA token",
|
|
},
|
|
// Add more test cases as needed
|
|
}
|
|
|
|
// Run subtests in parallel
|
|
for _, tc := range testCases {
|
|
tc := tc // Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Create a new HTTP request
|
|
req := httptest.NewRequest(tc.requestMethod, tc.requestURL, tc.requestBody)
|
|
for key, value := range tc.requestHeaders {
|
|
req.Header.Set(key, value)
|
|
}
|
|
for _, cookie := range tc.requestCookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Perform the request
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to perform request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Check the response status code
|
|
if resp.StatusCode != tc.expectedStatus {
|
|
t.Logf("Request: %s %s", req.Method, req.URL)
|
|
t.Logf("Response: %d", resp.StatusCode)
|
|
t.Errorf("Expected status code %d, but got %d", tc.expectedStatus, resp.StatusCode)
|
|
}
|
|
|
|
// Check the response location header if expectedLocation is set
|
|
if tc.expectedLocation != "" {
|
|
location := resp.Header.Get("Location")
|
|
if location != tc.expectedLocation {
|
|
t.Errorf("Expected location header %q, but got %q", tc.expectedLocation, location)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func customLogger(t *testing.T) fiber.Handler {
|
|
return func(c *fiber.Ctx) error {
|
|
// Log the request
|
|
t.Logf("Request: %s %s", c.Method(), c.OriginalURL())
|
|
|
|
// Continue with the middleware chain
|
|
err := c.Next()
|
|
|
|
// After continuing with the middleware, log the response
|
|
t.Logf("Response: %d", c.Response().StatusCode())
|
|
|
|
return err
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_SkipNext(t *testing.T) {
|
|
middleware := twofa.New(twofa.Config{
|
|
Next: func(c *fiber.Ctx) bool {
|
|
return true // Always skip the middleware
|
|
},
|
|
})
|
|
|
|
app := fiber.New()
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("Middleware skipped")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "https://hack/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to send request: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read response body: %v", err)
|
|
}
|
|
if string(body) != "Middleware skipped" {
|
|
t.Errorf("Expected body to be 'Middleware skipped', got '%s'", string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_SkipCookies(t *testing.T) {
|
|
middleware := twofa.New(twofa.Config{
|
|
SkipCookies: []string{"/skip"},
|
|
})
|
|
|
|
app := fiber.New()
|
|
app.Use(middleware)
|
|
app.Get("/skip", func(c *fiber.Ctx) error {
|
|
return c.SendString("Path skipped")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "https://hack/skip", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to send request: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read response body: %v", err)
|
|
}
|
|
if string(body) != "Path skipped" {
|
|
t.Errorf("Expected body to be 'Path skipped', got '%s'", string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_Handle_StorageGetFail(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Use a custom storage that fails on Get operation
|
|
store := &failingStorage{
|
|
Storage: memory.New(),
|
|
}
|
|
|
|
middleware := twofa.New(twofa.Config{
|
|
Storage: store,
|
|
ContextKey: "user123x",
|
|
Secret: secret,
|
|
})
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user123x", "user123")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("Should not get here")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "https://hack/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusInternalServerError {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusInternalServerError, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
if string(body) != "failed to retrieve 2FA information" {
|
|
t.Errorf("Expected error message 'failed to retrieve 2FA information', got '%s'", string(body))
|
|
}
|
|
}
|
|
|
|
// failingStorage is a custom storage that fails on Get operation
|
|
type failingStorage struct {
|
|
*memory.Storage
|
|
}
|
|
|
|
func (s *failingStorage) Get(key string) ([]byte, error) {
|
|
return nil, fmt.Errorf("storage get error")
|
|
}
|
|
|
|
func TestMiddleware_Handle_InfoNotFoundInStorage(t *testing.T) {
|
|
store := memory.New()
|
|
middleware := twofa.New(twofa.Config{
|
|
Storage: store,
|
|
ContextKey: "user123x",
|
|
RedirectURL: "/2fa",
|
|
CookieName: "twofa_cookie",
|
|
})
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals("user123x", "user123")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("Should not get here")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "https://hack/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d for missing 2FA info, but got %d", fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
if string(body) != "2FA information not found" {
|
|
t.Errorf("Expected error message '2FA information not found', got '%s'", string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_GenerateQRcodePath(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Create a new Fiber app
|
|
app := fiber.New()
|
|
|
|
// Create an in-memory storage
|
|
storage := memory.New()
|
|
|
|
// Create a new Middleware instance with a custom ContextKey, Issuer, and JSONUnmarshal
|
|
middleware := twofa.New(twofa.Config{
|
|
ContextKey: "accountName",
|
|
Issuer: "MyApp",
|
|
Secret: secret,
|
|
Storage: storage,
|
|
JSONMarshal: json.Marshal, // Set the JSONMarshal field
|
|
JSONUnmarshal: json.Unmarshal, // Set the JSONUnmarshal field
|
|
Encode: twofa.EncodeConfig{
|
|
Level: qrcode.Medium,
|
|
Size: 256,
|
|
},
|
|
QRCode: twofa.QRCodeConfig{
|
|
Content: "otpauth://totp/%s:%s?secret=%s&issuer=%s",
|
|
PathTemplate: "/test",
|
|
},
|
|
})
|
|
|
|
// Store the 2FA information in the storage for the test account
|
|
info := &twofa.Info{
|
|
Secret: secret,
|
|
}
|
|
rawInfo, _ := json.Marshal(info)
|
|
storage.Set("gopher@example.com", rawInfo, 0)
|
|
|
|
app.Use("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("accountName", "gopher@example.com")
|
|
return c.Next()
|
|
})
|
|
|
|
// Use the 2FA middleware
|
|
app.Use(middleware)
|
|
|
|
// Send a test request to the "/test" route
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
// Check if the response status code is 200 OK
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
// Check if the response content type is "image/png"
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType != "image/png" {
|
|
t.Errorf("Expected content type 'image/png', got '%s'", contentType)
|
|
}
|
|
|
|
// Read the response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
// Decode the response body as a PNG image
|
|
img, err := png.Decode(bytes.NewReader(body))
|
|
if err != nil {
|
|
t.Errorf("Error decoding response body as PNG: %v", err)
|
|
}
|
|
|
|
// Check if the decoded image has the expected dimensions
|
|
expectedWidth := 256
|
|
expectedHeight := 256
|
|
if img.Bounds().Dx() != expectedWidth || img.Bounds().Dy() != expectedHeight {
|
|
t.Errorf("Expected image dimensions %dx%d, got %dx%d", expectedWidth, expectedHeight, img.Bounds().Dx(), img.Bounds().Dy())
|
|
}
|
|
}
|
|
func TestMiddleware_GenerateQRcodePathAlreadyRegistered(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Create a new Fiber app
|
|
app := fiber.New()
|
|
|
|
// Create an in-memory storage
|
|
storage := memory.New()
|
|
|
|
// Create a new Middleware instance with a custom ContextKey, Issuer, and JSONUnmarshal
|
|
middleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
ContextKey: "accountName",
|
|
Issuer: "MyApp",
|
|
Secret: secret,
|
|
Storage: storage,
|
|
JSONMarshal: json.Marshal, // Set the JSONMarshal field
|
|
JSONUnmarshal: json.Unmarshal, // Set the JSONUnmarshal field
|
|
Encode: twofa.EncodeConfig{
|
|
Level: qrcode.Medium,
|
|
Size: 256,
|
|
},
|
|
QRCode: twofa.QRCodeConfig{
|
|
Content: "otpauth://totp/%s:%s?secret=%s&issuer=%s",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Store the 2FA information in the storage for the test account
|
|
info := &twofa.Info{
|
|
Secret: secret,
|
|
Registered: true,
|
|
}
|
|
rawInfo, _ := middleware.Config.JSONMarshal(info)
|
|
storage.Set("gopher@example.com", rawInfo, 0)
|
|
|
|
// Define a test handler that sets the account name in c.Locals and calls GenerateQRcodePath
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("accountName", "gopher@example.com")
|
|
return middleware.GenerateQRcodePath(c)
|
|
})
|
|
|
|
// Send a test request to the "/test" route
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
// Check if the response status code is 200 OK
|
|
if resp.StatusCode != http.StatusNotFound {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_GenerateQRcodePathWithVersion(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Create a new Fiber app
|
|
app := fiber.New()
|
|
|
|
// Create an in-memory storage
|
|
storage := memory.New()
|
|
|
|
// Create a new Middleware instance with a custom ContextKey, Issuer, and JSONUnmarshal
|
|
middleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
ContextKey: "accountName",
|
|
Issuer: "MyApp",
|
|
Secret: secret,
|
|
Storage: storage,
|
|
JSONMarshal: json.Marshal, // Set the JSONMarshal field
|
|
JSONUnmarshal: json.Unmarshal, // Set the JSONUnmarshal field
|
|
Encode: twofa.EncodeConfig{
|
|
Level: qrcode.Medium,
|
|
Size: 256,
|
|
VersionNumber: 5, // Set the desired version number
|
|
},
|
|
QRCode: twofa.QRCodeConfig{
|
|
Content: "otpauth://totp/%s:%s?secret=%s&issuer=%s",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Store the 2FA information in the storage for the test account
|
|
info := &twofa.Info{
|
|
Secret: secret,
|
|
}
|
|
rawInfo, _ := middleware.Config.JSONMarshal(info)
|
|
storage.Set("gopher@example.com", rawInfo, 0)
|
|
|
|
// Define a test handler that sets the account name in c.Locals and calls GenerateQRcodePath
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("accountName", "gopher@example.com")
|
|
return middleware.GenerateQRcodePath(c)
|
|
})
|
|
|
|
// Send a test request to the "/test" route
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
// Check if the response status code is 200 OK
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
// Check if the response content type is "image/png"
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType != "image/png" {
|
|
t.Errorf("Expected content type 'image/png', got '%s'", contentType)
|
|
}
|
|
|
|
// Read the response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
// Decode the response body as a PNG image
|
|
img, err := png.Decode(bytes.NewReader(body))
|
|
if err != nil {
|
|
t.Errorf("Error decoding response body as PNG: %v", err)
|
|
}
|
|
|
|
// Check if the decoded image has the expected dimensions
|
|
expectedWidth := 256
|
|
expectedHeight := 256
|
|
if img.Bounds().Dx() != expectedWidth || img.Bounds().Dy() != expectedHeight {
|
|
t.Errorf("Expected image dimensions %dx%d, got %dx%d", expectedWidth, expectedHeight, img.Bounds().Dx(), img.Bounds().Dy())
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_GenerateQRcodePath_CustomImage(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Create a new Fiber app
|
|
app := fiber.New()
|
|
|
|
// Create an in-memory storage
|
|
storage := memory.New()
|
|
|
|
// Create a custom QR code image
|
|
customImage := image.NewRGBA(image.Rect(0, 0, 100, 100))
|
|
// Fill the custom image with some color (e.g., red)
|
|
for i := 0; i < 100; i++ {
|
|
for j := 0; j < 100; j++ {
|
|
customImage.Set(i, j, color.RGBA{255, 0, 0, 255})
|
|
}
|
|
}
|
|
|
|
// Create a new Middleware instance with a custom ContextKey, Issuer, JSONMarshal, JSONUnmarshal, and QRcodeImage
|
|
middleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
ContextKey: "accountName",
|
|
Issuer: "MyApp",
|
|
Secret: secret,
|
|
Storage: storage,
|
|
JSONMarshal: json.Marshal, // Set the JSONMarshal field
|
|
JSONUnmarshal: json.Unmarshal, // Set the JSONUnmarshal field
|
|
QRCode: twofa.QRCodeConfig{
|
|
Image: customImage, // Set the custom QR code image
|
|
Content: "otpauth://totp/%s:%s?secret=%s&issuer=%s",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Store the 2FA information in the storage for the test account
|
|
// Note: This info manager is useful for writing tests during open-source development because Go has a rich ecosystem and tooling unlike other language mostly is poor.
|
|
// It eliminates the need to spend money on renting a database solely for testing purposes.
|
|
info := &twofa.Info{
|
|
Secret: secret,
|
|
}
|
|
rawInfo, err := middleware.Config.JSONMarshal(info)
|
|
if err != nil {
|
|
t.Fatalf("Error marshaling 2FA information: %v", err)
|
|
}
|
|
storage.Set("gopher@example.com", rawInfo, 0)
|
|
|
|
// Define a test handler that sets the account name in c.Locals and calls GenerateQRcodePath
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("accountName", "gopher@example.com")
|
|
return middleware.GenerateQRcodePath(c)
|
|
})
|
|
|
|
// Send a test request to the "/test" route
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
// Check if the response status code is 200 OK
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
// Check if the response content type is "image/png"
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType != "image/png" {
|
|
t.Errorf("Expected content type 'image/png', got '%s'", contentType)
|
|
}
|
|
|
|
// Read the response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
// Decode the response body as a PNG image
|
|
img, err := png.Decode(bytes.NewReader(body))
|
|
if err != nil {
|
|
t.Errorf("Error decoding response body as PNG: %v", err)
|
|
}
|
|
|
|
// Check if the decoded image matches the custom QR code image
|
|
if !reflect.DeepEqual(img, customImage) {
|
|
t.Error("Decoded image does not match the custom QR code image")
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_SendInternalErrorResponse(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
responseMIME string
|
|
expectedBody string
|
|
}{
|
|
{
|
|
name: "Plain text response",
|
|
responseMIME: fiber.MIMETextPlainCharsetUTF8,
|
|
expectedBody: "ContextKey is not set",
|
|
},
|
|
{
|
|
name: "JSON response",
|
|
responseMIME: fiber.MIMEApplicationJSON,
|
|
expectedBody: "{\"error\":\"ContextKey is not set\"}",
|
|
},
|
|
{
|
|
name: "XML response",
|
|
responseMIME: fiber.MIMEApplicationXML,
|
|
expectedBody: "<error><message>ContextKey is not set</message></error>",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
config := twofa.Config{
|
|
ResponseMIME: tc.responseMIME,
|
|
}
|
|
middleware := twofa.New(config)
|
|
|
|
app := fiber.New()
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.Status(fiber.StatusInternalServerError).SendString("ContextKey is not set")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusInternalServerError {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusInternalServerError, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
if strings.TrimSpace(string(body)) != tc.expectedBody {
|
|
t.Errorf("Expected response body '%s', got '%s'", tc.expectedBody, string(body))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_SendUnauthorizedResponse(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
responseMIME string
|
|
expectedBody string
|
|
}{
|
|
{
|
|
name: "Plain text response",
|
|
responseMIME: fiber.MIMETextPlainCharsetUTF8,
|
|
expectedBody: "2FA information not found",
|
|
},
|
|
{
|
|
name: "JSON response",
|
|
responseMIME: fiber.MIMEApplicationJSON,
|
|
expectedBody: "{\"error\":\"2FA information not found\"}",
|
|
},
|
|
{
|
|
name: "XML response",
|
|
responseMIME: fiber.MIMEApplicationXML,
|
|
expectedBody: "<error><message>2FA information not found</message></error>",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
store := memory.New()
|
|
secret := gotp.RandomSecret(16)
|
|
config := twofa.Config{
|
|
ResponseMIME: tc.responseMIME,
|
|
Secret: secret,
|
|
ContextKey: "gopher_testing",
|
|
Storage: store,
|
|
}
|
|
middleware := twofa.New(config)
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals(config.ContextKey, "test_context_key")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("OK")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
if strings.TrimSpace(string(body)) != tc.expectedBody {
|
|
t.Errorf("Expected response body '%s', got '%s'", tc.expectedBody, string(body))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_CustomUnauthorizedHandler(t *testing.T) {
|
|
store := memory.New()
|
|
secret := gotp.RandomSecret(16)
|
|
config := twofa.Config{
|
|
Secret: secret,
|
|
ContextKey: "gopher_testing",
|
|
Storage: store,
|
|
UnauthorizedHandler: func(c *fiber.Ctx, err error) error {
|
|
return c.Status(fiber.StatusUnauthorized).SendString("Custom unauthorized handler")
|
|
},
|
|
}
|
|
middleware := twofa.New(config)
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals(config.ContextKey, "test_context_key")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("OK")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
expectedBody := "Custom unauthorized handler"
|
|
if string(body) != expectedBody {
|
|
t.Errorf("Expected response body '%s', got '%s'", expectedBody, string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_CustomInternalErrorHandler(t *testing.T) {
|
|
store := memory.New()
|
|
secret := gotp.RandomSecret(16)
|
|
config := twofa.Config{
|
|
Secret: secret,
|
|
ContextKey: "gopher_testing",
|
|
Storage: store,
|
|
InternalErrorHandler: func(c *fiber.Ctx, err error) error {
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Custom internal error handler")
|
|
},
|
|
}
|
|
middleware := twofa.New(config)
|
|
|
|
app := fiber.New()
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return fiber.NewError(fiber.StatusInternalServerError, "Internal server error")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusInternalServerError {
|
|
t.Errorf("Expected status code %d, got %d", fiber.StatusInternalServerError, resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
expectedBody := "Custom internal error handler"
|
|
if string(body) != expectedBody {
|
|
t.Errorf("Expected response body '%s', got '%s'", expectedBody, string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_GetContextKey(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
contextKey string
|
|
contextValue any
|
|
expectedKey string
|
|
expectedError string
|
|
}{
|
|
{
|
|
name: "Valid context key",
|
|
contextKey: "user_id",
|
|
contextValue: "123",
|
|
expectedKey: "123",
|
|
expectedError: "",
|
|
},
|
|
{
|
|
name: "Empty context key",
|
|
contextKey: "",
|
|
contextValue: nil,
|
|
expectedKey: "",
|
|
expectedError: "ContextKey is not set",
|
|
},
|
|
{
|
|
name: "Context key not set",
|
|
contextKey: "user_id",
|
|
contextValue: nil,
|
|
expectedKey: "",
|
|
expectedError: "ContextKey is not set",
|
|
},
|
|
{
|
|
name: "Invalid context value type",
|
|
contextKey: "user_id",
|
|
contextValue: 123,
|
|
expectedKey: "",
|
|
expectedError: "failed to retrieve context key",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
config := twofa.Config{
|
|
ContextKey: tc.contextKey,
|
|
}
|
|
middleware := twofa.New(config)
|
|
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
if tc.contextValue != nil {
|
|
c.Locals(tc.contextKey, tc.contextValue)
|
|
}
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendString("OK")
|
|
})
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
if resp.StatusCode != fiber.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if tc.expectedError != "" && !strings.Contains(string(body), tc.expectedError) {
|
|
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, string(body))
|
|
}
|
|
} else {
|
|
if tc.expectedError != "" {
|
|
t.Errorf("Expected error '%s', but got no error", tc.expectedError)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMiddleware_GenerateQRcodePath_Error(t *testing.T) {
|
|
secret := gotp.RandomSecret(16)
|
|
// Create a new Fiber app
|
|
app := fiber.New()
|
|
|
|
// Create an in-memory storage
|
|
storage := memory.New()
|
|
|
|
// Create a new Middleware instance with a custom ContextKey, Issuer, and JSONUnmarshal
|
|
middleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
ContextKey: "accountName",
|
|
Issuer: "MyApp",
|
|
Secret: secret,
|
|
Storage: storage,
|
|
JSONMarshal: json.Marshal, // Set the JSONMarshal field
|
|
JSONUnmarshal: json.Unmarshal, // Set the JSONUnmarshal field
|
|
Encode: twofa.EncodeConfig{
|
|
Level: qrcode.Medium,
|
|
Size: 256,
|
|
},
|
|
QRCode: twofa.QRCodeConfig{
|
|
Content: "otpauth://totp/%s:%s?secret=%s&issuer=%s",
|
|
},
|
|
},
|
|
}
|
|
|
|
// Define a test handler that sets an invalid account name in c.Locals and calls GenerateQRcodePath
|
|
app.Get("/test", func(c *fiber.Ctx) error {
|
|
c.Locals("accountName", "invalid@example.com")
|
|
return middleware.GenerateQRcodePath(c)
|
|
})
|
|
|
|
// Send a test request to the "/test" route
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Error when sending request to the app: %v", err)
|
|
}
|
|
|
|
// Check if the response status code is 401 Unauthorized
|
|
if resp.StatusCode != http.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
|
}
|
|
|
|
// Read the response body
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("Error reading response body: %v", err)
|
|
}
|
|
|
|
// Check if the response body contains the expected error message
|
|
expectedErrorMessage := "2FA information not found"
|
|
if !strings.Contains(string(body), expectedErrorMessage) {
|
|
t.Errorf("Expected error message '%s', got '%s'", expectedErrorMessage, string(body))
|
|
}
|
|
}
|
|
|
|
func TestMiddlewareUUIDContextKey_Handle(t *testing.T) {
|
|
// Set up the storage with an in-memory store for simplicity
|
|
store := memory.New()
|
|
secret := gotp.RandomSecret(16)
|
|
|
|
// Generate a UUID for the context key
|
|
contextKey := uuid.New().String()
|
|
|
|
// Create a default Info struct and store it for Simulate State
|
|
info := twofa.Info{
|
|
Secret: secret,
|
|
CookieValue: "",
|
|
ExpirationTime: time.Time{},
|
|
}
|
|
infoJSON, _ := json.Marshal(info)
|
|
_ = store.Set("gopher@example.com", infoJSON, 0) // Ignoring error for brevity
|
|
|
|
// Define a middleware instance with default configuration
|
|
middleware := twofa.New(twofa.Config{
|
|
Secret: secret,
|
|
Storage: store,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
ContextKey: contextKey,
|
|
RedirectURL: "/2fa",
|
|
CookieMaxAge: 86400,
|
|
CookieName: "twofa_cookie",
|
|
TokenLookup: "header:Authorization,query:token,form:token,param:token,cookie:token",
|
|
})
|
|
|
|
// Create a new Fiber app and register the middleware
|
|
app := fiber.New()
|
|
app.Use(func(c *fiber.Ctx) error {
|
|
c.Locals(contextKey, "gopher@example.com")
|
|
return c.Next()
|
|
})
|
|
app.Use(middleware)
|
|
|
|
// Define routes that will be used for testing
|
|
app.Get("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
app.Post("/", func(c *fiber.Ctx) error {
|
|
return c.SendStatus(fiber.StatusOK)
|
|
})
|
|
|
|
// Generate a valid 2FA token
|
|
totp := otp.NewTOTPVerifier(otp.Config{
|
|
Secret: info.Secret,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
})
|
|
|
|
validToken := totp.GenerateToken()
|
|
totp.Verify(validToken)
|
|
|
|
// Create a separate instance of the Middleware struct for testing
|
|
testMiddleware := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
Secret: secret,
|
|
TimeSource: twofa.DefaultConfig.TimeSource,
|
|
Hash: twofa.DefaultConfig.Hash,
|
|
},
|
|
}
|
|
|
|
// Generate cryptographically secure random data
|
|
randomData := make([]byte, 16)
|
|
_, err := rand.Read(randomData)
|
|
if err != nil {
|
|
t.Fatalf("Failed to generate random data: %v", err)
|
|
}
|
|
|
|
// Create a UUID using the random data
|
|
randomUUID, err := uuid.FromBytes(randomData)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create UUID from random data: %v", err)
|
|
}
|
|
|
|
// Create a separate instance of the Middleware struct for testing
|
|
testMiddlewareRandomUUID := &twofa.Middleware{
|
|
Config: &twofa.Config{
|
|
Secret: secret,
|
|
ContextKey: randomUUID.String(),
|
|
},
|
|
}
|
|
|
|
// Define test cases
|
|
testCases := []struct {
|
|
name string
|
|
requestURL string
|
|
requestMethod string
|
|
requestBody io.Reader
|
|
requestHeaders map[string]string
|
|
requestCookies []*http.Cookie
|
|
expectedStatus int
|
|
expectedLocation string
|
|
expectedBody string
|
|
setupFunc func()
|
|
}{
|
|
{
|
|
name: "GET request without token",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusFound,
|
|
expectedLocation: "/2fa",
|
|
},
|
|
{
|
|
name: "GET request with valid token in query parameter",
|
|
requestURL: fmt.Sprintf("https://hack/?token=%s", validToken),
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid token in header",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", validToken),
|
|
},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "POST request with valid token in form data",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "POST",
|
|
requestBody: strings.NewReader(fmt.Sprintf("token=%s", validToken)),
|
|
requestHeaders: map[string]string{"Content-Type": "application/x-www-form-urlencoded"},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid token in cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{{Name: "token", Value: validToken}},
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with valid cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: testMiddleware.GenerateCookieValue(time.Now().Add(time.Hour)),
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusOK,
|
|
},
|
|
{
|
|
name: "GET request with invalid cookie",
|
|
requestURL: "https://hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: "invalid_cookie_value",
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusFound,
|
|
expectedLocation: "/2fa",
|
|
},
|
|
{
|
|
name: "Invalid 2FA token",
|
|
requestURL: "https://hack/?token=invalid_token",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusUnauthorized,
|
|
expectedBody: "Invalid 2FA token",
|
|
},
|
|
|
|
{
|
|
name: "GET request with Random UUID value",
|
|
requestURL: fmt.Sprintf("https://rand.uuid.hack/?token=%s", validToken),
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "GET request with valid token in header and Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: map[string]string{
|
|
"Authorization": fmt.Sprintf("Bearer %s", validToken),
|
|
},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "POST request with valid token in form data and Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/",
|
|
requestMethod: "POST",
|
|
requestBody: strings.NewReader(fmt.Sprintf("token=%s", validToken)),
|
|
requestHeaders: map[string]string{"Content-Type": "application/x-www-form-urlencoded"},
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusOK,
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "GET request with valid token in cookie and Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{{Name: "token", Value: validToken}},
|
|
expectedStatus: fiber.StatusOK,
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "GET request with valid cookie and Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: testMiddleware.GenerateCookieValue(time.Now().Add(time.Hour)),
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusOK,
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "GET request with invalid cookie and Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: []*http.Cookie{
|
|
{
|
|
Name: "twofa_cookie",
|
|
Value: "invalid_cookie_value",
|
|
},
|
|
},
|
|
expectedStatus: fiber.StatusFound,
|
|
expectedLocation: "/2fa",
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid 2FA token with Random UUID value",
|
|
requestURL: "https://rand.uuid.hack/?token=invalid_token",
|
|
requestMethod: "GET",
|
|
requestBody: nil,
|
|
requestHeaders: nil,
|
|
requestCookies: nil,
|
|
expectedStatus: fiber.StatusUnauthorized,
|
|
expectedBody: "Invalid 2FA token",
|
|
setupFunc: func() {
|
|
// Update the middleware configuration with the UUID value
|
|
testMiddlewareRandomUUID.Config.ContextKey = randomUUID.String()
|
|
},
|
|
},
|
|
|
|
// Add more test cases as needed
|
|
}
|
|
|
|
// Run subtests in parallel
|
|
for _, tc := range testCases {
|
|
tc := tc // Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Run the setup function if provided
|
|
if tc.setupFunc != nil {
|
|
tc.setupFunc()
|
|
}
|
|
|
|
// Create a new HTTP request
|
|
req := httptest.NewRequest(tc.requestMethod, tc.requestURL, tc.requestBody)
|
|
for key, value := range tc.requestHeaders {
|
|
req.Header.Set(key, value)
|
|
}
|
|
for _, cookie := range tc.requestCookies {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Perform the request
|
|
resp, err := app.Test(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to perform request: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Check the response status code
|
|
if resp.StatusCode != tc.expectedStatus {
|
|
t.Logf("Request: %s %s", req.Method, req.URL)
|
|
t.Logf("Response: %d", resp.StatusCode)
|
|
t.Errorf("Expected status code %d, but got %d", tc.expectedStatus, resp.StatusCode)
|
|
}
|
|
|
|
// Check the response location header if expectedLocation is set
|
|
if tc.expectedLocation != "" {
|
|
location := resp.Header.Get("Location")
|
|
if location != tc.expectedLocation {
|
|
t.Errorf("Expected location header %q, but got %q", tc.expectedLocation, location)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|