rewrite: Fix query string logic

This commit is contained in:
Matthew Holt 2020-01-11 11:40:03 -07:00
parent 8be1f0ea66
commit d876de61e5
No known key found for this signature in database
GPG Key ID: 2A349DD577D586A5
2 changed files with 34 additions and 22 deletions

View File

@ -171,23 +171,31 @@ func (rewr Rewrite) rewrite(r *http.Request, repl *caddy.Replacer, logger *zap.L
// buildQueryString takes an input query string and
// performs replacements on each component, returning
// the resulting query string.
// the resulting query string. This function appends
// duplicate keys rather than replaces.
func buildQueryString(qs string, repl *caddy.Replacer) string {
var sb strings.Builder
var wroteKey bool
// first component must be key, which is the same
// as if we just wrote a value in previous iteration
wroteVal := true
for len(qs) > 0 {
// determine the end of this component
// determine the end of this component, which will be at
// the next equal sign or ampersand, whichever comes first
nextEq, nextAmp := strings.Index(qs, "="), strings.Index(qs, "&")
end := min(nextEq, nextAmp)
if end == -1 {
end = len(qs) // if there is nothing left, go to end of string
ampIsNext := nextAmp >= 0 && (nextAmp < nextEq || nextEq < 0)
end := len(qs) // assume no delimiter remains...
if ampIsNext {
end = nextAmp // ...unless ampersand is first...
} else if nextEq >= 0 && (nextEq < nextAmp || nextAmp < 0) {
end = nextEq // ...or unless equal is first.
}
// consume the component and write the result
comp := qs[:end]
comp, _ = repl.ReplaceFunc(comp, func(name, val string) (string, error) {
if name == "http.request.uri.query" {
if name == "http.request.uri.query" && wroteVal {
return val, nil // already escaped
}
return url.QueryEscape(val), nil
@ -197,29 +205,25 @@ func buildQueryString(qs string, repl *caddy.Replacer) string {
}
qs = qs[end:]
if wroteKey {
// if previous iteration wrote a value,
// that means we are writing a key
if wroteVal {
if sb.Len() > 0 {
sb.WriteRune('&')
}
} else {
sb.WriteRune('=')
} else if sb.Len() > 0 {
sb.WriteRune('&')
}
// remember that we just wrote a key, which is if the next
// delimiter is an equals sign or if there is no ampersand
wroteKey = nextEq < nextAmp || nextAmp < 0
sb.WriteString(comp)
// remember for the next iteration that we just wrote a value,
// which means the next iteration MUST write a key
wroteVal = ampIsNext
}
return sb.String()
}
func min(a, b int) int {
if b < a {
return b
}
return a
}
// replacer describes a simple and fast substring replacement.
type replacer struct {
// The substring to find. Supports placeholders.

View File

@ -138,6 +138,11 @@ func TestRewrite(t *testing.T) {
input: newRequest(t, "GET", "/foo/bar?a=b&c=d"),
expect: newRequest(t, "GET", "/foo/bar"),
},
{
rule: Rewrite{URI: "?qs={http.request.uri.query}"},
input: newRequest(t, "GET", "/foo?a=b&c=d"),
expect: newRequest(t, "GET", "/foo?qs=a%3Db%26c%3Dd"),
},
{
rule: Rewrite{URI: "/foo?{http.request.uri.query}#frag"},
input: newRequest(t, "GET", "/foo/bar?a=b"),
@ -216,6 +221,9 @@ func TestRewrite(t *testing.T) {
if expected, actual := tc.expect.URL.RequestURI(), tc.input.URL.RequestURI(); expected != actual {
t.Errorf("Test %d: Expected URL.RequestURI()='%s' but got '%s'", i, expected, actual)
}
if expected, actual := tc.expect.URL.Fragment, tc.input.URL.Fragment; expected != actual {
t.Errorf("Test %d: Expected URL.Fragment='%s' but got '%s'", i, expected, actual)
}
}
}