diff --git a/api/api.go b/api/api.go index 36605350..6bc8b3a8 100644 --- a/api/api.go +++ b/api/api.go @@ -106,8 +106,10 @@ func (a *api) CORSHandler(w http.ResponseWriter, req *http.Request) { origin := req.Header.Get("Origin") if a.cors.isOriginAllowed(origin) { w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Allow-Headers", a.cors.AllowedHeaders()) + w.Header().Set("Access-Control-Allow-Methods", a.cors.AllowedMethods()) + w.Header().Set("Content-Type", a.cors.ContentType) } - } func (a *api) SetDebug() { diff --git a/api/cors.go b/api/cors.go index eca27588..530d75d8 100644 --- a/api/cors.go +++ b/api/cors.go @@ -1,24 +1,56 @@ package api +import ( + "regexp" + "strings" +) + type CORS struct { - AllowOrigins []string - AllowHeaders []string // Not yet implemented - AllowMethods []string // ditto + AllowOrigins []string + AllowHeaders []string + AllowMethods []string + ContentType string + allowOriginPatterns []string } func NewCORS(allowedOrigins []string) *CORS { - return &CORS{ + cors := &CORS{ AllowOrigins: allowedOrigins, AllowMethods: []string{"GET", "POST"}, AllowHeaders: []string{"Origin", "Content-Type"}, + ContentType: "application/json; charset=utf-8", + } + + cors.generatePatterns() + + return cors +} + +func (c *CORS) isOriginAllowed(origin string) (allowed bool) { + for _, allowedOriginPattern := range c.allowOriginPatterns { + allowed, _ = regexp.MatchString(allowedOriginPattern, origin) + if allowed { + return + } + } + return +} + +func (c *CORS) generatePatterns() { + if c.AllowOrigins != nil { + for _, origin := range c.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$") + } } } -func (c *CORS) isOriginAllowed(currentOrigin string) bool { - for _, allowedOrigin := range c.AllowOrigins { - if "*" == allowedOrigin || currentOrigin == allowedOrigin { - return true - } - } - return false +func (c *CORS) AllowedHeaders() string { + return strings.Join(c.AllowHeaders, ",") +} + +func (c *CORS) AllowedMethods() string { + return strings.Join(c.AllowMethods, ",") } diff --git a/api/cors_test.go b/api/cors_test.go index f214974e..fada1832 100644 --- a/api/cors_test.go +++ b/api/cors_test.go @@ -16,11 +16,19 @@ func TestNewCORS(t *testing.T) { } func TestNewCorsSetsProperties(t *testing.T) { - allowedOrigins := []string{"http://server:port"} + allowedOrigins := []string{"http://*server:*", "http://localhost:*"} + allowedMethods := []string{"GET", "POST"} + allowedHeaders := []string{"Origin", "Content-Type"} + contentType := "application/json; charset=utf-8" + allowOriginPatterns := []string{"^http://.*server:.*$", "^http://localhost:.*$"} cors := NewCORS(allowedOrigins) gobot.Assert(t, cors.AllowOrigins, allowedOrigins) + gobot.Assert(t, cors.AllowMethods, allowedMethods) + gobot.Assert(t, cors.AllowHeaders, allowedHeaders) + gobot.Assert(t, cors.ContentType, contentType) + gobot.Assert(t, cors.allowOriginPatterns, allowOriginPatterns) } func TestCORSIsOriginAllowed(t *testing.T) { @@ -32,16 +40,42 @@ func TestCORSIsOriginAllowed(t *testing.T) { gobot.Assert(t, cors.isOriginAllowed("http://server.com"), true) // When one origin is accepted - cors.AllowOrigins = []string{"http://localhost:8000"} + cors = NewCORS([]string{"http://localhost:8000"}) gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), true) gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), false) gobot.Assert(t, cors.isOriginAllowed("http://server.com"), false) // When several origins are accepted - cors.AllowOrigins = []string{"http://localhost:8000", "http://server.com"} + cors = NewCORS([]string{"http://localhost:*", "http://server.com"}) gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), true) - gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), false) + gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), true) gobot.Assert(t, cors.isOriginAllowed("http://server.com"), true) + + // When several origins are accepted within the same domain + cors = NewCORS([]string{"http://*.server.com"}) + + gobot.Assert(t, cors.isOriginAllowed("http://localhost:8000"), false) + gobot.Assert(t, cors.isOriginAllowed("http://localhost:3001"), false) + gobot.Assert(t, cors.isOriginAllowed("http://foo.server.com"), true) + gobot.Assert(t, cors.isOriginAllowed("http://api.server.com"), true) +} + +func TestCORSAllowedHeaders(t *testing.T) { + cors := NewCORS([]string{"*"}) + + cors.AllowHeaders = []string{"Header1", "Header2"} + + gobot.Assert(t, cors.AllowedHeaders(), "Header1,Header2") +} + +func TestCORSAllowedMethods(t *testing.T) { + cors := NewCORS([]string{"*"}) + + gobot.Assert(t, cors.AllowedMethods(), "GET,POST") + + cors.AllowMethods = []string{"GET", "POST", "PUT"} + + gobot.Assert(t, cors.AllowedMethods(), "GET,POST,PUT") }