| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329 | import hashlibimport hmacimport loggingimport osfrom urllib.parse import urlparsefrom flask import Blueprintfrom flask import current_appfrom flask import gfrom flask import requestfrom flask import sessionfrom itsdangerous import BadDatafrom itsdangerous import SignatureExpiredfrom itsdangerous import URLSafeTimedSerializerfrom werkzeug.exceptions import BadRequestfrom wtforms import ValidationErrorfrom wtforms.csrf.core import CSRF__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")logger = logging.getLogger(__name__)def generate_csrf(secret_key=None, token_key=None):    """Generate a CSRF token. The token is cached for a request, so multiple    calls to this function will generate the same token.    During testing, it might be useful to access the signed token in    ``g.csrf_token`` and the raw token in ``session['csrf_token']``.    :param secret_key: Used to securely sign the token. Default is        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.    :param token_key: Key where token is stored in session for comparison.        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.    """    secret_key = _get_config(        secret_key,        "WTF_CSRF_SECRET_KEY",        current_app.secret_key,        message="A secret key is required to use CSRF.",    )    field_name = _get_config(        token_key,        "WTF_CSRF_FIELD_NAME",        "csrf_token",        message="A field name is required to use CSRF.",    )    if field_name not in g:        s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")        if field_name not in session:            session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()        try:            token = s.dumps(session[field_name])        except TypeError:            session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()            token = s.dumps(session[field_name])        setattr(g, field_name, token)    return g.get(field_name)def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):    """Check if the given data is a valid CSRF token. This compares the given    signed token to the one stored in the session.    :param data: The signed CSRF token to be checked.    :param secret_key: Used to securely sign the token. Default is        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.    :param time_limit: Number of seconds that the token is valid. Default is        ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).    :param token_key: Key where token is stored in session for comparison.        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.    :raises ValidationError: Contains the reason that validation failed.    .. versionchanged:: 0.14        Raises ``ValidationError`` with a specific error message rather than        returning ``True`` or ``False``.    """    secret_key = _get_config(        secret_key,        "WTF_CSRF_SECRET_KEY",        current_app.secret_key,        message="A secret key is required to use CSRF.",    )    field_name = _get_config(        token_key,        "WTF_CSRF_FIELD_NAME",        "csrf_token",        message="A field name is required to use CSRF.",    )    time_limit = _get_config(time_limit, "WTF_CSRF_TIME_LIMIT", 3600, required=False)    if not data:        raise ValidationError("The CSRF token is missing.")    if field_name not in session:        raise ValidationError("The CSRF session token is missing.")    s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token")    try:        token = s.loads(data, max_age=time_limit)    except SignatureExpired as e:        raise ValidationError("The CSRF token has expired.") from e    except BadData as e:        raise ValidationError("The CSRF token is invalid.") from e    if not hmac.compare_digest(session[field_name], token):        raise ValidationError("The CSRF tokens do not match.")def _get_config(    value, config_name, default=None, required=True, message="CSRF is not configured."):    """Find config value based on provided value, Flask config, and default    value.    :param value: already provided config value    :param config_name: Flask ``config`` key    :param default: default value if not provided or configured    :param required: whether the value must not be ``None``    :param message: error message if required config is not found    :raises KeyError: if required config is not found    """    if value is None:        value = current_app.config.get(config_name, default)    if required and value is None:        raise RuntimeError(message)    return valueclass _FlaskFormCSRF(CSRF):    def setup_form(self, form):        self.meta = form.meta        return super().setup_form(form)    def generate_csrf_token(self, csrf_token_field):        return generate_csrf(            secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name        )    def validate_csrf_token(self, form, field):        if g.get("csrf_valid", False):            # already validated by CSRFProtect            return        try:            validate_csrf(                field.data,                self.meta.csrf_secret,                self.meta.csrf_time_limit,                self.meta.csrf_field_name,            )        except ValidationError as e:            logger.info(e.args[0])            raiseclass CSRFProtect:    """Enable CSRF protection globally for a Flask app.    ::        app = Flask(__name__)        csrf = CSRFProtect(app)    Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``    header sent with JavaScript requests. Render the token in templates using    ``{{ csrf_token() }}``.    See the :ref:`csrf` documentation.    """    def __init__(self, app=None):        self._exempt_views = set()        self._exempt_blueprints = set()        if app:            self.init_app(app)    def init_app(self, app):        app.extensions["csrf"] = self        app.config.setdefault("WTF_CSRF_ENABLED", True)        app.config.setdefault("WTF_CSRF_CHECK_DEFAULT", True)        app.config["WTF_CSRF_METHODS"] = set(            app.config.get("WTF_CSRF_METHODS", ["POST", "PUT", "PATCH", "DELETE"])        )        app.config.setdefault("WTF_CSRF_FIELD_NAME", "csrf_token")        app.config.setdefault("WTF_CSRF_HEADERS", ["X-CSRFToken", "X-CSRF-Token"])        app.config.setdefault("WTF_CSRF_TIME_LIMIT", 3600)        app.config.setdefault("WTF_CSRF_SSL_STRICT", True)        app.jinja_env.globals["csrf_token"] = generate_csrf        app.context_processor(lambda: {"csrf_token": generate_csrf})        @app.before_request        def csrf_protect():            if not app.config["WTF_CSRF_ENABLED"]:                return            if not app.config["WTF_CSRF_CHECK_DEFAULT"]:                return            if request.method not in app.config["WTF_CSRF_METHODS"]:                return            if not request.endpoint:                return            if app.blueprints.get(request.blueprint) in self._exempt_blueprints:                return            view = app.view_functions.get(request.endpoint)            dest = f"{view.__module__}.{view.__name__}"            if dest in self._exempt_views:                return            self.protect()    def _get_csrf_token(self):        # find the token in the form data        field_name = current_app.config["WTF_CSRF_FIELD_NAME"]        base_token = request.form.get(field_name)        if base_token:            return base_token        # if the form has a prefix, the name will be {prefix}-csrf_token        for key in request.form:            if key.endswith(field_name):                csrf_token = request.form[key]                if csrf_token:                    return csrf_token        # find the token in the headers        for header_name in current_app.config["WTF_CSRF_HEADERS"]:            csrf_token = request.headers.get(header_name)            if csrf_token:                return csrf_token        return None    def protect(self):        if request.method not in current_app.config["WTF_CSRF_METHODS"]:            return        try:            validate_csrf(self._get_csrf_token())        except ValidationError as e:            logger.info(e.args[0])            self._error_response(e.args[0])        if request.is_secure and current_app.config["WTF_CSRF_SSL_STRICT"]:            if not request.referrer:                self._error_response("The referrer header is missing.")            good_referrer = f"https://{request.host}/"            if not same_origin(request.referrer, good_referrer):                self._error_response("The referrer does not match the host.")        g.csrf_valid = True  # mark this request as CSRF valid    def exempt(self, view):        """Mark a view or blueprint to be excluded from CSRF protection.        ::            @app.route('/some-view', methods=['POST'])            @csrf.exempt            def some_view():                ...        ::            bp = Blueprint(...)            csrf.exempt(bp)        """        if isinstance(view, Blueprint):            self._exempt_blueprints.add(view)            return view        if isinstance(view, str):            view_location = view        else:            view_location = ".".join((view.__module__, view.__name__))        self._exempt_views.add(view_location)        return view    def _error_response(self, reason):        raise CSRFError(reason)class CSRFError(BadRequest):    """Raise if the client sends invalid CSRF data with the request.    Generates a 400 Bad Request response with the failure reason by default.    Customize the response by registering a handler with    :meth:`flask.Flask.errorhandler`.    """    description = "CSRF validation failed."def same_origin(current_uri, compare_uri):    current = urlparse(current_uri)    compare = urlparse(compare_uri)    return (        current.scheme == compare.scheme        and current.hostname == compare.hostname        and current.port == compare.port    )
 |