123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 |
- import hashlib
- import hmac
- import logging
- import os
- from urllib.parse import urlparse
- from flask import Blueprint
- from flask import current_app
- from flask import g
- from flask import request
- from flask import session
- from itsdangerous import BadData
- from itsdangerous import SignatureExpired
- from itsdangerous import URLSafeTimedSerializer
- from werkzeug.exceptions import BadRequest
- from wtforms import ValidationError
- from wtforms.csrf.core import CSRF
- __all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
- logger = logging.getLogger(__name__)
- def generate_csrf(secret_key=None, token_key=None):
- """Generate a CSRF token. The token is cached for a request, so multiple
- calls to this function will generate the same token.
- During testing, it might be useful to access the signed token in
- ``g.csrf_token`` and the raw token in ``session['csrf_token']``.
- :param secret_key: Used to securely sign the token. Default is
- ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
- :param token_key: Key where token is stored in session for comparison.
- Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
- """
- secret_key = _get_config(
- secret_key,
- "WTF_CSRF_SECRET_KEY",
- current_app.secret_key,
- message="A secret key is required to use CSRF.",
- )
- field_name = _get_config(
- token_key,
- "WTF_CSRF_FIELD_NAME",
- "csrf_token",
- message="A field name is required to use CSRF.",
- )
- if field_name not in g:
- s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
- if field_name not in session:
- session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
- try:
- token = s.dumps(session[field_name])
- except TypeError:
- session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
- token = s.dumps(session[field_name])
- setattr(g, field_name, token)
- return g.get(field_name)
- def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
- """Check if the given data is a valid CSRF token. This compares the given
- signed token to the one stored in the session.
- :param data: The signed CSRF token to be checked.
- :param secret_key: Used to securely sign the token. Default is
- ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
- :param time_limit: Number of seconds that the token is valid. Default is
- ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
- :param token_key: Key where token is stored in session for comparison.
- Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
- :raises ValidationError: Contains the reason that validation failed.
- .. versionchanged:: 0.14
- Raises ``ValidationError`` with a specific error message rather than
- returning ``True`` or ``False``.
- """
- secret_key = _get_config(
- secret_key,
- "WTF_CSRF_SECRET_KEY",
- current_app.secret_key,
- message="A secret key is required to use CSRF.",
- )
- field_name = _get_config(
- token_key,
- "WTF_CSRF_FIELD_NAME",
- "csrf_token",
- message="A field name is required to use CSRF.",
- )
- time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)
- if not data:
- raise ValidationError("The CSRF token is missing.")
- if field_name not in session:
- raise ValidationError("The CSRF session token is missing.")
- s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")
- try:
- token = s.loads(data, max_age=time_limit)
- except SignatureExpired as e:
- raise ValidationError("The CSRF token has expired.") from e
- except BadData as e:
- raise ValidationError("The CSRF token is invalid.") from e
- if not hmac.compare_digest(session[field_name], token):
- raise ValidationError("The CSRF tokens do not match.")
- def _get_config(
- value, config_name, default=None, required=True, message="CSRF is not configured."
- ):
- """Find config value based on provided value, Flask config, and default
- value.
- :param value: already provided config value
- :param config_name: Flask ``config`` key
- :param default: default value if not provided or configured
- :param required: whether the value must not be ``None``
- :param message: error message if required config is not found
- :raises KeyError: if required config is not found
- """
- if value is None:
- value = current_app.config.get(config_name, default)
- if required and value is None:
- raise RuntimeError(message)
- return value
- class _FlaskFormCSRF(CSRF):
- def setup_form(self, form):
- self.meta = form.meta
- return super().setup_form(form)
- def generate_csrf_token(self, csrf_token_field):
- return generate_csrf(
- secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
- )
- def validate_csrf_token(self, form, field):
- if g.get("csrf_valid", False):
- # already validated by CSRFProtect
- return
- try:
- validate_csrf(
- field.data,
- self.meta.csrf_secret,
- self.meta.csrf_time_limit,
- self.meta.csrf_field_name,
- )
- except ValidationError as e:
- logger.info(e.args[0])
- raise
- class CSRFProtect:
- """Enable CSRF protection globally for a Flask app.
- ::
- app = Flask(__name__)
- csrf = CSRFProtect(app)
- Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
- header sent with JavaScript requests. Render the token in templates using
- ``{{ csrf_token() }}``.
- See the :ref:`csrf` documentation.
- """
- def __init__(self, app=None):
- self._exempt_views = set()
- self._exempt_blueprints = set()
- if app:
- self.init_app(app)
- def init_app(self, app):
- app.extensions["csrf"] = self
- app.config.setdefault("WTF_CSRF_ENABLED", True)
- app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)
- app.config["WTF_CSRF_METHODS"] = set(
- app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])
- )
- app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")
- app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])
- app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)
- app.config.setdefault("WTF_CSRF_SSL_STRICT", True)
- app.jinja_env.globals["csrf_token"] = generate_csrf
- app.context_processor(lambda: {"csrf_token": generate_csrf})
- @app.before_request
- def csrf_protect():
- if not app.config["WTF_CSRF_ENABLED"]:
- return
- if not app.config["WTF_CSRF_CHECK_DEFAULT"]:
- return
- if request.method not in app.config["WTF_CSRF_METHODS"]:
- return
- if not request.endpoint:
- return
- if app.blueprints.get(request.blueprint) in self._exempt_blueprints:
- return
- view = app.view_functions.get(request.endpoint)
- dest = f"{view.__module__}.{view.__name__}"
- if dest in self._exempt_views:
- return
- self.protect()
- def _get_csrf_token(self):
- # find the token in the form data
- field_name = current_app.config["WTF_CSRF_FIELD_NAME"]
- base_token = request.form.get(field_name)
- if base_token:
- return base_token
- # if the form has a prefix, the name will be {prefix}-csrf_token
- for key in request.form:
- if key.endswith(field_name):
- csrf_token = request.form[key]
- if csrf_token:
- return csrf_token
- # find the token in the headers
- for header_name in current_app.config["WTF_CSRF_HEADERS"]:
- csrf_token = request.headers.get(header_name)
- if csrf_token:
- return csrf_token
- return None
- def protect(self):
- if request.method not in current_app.config["WTF_CSRF_METHODS"]:
- return
- try:
- validate_csrf(self._get_csrf_token())
- except ValidationError as e:
- logger.info(e.args[0])
- self._error_response(e.args[0])
- if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:
- if not request.referrer:
- self._error_response("The referrer header is missing.")
- good_referrer = f"https://{request.host}/"
- if not same_origin(request.referrer, good_referrer):
- self._error_response("The referrer does not match the host.")
- g.csrf_valid = True # mark this request as CSRF valid
- def exempt(self, view):
- """Mark a view or blueprint to be excluded from CSRF protection.
- ::
- @app.route('/some-view', methods=['POST'])
- @csrf.exempt
- def some_view():
- ...
- ::
- bp = Blueprint(...)
- csrf.exempt(bp)
- """
- if isinstance(view, Blueprint):
- self._exempt_blueprints.add(view)
- return view
- if isinstance(view, str):
- view_location = view
- else:
- view_location = ".".join((view.__module__, view.__name__))
- self._exempt_views.add(view_location)
- return view
- def _error_response(self, reason):
- raise CSRFError(reason)
- class CSRFError(BadRequest):
- """Raise if the client sends invalid CSRF data with the request.
- Generates a 400 Bad Request response with the failure reason by default.
- Customize the response by registering a handler with
- :meth:`flask.Flask.errorhandler`.
- """
- description = "CSRF validation failed."
- def same_origin(current_uri, compare_uri):
- current = urlparse(current_uri)
- compare = urlparse(compare_uri)
- return (
- current.scheme == compare.scheme
- and current.hostname == compare.hostname
- and current.port == compare.port
- )
|