diff --git a/middleware/csrf.go b/middleware/csrf.go index f9d3293b0..1a35da63c 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -13,6 +13,15 @@ import ( "github.com/labstack/echo/v4" ) +// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site +// header and the request is deemed safe. +// It is a dummy token value that can be used to render CSRF token for form by handlers. +// +// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in +// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the +// handler may need this value to render CSRF token for form. +const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_" + // CSRFConfig defines the config for CSRF middleware. type CSRFConfig struct { // Skipper defines a function to skip middleware. @@ -83,6 +92,8 @@ type CSRFConfig struct { // ErrorHandler defines a function which is executed for returning custom errors. ErrorHandler CSRFErrorHandler + + generator func(length uint8) string } // CSRFErrorHandler is a function which is executed for creating custom errors. @@ -145,6 +156,10 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) } + tokenGenerator := randomString + if config.generator != nil { + tokenGenerator = config.generator + } extractors, cErr := CreateExtractors(config.TokenLookup) if cErr != nil { @@ -170,7 +185,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = randomString(config.TokenLength) + token = tokenGenerator(config.TokenLength) } else { token = k.Value // Reuse token } @@ -287,6 +302,11 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) } if isSafe { + // This helps handlers that support older token-based CSRF protection. + // We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in + // the future with this dummy token value it is OK. Although the request is safe, the template rendered by the + // handler may need this value to render CSRF token for form. + c.Set(config.ContextKey, CSRFUsingSecFetchSite) return true, nil } // we are here when request is state-changing and `cross-site` or `same-site` diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 85b7f1077..0b3210f07 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -86,7 +86,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { }, }, { - name: "ok, token from POST header, second token passes", + name: "nok, token from POST header, tokens limited to 1, second token would pass", whenTokenLookup: "header:" + echo.HeaderXCSRFToken, givenCSRFCookie: "token", givenMethod: http.MethodPost, @@ -122,7 +122,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { }, }, { - name: "ok, token from PUT query form, second token passes", + name: "nok, token from PUT query form, second token would pass", whenTokenLookup: "query:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPut, @@ -235,12 +235,14 @@ func TestCSRFWithConfig(t *testing.T) { expectEmptyBody bool expectMWError string expectCookieContains string + expectTokenInContext string expectErr string }{ { name: "ok, GET", whenMethod: http.MethodGet, expectCookieContains: "_csrf", + expectTokenInContext: "TESTTOKEN", }, { name: "ok, POST valid token", @@ -250,6 +252,7 @@ func TestCSRFWithConfig(t *testing.T) { }, whenMethod: http.MethodPost, expectCookieContains: "_csrf", + expectTokenInContext: token, }, { name: "nok, POST without token", @@ -278,13 +281,23 @@ func TestCSRFWithConfig(t *testing.T) { }, whenMethod: http.MethodGet, expectCookieContains: "_csrf", + expectTokenInContext: "TESTTOKEN", }, { name: "ok, unsafe method + SecFetchSite=same-origin passes", whenHeaders: map[string]string{ echo.HeaderSecFetchSite: "same-origin", }, - whenMethod: http.MethodPost, + whenMethod: http.MethodPost, + expectTokenInContext: "_echo_csrf_using_sec_fetch_site_", + }, + { + name: "ok, safe method + SecFetchSite=same-origin passes", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-origin", + }, + whenMethod: http.MethodGet, + expectTokenInContext: "_echo_csrf_using_sec_fetch_site_", }, { name: "nok, unsafe method + SecFetchSite=same-cross blocked", @@ -312,6 +325,11 @@ func TestCSRFWithConfig(t *testing.T) { if tc.givenConfig != nil { config = *tc.givenConfig } + if config.generator == nil { + config.generator = func(_ uint8) string { + return "TESTTOKEN" + } + } mw, err := config.ToMiddleware() if tc.expectMWError != "" { assert.EqualError(t, err, tc.expectMWError) @@ -320,6 +338,8 @@ func TestCSRFWithConfig(t *testing.T) { assert.NoError(t, err) h := mw(func(c echo.Context) error { + cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey)) + assert.Equal(t, tc.expectTokenInContext, cToken) return c.String(http.StatusOK, "test") }) @@ -559,7 +579,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, }, { name: "ok, unsafe POST + same-origin passes", @@ -617,7 +636,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPut, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, }, { name: "nok, unsafe DELETE + cross-site is blocked", @@ -633,7 +651,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodDelete, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, }, { name: "nok, unsafe PATCH + cross-site is blocked",