diff --git a/caddytest/integration/reverseproxy_test.go b/caddytest/integration/reverseproxy_test.go new file mode 100644 index 00000000..8000546d --- /dev/null +++ b/caddytest/integration/reverseproxy_test.go @@ -0,0 +1,97 @@ +package integration + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "runtime" + "strings" + "testing" + + "github.com/caddyserver/caddy/v2/caddytest" +) + +func TestReverseProxyHealthCheck(t *testing.T) { + tester := caddytest.NewTester(t) + tester.InitServer(` + { + http_port 9080 + https_port 9443 + } + http://localhost:2020 { + respond "Hello, World!" + } + http://localhost:2021 { + respond "ok" + } + http://localhost:9080 { + reverse_proxy { + to localhost:2020 + + health_path /health + health_port 2021 + health_interval 2s + health_timeout 5s + } + } + `, "caddyfile") + + tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!") +} + +func TestReverseProxyHealthCheckUnixSocket(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + tester := caddytest.NewTester(t) + f, err := ioutil.TempFile("", "*.sock") + if err != nil { + t.Errorf("failed to create TempFile: %s", err) + return + } + // a hack to get a file name within a valid path to use as socket + socketName := f.Name() + os.Remove(f.Name()) + + server := http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if strings.HasPrefix(req.URL.Path, "/health") { + w.Write([]byte("ok")) + return + } + w.Write([]byte("Hello, World!")) + }), + } + + unixListener, err := net.Listen("unix", socketName) + if err != nil { + t.Errorf("failed to listen on the socket: %s", err) + return + } + go server.Serve(unixListener) + t.Cleanup(func() { + server.Close() + }) + runtime.Gosched() // Allow other goroutines to run + + tester.InitServer(fmt.Sprintf(` + { + http_port 9080 + https_port 9443 + } + http://localhost:9080 { + reverse_proxy { + to unix/%s + + health_path /health + health_port 2021 + health_interval 2s + health_timeout 5s + } + } + `, socketName), "caddyfile") + + tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!") +} diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go index 33cfd82b..410b9d4a 100644 --- a/modules/caddyhttp/reverseproxy/healthchecks.go +++ b/modules/caddyhttp/reverseproxy/healthchecks.go @@ -153,32 +153,27 @@ func (h *Handler) doActiveHealthCheckForAllHosts() { log.Printf("[PANIC] active health check: %v\n%s", err, debug.Stack()) } }() - networkAddr := upstream.Dial - addr, err := caddy.ParseNetworkAddress(networkAddr) - if err != nil { - h.HealthChecks.Active.logger.Error("bad network address", - zap.String("address", networkAddr), - zap.Error(err), - ) - return - } - if addr.PortRangeSize() != 1 { - h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)", - zap.String("address", networkAddr), - ) - return - } - hostAddr := addr.JoinHostPort(0) - if addr.IsUnixNetwork() { + + portStr := strconv.Itoa(upstream.activeHealthCheckPort) + hostAddr := net.JoinHostPort(upstream.networkAddress.Host, portStr) + if upstream.networkAddress.IsUnixNetwork() { // this will be used as the Host portion of a http.Request URL, and // paths to socket files would produce an error when creating URL, // so use a fake Host value instead; unix sockets are usually local hostAddr = "localhost" } - err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, upstream.Host) + + dialInfo := DialInfo{ + Upstream: upstream, + Network: upstream.networkAddress.Network, + Host: upstream.networkAddress.Host, + Port: portStr, + Address: hostAddr, + } + err := h.doActiveHealthCheck(dialInfo, hostAddr, upstream.Host) if err != nil { h.HealthChecks.Active.logger.Error("active health check failed", - zap.String("address", networkAddr), + zap.String("address", hostAddr), zap.Error(err), ) } diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go index 5870b75f..b7b8c9b5 100644 --- a/modules/caddyhttp/reverseproxy/hosts.go +++ b/modules/caddyhttp/reverseproxy/hosts.go @@ -92,8 +92,10 @@ type Upstream struct { // HeaderAffinity string // IPAffinity string - healthCheckPolicy *PassiveHealthChecks - cb CircuitBreaker + networkAddress caddy.NetworkAddress + activeHealthCheckPort int + healthCheckPolicy *PassiveHealthChecks + cb CircuitBreaker } func (u Upstream) String() string { diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go index dce7b9e8..7e3bb69c 100644 --- a/modules/caddyhttp/reverseproxy/httptransport.go +++ b/modules/caddyhttp/reverseproxy/httptransport.go @@ -182,6 +182,9 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error) if dialInfo, ok := GetDialInfo(ctx); ok { network = dialInfo.Network address = dialInfo.Address + if dialInfo.Upstream.networkAddress.IsUnixNetwork() { + address = dialInfo.Host + } } conn, err := dialer.DialContext(ctx, network, address) if err != nil { diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 910fbfc7..138a3fc8 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -208,9 +208,13 @@ func (h *Handler) Provision(ctx caddy.Context) error { if err != nil { return err } + if addr.PortRangeSize() != 1 { return fmt.Errorf("multiple addresses (upstream must map to only one address): %v", addr) } + + upstream.networkAddress = addr + // create or get the host representation for this upstream var host Host = new(upstreamHost) existingHost, loaded := hosts.LoadOrStore(upstream.String(), host) @@ -267,6 +271,16 @@ func (h *Handler) Provision(ctx caddy.Context) error { Transport: h.Transport, } + for _, upstream := range h.Upstreams { + // if there's an alternative port for health-check provided in the config, + // then use it, otherwise use the port of upstream. + if h.HealthChecks.Active.Port != 0 { + upstream.activeHealthCheckPort = h.HealthChecks.Active.Port + } else { + upstream.activeHealthCheckPort = int(upstream.networkAddress.StartPort) + } + } + if h.HealthChecks.Active.Interval == 0 { h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) }