From abf5ab340ed76792214ae80c62df7abe0ad1b8a8 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Tue, 15 Oct 2019 14:07:10 -0600 Subject: [PATCH] caddyhttp: Improve ResponseRecorder to buffer headers --- modules/caddyhttp/caddyhttp.go | 13 +++ modules/caddyhttp/httpcache/httpcache.go | 5 +- modules/caddyhttp/markdown/markdown.go | 6 +- modules/caddyhttp/responsewriter.go | 119 ++++++++++++++++------ modules/caddyhttp/templates/templates.go | 18 ++-- modules/caddyhttp/templates/tplcontext.go | 2 +- 6 files changed, 116 insertions(+), 47 deletions(-) diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 5631b300..29a5ab07 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -593,6 +593,19 @@ func (ws WeakString) String() string { return string(ws) } +// CopyHeader copies HTTP headers by completely +// replacing dest with src. (This allows deletions +// to be propagated, assuming src started as a +// consistent copy of dest.) +func CopyHeader(dest, src http.Header) { + for field := range dest { + delete(dest, field) + } + for field, val := range src { + dest[field] = val + } +} + // StatusCodeMatches returns true if a real HTTP status code matches // the configured status code, which may be either a real HTTP status // code or an integer representing a class of codes (e.g. 4 for all diff --git a/modules/caddyhttp/httpcache/httpcache.go b/modules/caddyhttp/httpcache/httpcache.go index 1b2cfd2e..0b49c7ee 100644 --- a/modules/caddyhttp/httpcache/httpcache.go +++ b/modules/caddyhttp/httpcache/httpcache.go @@ -130,8 +130,7 @@ func (c *Cache) getter(ctx groupcache.Context, key string, dest groupcache.Sink) // we need to record the response if we are to cache it; only cache if // request is successful (TODO: there's probably much more nuance needed here) - var rr caddyhttp.ResponseRecorder - rr = caddyhttp.NewResponseRecorder(combo.rw, buf, func(status int) bool { + rr := caddyhttp.NewResponseRecorder(combo.rw, buf, func(status int, header http.Header) bool { shouldBuf := status < 300 if shouldBuf { @@ -141,7 +140,7 @@ func (c *Cache) getter(ctx groupcache.Context, key string, dest groupcache.Sink) // the rest will be the body, which will be written // implicitly for us by the recorder err := gob.NewEncoder(buf).Encode(headerAndStatus{ - Header: rr.Header(), + Header: header, Status: status, }) if err != nil { diff --git a/modules/caddyhttp/markdown/markdown.go b/modules/caddyhttp/markdown/markdown.go index 122aad6e..5ff18b88 100644 --- a/modules/caddyhttp/markdown/markdown.go +++ b/modules/caddyhttp/markdown/markdown.go @@ -48,8 +48,8 @@ func (m Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht buf.Reset() defer bufPool.Put(buf) - shouldBuf := func(status int) bool { - return strings.HasPrefix(w.Header().Get("Content-Type"), "text/") + shouldBuf := func(status int, header http.Header) bool { + return strings.HasPrefix(header.Get("Content-Type"), "text/") } rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuf) @@ -62,6 +62,8 @@ func (m Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht return nil } + caddyhttp.CopyHeader(w.Header(), rec.Header()) + output := blackfriday.Run(buf.Bytes()) w.Header().Set("Content-Length", strconv.Itoa(len(output))) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 344298f2..5beb40ea 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -18,6 +18,7 @@ import ( "bufio" "bytes" "fmt" + "io" "net" "net/http" ) @@ -78,52 +79,89 @@ type responseRecorder struct { wroteHeader bool statusCode int buf *bytes.Buffer - shouldBuffer func(status int) bool + shouldBuffer ShouldBufferFunc stream bool size int + header http.Header } // NewResponseRecorder returns a new ResponseRecorder that can be -// used instead of a real http.ResponseWriter. The recorder is useful -// for middlewares which need to buffer a responder's response and -// process it in its entirety before actually allowing the response to -// be written. Of course, this has a performance overhead, but -// sometimes there is no way to avoid buffering the whole response. -// Still, if at all practical, middlewares should strive to stream +// used instead of a standard http.ResponseWriter. The recorder is +// useful for middlewares which need to buffer a response and +// potentially process its entire body before actually writing the +// response to the underlying writer. Of course, buffering the entire +// body has a memory overhead, but sometimes there is no way to avoid +// buffering the whole response, hence the existence of this type. +// Still, if at all practical, handlers should strive to stream // responses by wrapping Write and WriteHeader methods instead of // buffering whole response bodies. // -// Recorders optionally buffer the response. When the headers are -// to be written, shouldBuffer will be called with the status -// code that is being written. The rest of the headers can be read -// from w.Header(). If shouldBuffer returns true, the response -// will be buffered. You can know the response was buffered if -// the Buffered() method returns true. If the response was not -// buffered, Buffered() will return false and that means the -// response bypassed the recorder and was written directly to the -// underlying writer. If shouldBuffer is nil, the response will -// never be buffered (it will always be streamed directly), and -// buf can also safely be nil. +// Buffering is actually optional. The shouldBuffer function will +// be called just before the headers are written. If it returns +// true, the headers and body will be buffered by this recorder +// and not written to the underlying writer; if false, the headers +// will be written immediately and the body will be streamed out +// directly to the underlying writer. If shouldBuffer is nil, +// the response will never be buffered and will always be streamed +// directly to the writer. // -// Before calling this function in a middleware handler, make a -// new buffer or obtain one from a pool (use the sync.Pool) type. -// Using a pool is generally recommended for performance gains; -// do profiling to ensure this is the case. If using a pool, be -// sure to reset the buffer before using it. +// You can know if shouldBuffer returned true by calling Buffered(). // -// The returned recorder can be used in place of w when calling -// the next handler in the chain. When that handler returns, you -// can read the status code from the recorder's Status() method. -// The response body fills buf if it was buffered, and the headers -// are available via w.Header(). -func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer func(status int) bool) ResponseRecorder { +// The provided buffer buf should be obtained from a pool for best +// performance (see the sync.Pool type). +// +// Proper usage of a recorder looks like this: +// +// rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer) +// err := next.ServeHTTP(rec, req) +// if err != nil { +// return err +// } +// if !rec.Buffered() { +// return nil +// } +// // process the buffered response here +// +// After a response has been buffered, remember that any upstream header +// manipulations are only manifest in the recorder's Header(), not the +// Header() of the underlying ResponseWriter. Thus if you wish to inspect +// or change response headers, you either need to use rec.Header(), or +// copy rec.Header() into w.Header() first (see caddyhttp.CopyHeader). +// +// Once you are ready to write the response, there are two ways you can do +// it. The easier way is to have the recorder do it: +// +// rec.WriteResponse() +// +// This writes the recorded response headers as well as the buffered body. +// Or, you may wish to do it yourself, especially if you manipulated the +// buffered body. First you will need to copy the recorded headers, then +// write the headers with the recorded status code, then write the body +// (this example writes the recorder's body buffer, but you might have +// your own body to write instead): +// +// caddyhttp.CopyHeader(w.Header(), rec.Header()) +// w.WriteHeader(rec.Status()) +// io.Copy(w, rec.Buffer()) +// +func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder { + // copy the current response header into this buffer so + // that any header manipulations on the buffered header + // are consistent with what would be written out + hdr := make(http.Header) + CopyHeader(hdr, w.Header()) return &responseRecorder{ ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w}, buf: buf, shouldBuffer: shouldBuffer, + header: hdr, } } +func (rr *responseRecorder) Header() http.Header { + return rr.header +} + func (rr *responseRecorder) WriteHeader(statusCode int) { if rr.wroteHeader { return @@ -135,9 +173,12 @@ func (rr *responseRecorder) WriteHeader(statusCode int) { if rr.shouldBuffer == nil { rr.stream = true } else { - rr.stream = !rr.shouldBuffer(rr.statusCode) + rr.stream = !rr.shouldBuffer(rr.statusCode, rr.header) } + + // if not buffered, immediately write header if rr.stream { + CopyHeader(rr.ResponseWriterWrapper.Header(), rr.header) rr.ResponseWriterWrapper.WriteHeader(rr.statusCode) } } @@ -179,16 +220,32 @@ func (rr *responseRecorder) Buffered() bool { return !rr.stream } +func (rr *responseRecorder) WriteResponse() error { + if rr.stream { + return nil + } + CopyHeader(rr.ResponseWriterWrapper.Header(), rr.header) + rr.ResponseWriterWrapper.WriteHeader(rr.statusCode) + _, err := io.Copy(rr.ResponseWriterWrapper, rr.buf) + return err +} // ResponseRecorder is a http.ResponseWriter that records -// responses instead of writing them to the client. +// responses instead of writing them to the client. See +// docs for NewResponseRecorder for proper usage. type ResponseRecorder interface { HTTPInterfaces Status() int Buffer() *bytes.Buffer Buffered() bool Size() int + WriteResponse() error } +// ShouldBufferFunc is a function that returns true if the +// response should be buffered, given the pending HTTP status +// code and response headers. +type ShouldBufferFunc func(status int, header http.Header) bool + // Interface guards var ( _ HTTPInterfaces = (*ResponseWriterWrapper)(nil) diff --git a/modules/caddyhttp/templates/templates.go b/modules/caddyhttp/templates/templates.go index 05a2f633..e9c1da81 100644 --- a/modules/caddyhttp/templates/templates.go +++ b/modules/caddyhttp/templates/templates.go @@ -17,7 +17,6 @@ package templates import ( "bytes" "fmt" - "io" "net/http" "strconv" "strings" @@ -71,8 +70,8 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy // shouldBuf determines whether to execute templates on this response, // since generally we will not want to execute for images or CSS, etc. - shouldBuf := func(status int) bool { - ct := w.Header().Get("Content-Type") + shouldBuf := func(status int, header http.Header) bool { + ct := header.Get("Content-Type") for _, mt := range t.MIMETypes { if strings.Contains(ct, mt) { return true @@ -96,18 +95,17 @@ func (t *Templates) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy return err } - w.Header().Set("Content-Length", strconv.Itoa(buf.Len())) - w.Header().Del("Accept-Ranges") // we don't know ranges for dynamically-created content - w.Header().Del("Last-Modified") // useless for dynamic content since it's always changing + rec.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + rec.Header().Del("Accept-Ranges") // we don't know ranges for dynamically-created content + rec.Header().Del("Last-Modified") // useless for dynamic content since it's always changing // we don't know a way to guickly generate etag for dynamic content, // but we can convert this to a weak etag to kind of indicate that - if etag := w.Header().Get("ETag"); etag != "" { - w.Header().Set("ETag", "W/"+etag) + if etag := rec.Header().Get("Etag"); etag != "" { + rec.Header().Set("Etag", "W/"+etag) } - w.WriteHeader(rec.Status()) - io.Copy(w, buf) + rec.WriteResponse() return nil } diff --git a/modules/caddyhttp/templates/tplcontext.go b/modules/caddyhttp/templates/tplcontext.go index 5b74623e..40d13707 100644 --- a/modules/caddyhttp/templates/tplcontext.go +++ b/modules/caddyhttp/templates/tplcontext.go @@ -80,7 +80,7 @@ func (c templateContext) Include(filename string, args ...interface{}) (template // If it is not trusted, be sure to use escaping functions yourself. func (c templateContext) HTTPInclude(uri string) (template.HTML, error) { if c.Req.Header.Get(recursionPreventionHeader) == "1" { - return "", fmt.Errorf("virtual include cycle") + return "", fmt.Errorf("virtual request cycle") } buf := bufPool.Get().(*bytes.Buffer)