71 lines
2.0 KiB
Go
71 lines
2.0 KiB
Go
package api
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// CORS represents CORS configuration
|
|
type CORS struct {
|
|
AllowOrigins []string
|
|
AllowHeaders []string
|
|
AllowMethods []string
|
|
ContentType string
|
|
allowOriginPatterns []string
|
|
}
|
|
|
|
// AllowRequestsFrom returns handler to verify that requests come from allowedOrigins
|
|
func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc {
|
|
c := &CORS{
|
|
AllowOrigins: allowedOrigins,
|
|
AllowMethods: []string{"GET", "POST"},
|
|
AllowHeaders: []string{"Origin", "Content-Type"},
|
|
ContentType: "application/json; charset=utf-8",
|
|
}
|
|
|
|
c.generatePatterns()
|
|
|
|
return func(w http.ResponseWriter, req *http.Request) {
|
|
origin := req.Header.Get("Origin")
|
|
if c.isOriginAllowed(origin) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders())
|
|
w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods())
|
|
w.Header().Set("Content-Type", c.ContentType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// isOriginAllowed returns true if origin matches an allowed origin pattern.
|
|
func (c *CORS) isOriginAllowed(origin string) bool {
|
|
for _, allowedOriginPattern := range c.allowOriginPatterns {
|
|
if allowed, _ := regexp.MatchString(allowedOriginPattern, origin); allowed {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// generatePatterns generates regex expression for AllowOrigins
|
|
func (c *CORS) generatePatterns() {
|
|
if c.AllowOrigins != nil {
|
|
for _, origin := range c.AllowOrigins {
|
|
pattern := regexp.QuoteMeta(origin)
|
|
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
|
|
pattern = strings.ReplaceAll(pattern, "\\?", ".")
|
|
c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$")
|
|
}
|
|
}
|
|
}
|
|
|
|
// AllowedHeaders returns allowed headers in a string
|
|
func (c *CORS) AllowedHeaders() string {
|
|
return strings.Join(c.AllowHeaders, ",")
|
|
}
|
|
|
|
// AllowedMethods returns allowed http methods in a string
|
|
func (c *CORS) AllowedMethods() string {
|
|
return strings.Join(c.AllowMethods, ",")
|
|
}
|