diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 27df3f7a8..162a700cb 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -3130,6 +3130,44 @@ def test_xsrf_httponly(self): self.assertTrue(abs((expires - header_expires).total_seconds()) < 10) +class CheckSameOriginTest(SimpleHandlerTestCase): + class Handler(RequestHandler): + def post(self): + self.write("ok") + + def get_app_kwargs(self): + return dict(check_fetch_header=True, check_origin=self.get_url("")) + + def _post(self, headers): + return self.fetch("/", method="POST", body="x=1", headers=headers) + + def test_sec_fetch_site_success(self): + response = self._post({"Sec-Fetch-Site": "same-origin"}) + self.assertEqual(response.code, 200) + + def test_sec_fetch_site_fail(self): + with ExpectLog(gen_log, ".*Cross-origin request"): + response = self._post({"Sec-Fetch-Site": "cross-site"}) + self.assertEqual(response.code, 403) + + def test_fallback_success(self): + response = self._post({"Origin": self.get_url("")}) + self.assertEqual(response.code, 200) + + def test_fallback_referrer_success(self): + response = self._post({"Referrer": self.get_url("/foo/bar")}) + self.assertEqual(response.code, 200) + + def test_fallback_fail(self): + with ExpectLog(gen_log, ".*Cross-origin request"): + response = self._post({"Origin": "https://evil.example.com/"}) + self.assertEqual(response.code, 403) + + def test_fallback_no_origin(self): + response = self._post({}) + self.assertEqual(response.code, 200) + + class FinishExceptionTest(SimpleHandlerTestCase): class Handler(RequestHandler): def get(self): diff --git a/tornado/web.py b/tornado/web.py index 2351afdbe..7c70566de 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -1690,6 +1690,28 @@ def xsrf_form_html(self) -> str: + '"/>' ) + def check_fetch_header(self) -> bool: + """Verify that non-safe methods come from a same-origin request""" + if (sfs := self.request.headers.get("Sec-Fetch-Site")) is not None: + # All major browsers send the Sec-Fetch-Site header since ~2023 + # for 'potentially trustworthy' URLs (roughly, HTTPS or localhost) + if sfs not in ("same-origin", "none"): + raise HTTPError(403, "Cross-origin request with unsafe method") + return True + return False + + def check_request_origin(self) -> None: + # Fallback: The Origin or Referrer header gives the domain + # the request came from, Host should tell us where we're running. + headers = self.request.headers + src_origin = headers.get("Origin") or headers.get("Referrer") + if src_origin is None: + return # Probably non-browser request + src_scheme, src_netloc = urllib.parse.urlsplit(src_origin)[:2] + target_origin = self.application.settings["check_origin"] + if f"{src_scheme}://{src_netloc}" != target_origin: + raise HTTPError(403, "Cross-origin request with unsafe method") + def static_url( self, path: str, include_host: Optional[bool] = None, **kwargs: Any ) -> str: @@ -1826,12 +1848,13 @@ async def _execute( } # If XSRF cookies are turned on, reject form submissions without # the proper cookie - if self.request.method not in ( - "GET", - "HEAD", - "OPTIONS", - ) and self.application.settings.get("xsrf_cookies"): - self.check_xsrf_cookie() + if self.request.method not in ("GET", "HEAD", "OPTIONS"): + if self.application.settings.get("xsrf_cookies"): + self.check_xsrf_cookie() + if self.application.settings.get("check_fetch_header"): + checked = self.check_fetch_header() + if not checked and self.application.settings.get("check_origin"): + self.check_request_origin() result = self.prepare() if result is not None: