csrf.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import hashlib
  2. import hmac
  3. import logging
  4. import os
  5. from urllib.parse import urlparse
  6. from flask import Blueprint
  7. from flask import current_app
  8. from flask import g
  9. from flask import request
  10. from flask import session
  11. from itsdangerous import BadData
  12. from itsdangerous import SignatureExpired
  13. from itsdangerous import URLSafeTimedSerializer
  14. from werkzeug.exceptions import BadRequest
  15. from wtforms import ValidationError
  16. from wtforms.csrf.core import CSRF
  17. __all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
  18. logger = logging.getLogger(__name__)
  19. def generate_csrf(secret_key=None, token_key=None):
  20. """Generate a CSRF token. The token is cached for a request, so multiple
  21. calls to this function will generate the same token.
  22. During testing, it might be useful to access the signed token in
  23. ``g.csrf_token`` and the raw token in ``session['csrf_token']``.
  24. :param secret_key: Used to securely sign the token. Default is
  25. ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
  26. :param token_key: Key where token is stored in session for comparison.
  27. Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
  28. """
  29. secret_key = _get_config(
  30. secret_key,
  31. "WTF_CSRF_SECRET_KEY",
  32. current_app.secret_key,
  33. message="A secret key is required to use CSRF.",
  34. )
  35. field_name = _get_config(
  36. token_key,
  37. "WTF_CSRF_FIELD_NAME",
  38. "csrf_token",
  39. message="A field name is required to use CSRF.",
  40. )
  41. if field_name not in g:
  42. s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
  43. if field_name not in session:
  44. session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
  45. try:
  46. token = s.dumps(session[field_name])
  47. except TypeError:
  48. session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
  49. token = s.dumps(session[field_name])
  50. setattr(g, field_name, token)
  51. return g.get(field_name)
  52. def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
  53. """Check if the given data is a valid CSRF token. This compares the given
  54. signed token to the one stored in the session.
  55. :param data: The signed CSRF token to be checked.
  56. :param secret_key: Used to securely sign the token. Default is
  57. ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
  58. :param time_limit: Number of seconds that the token is valid. Default is
  59. ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
  60. :param token_key: Key where token is stored in session for comparison.
  61. Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
  62. :raises ValidationError: Contains the reason that validation failed.
  63. .. versionchanged:: 0.14
  64. Raises ``ValidationError`` with a specific error message rather than
  65. returning ``True`` or ``False``.
  66. """
  67. secret_key = _get_config(
  68. secret_key,
  69. "WTF_CSRF_SECRET_KEY",
  70. current_app.secret_key,
  71. message="A secret key is required to use CSRF.",
  72. )
  73. field_name = _get_config(
  74. token_key,
  75. "WTF_CSRF_FIELD_NAME",
  76. "csrf_token",
  77. message="A field name is required to use CSRF.",
  78. )
  79. time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
  80. if not data:
  81. raise ValidationError("The CSRF token is missing.")
  82. if field_name not in session:
  83. raise ValidationError("The CSRF session token is missing.")
  84. s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
  85. try:
  86. token = s.loads(data, max_age=time_limit)
  87. except SignatureExpired as e:
  88. raise ValidationError("The CSRF token has expired.") from e
  89. except BadData as e:
  90. raise ValidationError("The CSRF token is invalid.") from e
  91. if not hmac.compare_digest(session[field_name], token):
  92. raise ValidationError("The CSRF tokens do not match.")
  93. def _get_config(
  94. value, config_name, default=None, required=True, message="CSRF is not configured."
  95. ):
  96. """Find config value based on provided value, Flask config, and default
  97. value.
  98. :param value: already provided config value
  99. :param config_name: Flask ``config`` key
  100. :param default: default value if not provided or configured
  101. :param required: whether the value must not be ``None``
  102. :param message: error message if required config is not found
  103. :raises KeyError: if required config is not found
  104. """
  105. if value is None:
  106. value = current_app.config.get(config_name, default)
  107. if required and value is None:
  108. raise RuntimeError(message)
  109. return value
  110. class _FlaskFormCSRF(CSRF):
  111. def setup_form(self, form):
  112. self.meta = form.meta
  113. return super().setup_form(form)
  114. def generate_csrf_token(self, csrf_token_field):
  115. return generate_csrf(
  116. secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
  117. )
  118. def validate_csrf_token(self, form, field):
  119. if g.get("csrf_valid", False):
  120. # already validated by CSRFProtect
  121. return
  122. try:
  123. validate_csrf(
  124. field.data,
  125. self.meta.csrf_secret,
  126. self.meta.csrf_time_limit,
  127. self.meta.csrf_field_name,
  128. )
  129. except ValidationError as e:
  130. logger.info(e.args[0])
  131. raise
  132. class CSRFProtect:
  133. """Enable CSRF protection globally for a Flask app.
  134. ::
  135. app = Flask(__name__)
  136. csrf = CSRFProtect(app)
  137. Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
  138. header sent with JavaScript requests. Render the token in templates using
  139. ``{{ csrf_token() }}``.
  140. See the :ref:`csrf` documentation.
  141. """
  142. def __init__(self, app=None):
  143. self._exempt_views = set()
  144. self._exempt_blueprints = set()
  145. if app:
  146. self.init_app(app)
  147. def init_app(self, app):
  148. app.extensions["csrf"] = self
  149. app.config.setdefault("WTF_CSRF_ENABLED", True)
  150. app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
  151. app.config["WTF_CSRF_METHODS"] = set(
  152. app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
  153. )
  154. app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
  155. app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
  156. app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
  157. app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
  158. app.jinja_env.globals["csrf_token"] = generate_csrf
  159. app.context_processor(lambda: {"csrf_token": generate_csrf})
  160. @app.before_request
  161. def csrf_protect():
  162. if not app.config["WTF_CSRF_ENABLED"]:
  163. return
  164. if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
  165. return
  166. if request.method not in app.config["WTF_CSRF_METHODS"]:
  167. return
  168. if not request.endpoint:
  169. return
  170. if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
  171. return
  172. view = app.view_functions.get(request.endpoint)
  173. dest = f"{view.__module__}.{view.__name__}"
  174. if dest in self._exempt_views:
  175. return
  176. self.protect()
  177. def _get_csrf_token(self):
  178. # find the token in the form data
  179. field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
  180. base_token = request.form.get(field_name)
  181. if base_token:
  182. return base_token
  183. # if the form has a prefix, the name will be {prefix}-csrf_token
  184. for key in request.form:
  185. if key.endswith(field_name):
  186. csrf_token = request.form[key]
  187. if csrf_token:
  188. return csrf_token
  189. # find the token in the headers
  190. for header_name in current_app.config["WTF_CSRF_HEADERS"]:
  191. csrf_token = request.headers.get(header_name)
  192. if csrf_token:
  193. return csrf_token
  194. return None
  195. def protect(self):
  196. if request.method not in current_app.config["WTF_CSRF_METHODS"]:
  197. return
  198. try:
  199. validate_csrf(self._get_csrf_token())
  200. except ValidationError as e:
  201. logger.info(e.args[0])
  202. self._error_response(e.args[0])
  203. if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
  204. if not request.referrer:
  205. self._error_response("The referrer header is missing.")
  206. good_referrer = f"https://{request.host}/"
  207. if not same_origin(request.referrer, good_referrer):
  208. self._error_response("The referrer does not match the host.")
  209. g.csrf_valid = True # mark this request as CSRF valid
  210. def exempt(self, view):
  211. """Mark a view or blueprint to be excluded from CSRF protection.
  212. ::
  213. @app.route('/some-view', methods=['POST'])
  214. @csrf.exempt
  215. def some_view():
  216. ...
  217. ::
  218. bp = Blueprint(...)
  219. csrf.exempt(bp)
  220. """
  221. if isinstance(view, Blueprint):
  222. self._exempt_blueprints.add(view)
  223. return view
  224. if isinstance(view, str):
  225. view_location = view
  226. else:
  227. view_location = ".".join((view.__module__, view.__name__))
  228. self._exempt_views.add(view_location)
  229. return view
  230. def _error_response(self, reason):
  231. raise CSRFError(reason)
  232. class CSRFError(BadRequest):
  233. """Raise if the client sends invalid CSRF data with the request.
  234. Generates a 400 Bad Request response with the failure reason by default.
  235. Customize the response by registering a handler with
  236. :meth:`flask.Flask.errorhandler`.
  237. """
  238. description = "CSRF validation failed."
  239. def same_origin(current_uri, compare_uri):
  240. current = urlparse(current_uri)
  241. compare = urlparse(compare_uri)
  242. return (
  243. current.scheme == compare.scheme
  244. and current.hostname == compare.hostname
  245. and current.port == compare.port
  246. )