hybridgroup.gobot/api/cors.go

71 lines
2.0 KiB
Go

package api
import (
"net/http"
"regexp"
"strings"
)
// CORS represents CORS configuration
type CORS struct {
AllowOrigins []string
AllowHeaders []string
AllowMethods []string
ContentType string
allowOriginPatterns []string
}
// AllowRequestsFrom returns handler to verify that requests come from allowedOrigins
func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc {
c := &CORS{
AllowOrigins: allowedOrigins,
AllowMethods: []string{"GET", "POST"},
AllowHeaders: []string{"Origin", "Content-Type"},
ContentType: "application/json; charset=utf-8",
}
c.generatePatterns()
return func(w http.ResponseWriter, req *http.Request) {
origin := req.Header.Get("Origin")
if c.isOriginAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders())
w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods())
w.Header().Set("Content-Type", c.ContentType)
}
}
}
// isOriginAllowed returns true if origin matches an allowed origin pattern.
func (c *CORS) isOriginAllowed(origin string) bool {
for _, allowedOriginPattern := range c.allowOriginPatterns {
if allowed, _ := regexp.MatchString(allowedOriginPattern, origin); allowed {
return true
}
}
return false
}
// generatePatterns generates regex expression for AllowOrigins
func (c *CORS) generatePatterns() {
if c.AllowOrigins != nil {
for _, origin := range c.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$")
}
}
}
// AllowedHeaders returns allowed headers in a string
func (c *CORS) AllowedHeaders() string {
return strings.Join(c.AllowHeaders, ",")
}
// AllowedMethods returns allowed http methods in a string
func (c *CORS) AllowedMethods() string {
return strings.Join(c.AllowMethods, ",")
}