package api import ( "net/http" "regexp" "strings" ) // CORS represents CORS configuration type CORS struct { AllowOrigins []string AllowHeaders []string AllowMethods []string ContentType string allowOriginPatterns []string } // AllowRequestsFrom returns handler to verify that requests come from allowedOrigins func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc { c := &CORS{ AllowOrigins: allowedOrigins, AllowMethods: []string{"GET", "POST"}, AllowHeaders: []string{"Origin", "Content-Type"}, ContentType: "application/json; charset=utf-8", } c.generatePatterns() return func(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") if c.isOriginAllowed(origin) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders()) w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods()) w.Header().Set("Content-Type", c.ContentType) } } } // isOriginAllowed returns true if origin matches an allowed origin pattern. func (c *CORS) isOriginAllowed(origin string) bool { for _, allowedOriginPattern := range c.allowOriginPatterns { if allowed, _ := regexp.MatchString(allowedOriginPattern, origin); allowed { return true } } return false } // generatePatterns generates regex expression for AllowOrigins func (c *CORS) generatePatterns() { if c.AllowOrigins != nil { for _, origin := range c.AllowOrigins { pattern := regexp.QuoteMeta(origin) pattern = strings.ReplaceAll(pattern, "\\*", ".*") pattern = strings.ReplaceAll(pattern, "\\?", ".") c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$") } } } // AllowedHeaders returns allowed headers in a string func (c *CORS) AllowedHeaders() string { return strings.Join(c.AllowHeaders, ",") } // AllowedMethods returns allowed http methods in a string func (c *CORS) AllowedMethods() string { return strings.Join(c.AllowMethods, ",") }