1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- """
- A provided CSRF implementation which puts CSRF data in a session.
- This can be used fairly comfortably with many `request.session` type
- objects, including the Werkzeug/Flask session store, Django sessions, and
- potentially other similar objects which use a dict-like API for storing
- session keys.
- The basic concept is a randomly generated value is stored in the user's
- session, and an hmac-sha1 of it (along with an optional expiration time,
- for extra security) is used as the value of the csrf_token. If this token
- validates with the hmac of the random value + expiration time, and the
- expiration time is not passed, the CSRF validation will pass.
- """
- import hmac
- import os
- from datetime import datetime
- from datetime import timedelta
- from hashlib import sha1
- from ..validators import ValidationError
- from .core import CSRF
- __all__ = ("SessionCSRF",)
- class SessionCSRF(CSRF):
- TIME_FORMAT = "%Y%m%d%H%M%S"
- def setup_form(self, form):
- self.form_meta = form.meta
- return super().setup_form(form)
- def generate_csrf_token(self, csrf_token_field):
- meta = self.form_meta
- if meta.csrf_secret is None:
- raise Exception(
- "must set `csrf_secret` on class Meta for SessionCSRF to work"
- )
- if meta.csrf_context is None:
- raise TypeError("Must provide a session-like object as csrf context")
- session = self.session
- if "csrf" not in session:
- session["csrf"] = sha1(os.urandom(64)).hexdigest()
- if self.time_limit:
- expires = (self.now() + self.time_limit).strftime(self.TIME_FORMAT)
- csrf_build = "{}{}".format(session["csrf"], expires)
- else:
- expires = ""
- csrf_build = session["csrf"]
- hmac_csrf = hmac.new(
- meta.csrf_secret, csrf_build.encode("utf8"), digestmod=sha1
- )
- return f"{expires}##{hmac_csrf.hexdigest()}"
- def validate_csrf_token(self, form, field):
- meta = self.form_meta
- if not field.data or "##" not in field.data:
- raise ValidationError(field.gettext("CSRF token missing."))
- expires, hmac_csrf = field.data.split("##", 1)
- check_val = (self.session["csrf"] + expires).encode("utf8")
- hmac_compare = hmac.new(meta.csrf_secret, check_val, digestmod=sha1)
- if hmac_compare.hexdigest() != hmac_csrf:
- raise ValidationError(field.gettext("CSRF failed."))
- if self.time_limit:
- now_formatted = self.now().strftime(self.TIME_FORMAT)
- if now_formatted > expires:
- raise ValidationError(field.gettext("CSRF token expired."))
- def now(self):
- """
- Get the current time. Used for test mocking/overriding mainly.
- """
- return datetime.now()
- @property
- def time_limit(self):
- return getattr(self.form_meta, "csrf_time_limit", timedelta(minutes=30))
- @property
- def session(self):
- return getattr(
- self.form_meta.csrf_context, "session", self.form_meta.csrf_context
- )
|