diff --git a/middleware/proxy/README.md b/middleware/proxy/README.md index a4be4b96..c3ceee9d 100644 --- a/middleware/proxy/README.md +++ b/middleware/proxy/README.md @@ -31,6 +31,12 @@ import ( After you initiate your Fiber app, you can use the following possibilities: ```go +// if target https site uses a self-signed certificate, you should +// call WithTlsConfig before Do and Forward +proxy.WithTlsConfig(&tls.Config{ + InsecureSkipVerify: true, +}) + // Forward to url app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) @@ -113,6 +119,9 @@ type Config struct { // Per-connection buffer size for responses' writing. WriteBufferSize int + + // tls config for the http client + TlsConfig *tls.Config } ``` @@ -121,6 +130,9 @@ type Config struct { ```go // ConfigDefault is the default config var ConfigDefault = Config{ - Next: nil, + Next: nil, + ModifyRequest: nil, + ModifyResponse: nil, + Timeout: fasthttp.DefaultLBClientTimeout, } ``` diff --git a/middleware/proxy/config.go b/middleware/proxy/config.go index 73d39c59..80b64725 100644 --- a/middleware/proxy/config.go +++ b/middleware/proxy/config.go @@ -1,6 +1,7 @@ package proxy import ( + "crypto/tls" "time" "github.com/gofiber/fiber/v2" @@ -45,6 +46,9 @@ type Config struct { // Per-connection buffer size for responses' writing. WriteBufferSize int + + // tls config for the http client + TlsConfig *tls.Config } // ConfigDefault is the default config diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 46c96436..09d91d94 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -1,13 +1,13 @@ package proxy import ( + "crypto/tls" "fmt" - "net/url" - "strings" - "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp" + "net/url" + "strings" ) // New is deprecated @@ -45,6 +45,8 @@ func Balancer(config Config) fiber.Handler { ReadBufferSize: config.ReadBufferSize, WriteBufferSize: config.WriteBufferSize, + + TLSConfig: config.TlsConfig, } lbc.Clients = append(lbc.Clients, client) @@ -98,6 +100,12 @@ var client = fasthttp.Client{ DisablePathNormalizing: true, } +// WithTlsConfig update http client with a user specified tls.config +// This function should be called before Do and Forward. +func WithTlsConfig(tlsConfig *tls.Config) { + client.TLSConfig = tlsConfig +} + // Forward performs the given http request and fills the given http response. // This method will return an fiber.Handler func Forward(addr string) fiber.Handler { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index c93da036..c998f302 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -1,6 +1,8 @@ package proxy import ( + "crypto/tls" + "github.com/gofiber/fiber/v2/internal/tlstest" "io/ioutil" "net" "net/http/httptest" @@ -82,6 +84,42 @@ func Test_Proxy(t *testing.T) { utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) } +// go test -run Test_Proxy_Balancer_WithTlsConfig +func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + utils.AssertEqual(t, nil, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + app.Get("/tlsbalaner", func(c *fiber.Ctx) error { + return c.SendString("tls balancer") + }) + + addr := ln.Addr().String() + clientTLSConf = &tls.Config{InsecureSkipVerify: true} + + // disable certificate verification in Balancer + app.Use(Balancer(Config{ + Servers: []string{addr}, + TlsConfig: clientTLSConf, + })) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() + + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, fiber.StatusOK, code) + utils.AssertEqual(t, "tls balancer", body) +} + // go test -run Test_Proxy_Forward func Test_Proxy_Forward(t *testing.T) { t.Parallel() @@ -103,6 +141,40 @@ func Test_Proxy_Forward(t *testing.T) { utils.AssertEqual(t, "forwarded", string(b)) } +// go test -run Test_Proxy_Forward_WithTlsConfig +func Test_Proxy_Forward_WithTlsConfig(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + utils.AssertEqual(t, nil, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + app.Get("/tlsfwd", func(c *fiber.Ctx) error { + return c.SendString("tls forward") + }) + + addr := ln.Addr().String() + clientTLSConf = &tls.Config{InsecureSkipVerify: true} + + // disable certificate verification + WithTlsConfig(clientTLSConf) + app.Use(Forward("https://" + addr + "/tlsfwd")) + + go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }() + + code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String() + + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, fiber.StatusOK, code) + utils.AssertEqual(t, "tls forward", body) +} + // go test -run Test_Proxy_Modify_Response func Test_Proxy_Modify_Response(t *testing.T) { t.Parallel()