72 lines
2.0 KiB
Go
72 lines
2.0 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// CORS represents CORS configuration
|
|
type CORS struct {
|
|
AllowOrigins []string
|
|
AllowHeaders []string
|
|
AllowMethods []string
|
|
ContentType string
|
|
allowOriginPatterns []string
|
|
}
|
|
|
|
// AllowRequestsFrom returns handler to verify that requests come from allowedOrigins
|
|
func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc {
|
|
c := &CORS{
|
|
AllowOrigins: allowedOrigins,
|
|
AllowMethods: []string{"GET", "POST"},
|
|
AllowHeaders: []string{"Origin", "Content-Type"},
|
|
ContentType: "application/json; charset=utf-8",
|
|
}
|
|
|
|
c.generatePatterns()
|
|
|
|
return func(w http.ResponseWriter, req *http.Request) {
|
|
origin := req.Header.Get("Origin")
|
|
if c.isOriginAllowed(origin) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders())
|
|
w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods())
|
|
w.Header().Set("Content-Type", c.ContentType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// isOriginAllowed returns true if origin matches an allowed origin pattern.
|
|
func (c *CORS) isOriginAllowed(origin string) (allowed bool) {
|
|
for _, allowedOriginPattern := range c.allowOriginPatterns {
|
|
allowed, _ = regexp.MatchString(allowedOriginPattern, origin)
|
|
if allowed {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// generatePatterns generates regex expression for AllowOrigins
|
|
func (c *CORS) generatePatterns() {
|
|
if c.AllowOrigins != nil {
|
|
for _, origin := range c.AllowOrigins {
|
|
pattern := regexp.QuoteMeta(origin)
|
|
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
|
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
|
c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$")
|
|
}
|
|
}
|
|
}
|
|
|
|
// AllowedHeaders returns allowed headers in a string
|
|
func (c *CORS) AllowedHeaders() string {
|
|
return strings.Join(c.AllowHeaders, ",")
|
|
}
|
|
|
|
// AllowedMethods returns allowed http methods in a string
|
|
func (c *CORS) AllowedMethods() string {
|
|
return strings.Join(c.AllowMethods, ",")
|
|
}
|