Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ import (
"github.com/labstack/echo/v4"
)

// CSRFUsingSecFetchSite is a context key for CSRF middleware what is set when the client browser is using Sec-Fetch-Site
// header and the request is deemed safe.
// It is a dummy token value that can be used to render CSRF token for form by handlers.
//
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
// handler may need this value to render CSRF token for form.
const CSRFUsingSecFetchSite = "_echo_csrf_using_sec_fetch_site_"

// CSRFConfig defines the config for CSRF middleware.
type CSRFConfig struct {
// Skipper defines a function to skip middleware.
Expand Down Expand Up @@ -83,6 +92,8 @@ type CSRFConfig struct {

// ErrorHandler defines a function which is executed for returning custom errors.
ErrorHandler CSRFErrorHandler

generator func(length uint8) string
}

// CSRFErrorHandler is a function which is executed for creating custom errors.
Expand Down Expand Up @@ -145,6 +156,10 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
}
config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...)
}
tokenGenerator := randomString
if config.generator != nil {
tokenGenerator = config.generator
}

extractors, cErr := CreateExtractors(config.TokenLookup)
if cErr != nil {
Expand All @@ -170,7 +185,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {

token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = randomString(config.TokenLength)
token = tokenGenerator(config.TokenLength)
} else {
token = k.Value // Reuse token
}
Expand Down Expand Up @@ -287,6 +302,11 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error)
}

if isSafe {
// This helps handlers that support older token-based CSRF protection.
// We know that the client is using a browser that supports Sec-Fetch-Site header, so when the form is submitted in
// the future with this dummy token value it is OK. Although the request is safe, the template rendered by the
// handler may need this value to render CSRF token for form.
c.Set(config.ContextKey, CSRFUsingSecFetchSite)
return true, nil
}
// we are here when request is state-changing and `cross-site` or `same-site`
Expand Down
29 changes: 23 additions & 6 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
},
},
{
name: "ok, token from POST header, second token passes",
name: "nok, token from POST header, tokens limited to 1, second token would pass",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestCSRF_tokenExtractors(t *testing.T) {
},
},
{
name: "ok, token from PUT query form, second token passes",
name: "nok, token from PUT query form, second token would pass",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
Expand Down Expand Up @@ -235,12 +235,14 @@ func TestCSRFWithConfig(t *testing.T) {
expectEmptyBody bool
expectMWError string
expectCookieContains string
expectTokenInContext string
expectErr string
}{
{
name: "ok, GET",
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
expectTokenInContext: "TESTTOKEN",
},
{
name: "ok, POST valid token",
Expand All @@ -250,6 +252,7 @@ func TestCSRFWithConfig(t *testing.T) {
},
whenMethod: http.MethodPost,
expectCookieContains: "_csrf",
expectTokenInContext: token,
},
{
name: "nok, POST without token",
Expand Down Expand Up @@ -278,13 +281,23 @@ func TestCSRFWithConfig(t *testing.T) {
},
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
expectTokenInContext: "TESTTOKEN",
},
{
name: "ok, unsafe method + SecFetchSite=same-origin passes",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodPost,
whenMethod: http.MethodPost,
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
},
{
name: "ok, safe method + SecFetchSite=same-origin passes",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodGet,
expectTokenInContext: "_echo_csrf_using_sec_fetch_site_",
},
{
name: "nok, unsafe method + SecFetchSite=same-cross blocked",
Expand Down Expand Up @@ -312,6 +325,11 @@ func TestCSRFWithConfig(t *testing.T) {
if tc.givenConfig != nil {
config = *tc.givenConfig
}
if config.generator == nil {
config.generator = func(_ uint8) string {
return "TESTTOKEN"
}
}
mw, err := config.ToMiddleware()
if tc.expectMWError != "" {
assert.EqualError(t, err, tc.expectMWError)
Expand All @@ -320,6 +338,8 @@ func TestCSRFWithConfig(t *testing.T) {
assert.NoError(t, err)

h := mw(func(c echo.Context) error {
cToken := c.Get(cmp.Or(config.ContextKey, DefaultCSRFConfig.ContextKey))
assert.Equal(t, tc.expectTokenInContext, cToken)
return c.String(http.StatusOK, "test")
})

Expand Down Expand Up @@ -559,7 +579,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodPost,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: ``,
},
{
name: "ok, unsafe POST + same-origin passes",
Expand Down Expand Up @@ -617,7 +636,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodPut,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: ``,
},
{
name: "nok, unsafe DELETE + cross-site is blocked",
Expand All @@ -633,7 +651,6 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
whenMethod: http.MethodDelete,
whenSecFetchSite: "same-site",
expectAllow: false,
expectErr: ``,
},
{
name: "nok, unsafe PATCH + cross-site is blocked",
Expand Down