""" A provided CSRF implementation which puts CSRF data in a session. This can be used fairly comfortably with many `request.session` type objects, including the Werkzeug/Flask session store, Django sessions, and potentially other similar objects which use a dict-like API for storing session keys. The basic concept is a randomly generated value is stored in the user's session, and an hmac-sha1 of it (along with an optional expiration time, for extra security) is used as the value of the csrf_token. If this token validates with the hmac of the random value + expiration time, and the expiration time is not passed, the CSRF validation will pass. """ import hmac import os from datetime import datetime from datetime import timedelta from hashlib import sha1 from ..validators import ValidationError from .core import CSRF __all__ = ("SessionCSRF",) class SessionCSRF(CSRF): TIME_FORMAT = "%Y%m%d%H%M%S" def setup_form(self, form): self.form_meta = form.meta return super().setup_form(form) def generate_csrf_token(self, csrf_token_field): meta = self.form_meta if meta.csrf_secret is None: raise Exception( "must set `csrf_secret` on class Meta for SessionCSRF to work" ) if meta.csrf_context is None: raise TypeError("Must provide a session-like object as csrf context") session = self.session if "csrf" not in session: session["csrf"] = sha1(os.urandom(64)).hexdigest() if self.time_limit: expires = (self.now() + self.time_limit).strftime(self.TIME_FORMAT) csrf_build = "{}{}".format(session["csrf"], expires) else: expires = "" csrf_build = session["csrf"] hmac_csrf = hmac.new( meta.csrf_secret, csrf_build.encode("utf8"), digestmod=sha1 ) return f"{expires}##{hmac_csrf.hexdigest()}" def validate_csrf_token(self, form, field): meta = self.form_meta if not field.data or "##" not in field.data: raise ValidationError(field.gettext("CSRF token missing.")) expires, hmac_csrf = field.data.split("##", 1) check_val = (self.session["csrf"] + expires).encode("utf8") hmac_compare = hmac.new(meta.csrf_secret, check_val, digestmod=sha1) if hmac_compare.hexdigest() != hmac_csrf: raise ValidationError(field.gettext("CSRF failed.")) if self.time_limit: now_formatted = self.now().strftime(self.TIME_FORMAT) if now_formatted > expires: raise ValidationError(field.gettext("CSRF token expired.")) def now(self): """ Get the current time. Used for test mocking/overriding mainly. """ return datetime.now() @property def time_limit(self): return getattr(self.form_meta, "csrf_time_limit", timedelta(minutes=30)) @property def session(self): return getattr( self.form_meta.csrf_context, "session", self.form_meta.csrf_context )