From d876de61e512db7a31a7ae59723d5134048f283e Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Sat, 11 Jan 2020 11:40:03 -0700 Subject: [PATCH] rewrite: Fix query string logic --- modules/caddyhttp/rewrite/rewrite.go | 48 ++++++++++++----------- modules/caddyhttp/rewrite/rewrite_test.go | 8 ++++ 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/modules/caddyhttp/rewrite/rewrite.go b/modules/caddyhttp/rewrite/rewrite.go index d9464474..c069db9d 100644 --- a/modules/caddyhttp/rewrite/rewrite.go +++ b/modules/caddyhttp/rewrite/rewrite.go @@ -171,23 +171,31 @@ func (rewr Rewrite) rewrite(r *http.Request, repl *caddy.Replacer, logger *zap.L // buildQueryString takes an input query string and // performs replacements on each component, returning -// the resulting query string. +// the resulting query string. This function appends +// duplicate keys rather than replaces. func buildQueryString(qs string, repl *caddy.Replacer) string { var sb strings.Builder - var wroteKey bool + + // first component must be key, which is the same + // as if we just wrote a value in previous iteration + wroteVal := true for len(qs) > 0 { - // determine the end of this component + // determine the end of this component, which will be at + // the next equal sign or ampersand, whichever comes first nextEq, nextAmp := strings.Index(qs, "="), strings.Index(qs, "&") - end := min(nextEq, nextAmp) - if end == -1 { - end = len(qs) // if there is nothing left, go to end of string + ampIsNext := nextAmp >= 0 && (nextAmp < nextEq || nextEq < 0) + end := len(qs) // assume no delimiter remains... + if ampIsNext { + end = nextAmp // ...unless ampersand is first... + } else if nextEq >= 0 && (nextEq < nextAmp || nextAmp < 0) { + end = nextEq // ...or unless equal is first. } // consume the component and write the result comp := qs[:end] comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) { - if name == "http.request.uri.query" { + if name == "http.request.uri.query" && wroteVal { return val, nil // already escaped } return url.QueryEscape(val), nil @@ -197,29 +205,25 @@ func buildQueryString(qs string, repl *caddy.Replacer) string { } qs = qs[end:] - if wroteKey { + // if previous iteration wrote a value, + // that means we are writing a key + if wroteVal { + if sb.Len() > 0 { + sb.WriteRune('&') + } + } else { sb.WriteRune('=') - } else if sb.Len() > 0 { - sb.WriteRune('&') } - - // remember that we just wrote a key, which is if the next - // delimiter is an equals sign or if there is no ampersand - wroteKey = nextEq < nextAmp || nextAmp < 0 - sb.WriteString(comp) + + // remember for the next iteration that we just wrote a value, + // which means the next iteration MUST write a key + wroteVal = ampIsNext } return sb.String() } -func min(a, b int) int { - if b < a { - return b - } - return a -} - // replacer describes a simple and fast substring replacement. type replacer struct { // The substring to find. Supports placeholders. diff --git a/modules/caddyhttp/rewrite/rewrite_test.go b/modules/caddyhttp/rewrite/rewrite_test.go index beff499a..de82d8d8 100644 --- a/modules/caddyhttp/rewrite/rewrite_test.go +++ b/modules/caddyhttp/rewrite/rewrite_test.go @@ -138,6 +138,11 @@ func TestRewrite(t *testing.T) { input: newRequest(t, "GET", "/foo/bar?a=b&c=d"), expect: newRequest(t, "GET", "/foo/bar"), }, + { + rule: Rewrite{URI: "?qs={http.request.uri.query}"}, + input: newRequest(t, "GET", "/foo?a=b&c=d"), + expect: newRequest(t, "GET", "/foo?qs=a%3Db%26c%3Dd"), + }, { rule: Rewrite{URI: "/foo?{http.request.uri.query}#frag"}, input: newRequest(t, "GET", "/foo/bar?a=b"), @@ -216,6 +221,9 @@ func TestRewrite(t *testing.T) { if expected, actual := tc.expect.URL.RequestURI(), tc.input.URL.RequestURI(); expected != actual { t.Errorf("Test %d: Expected URL.RequestURI()='%s' but got '%s'", i, expected, actual) } + if expected, actual := tc.expect.URL.Fragment, tc.input.URL.Fragment; expected != actual { + t.Errorf("Test %d: Expected URL.Fragment='%s' but got '%s'", i, expected, actual) + } } }