session.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. A provided CSRF implementation which puts CSRF data in a session.
  3. This can be used fairly comfortably with many `request.session` type
  4. objects, including the Werkzeug/Flask session store, Django sessions, and
  5. potentially other similar objects which use a dict-like API for storing
  6. session keys.
  7. The basic concept is a randomly generated value is stored in the user's
  8. session, and an hmac-sha1 of it (along with an optional expiration time,
  9. for extra security) is used as the value of the csrf_token. If this token
  10. validates with the hmac of the random value + expiration time, and the
  11. expiration time is not passed, the CSRF validation will pass.
  12. """
  13. import hmac
  14. import os
  15. from datetime import datetime
  16. from datetime import timedelta
  17. from hashlib import sha1
  18. from ..validators import ValidationError
  19. from .core import CSRF
  20. __all__ = ("SessionCSRF",)
  21. class SessionCSRF(CSRF):
  22. TIME_FORMAT = "%Y%m%d%H%M%S"
  23. def setup_form(self, form):
  24. self.form_meta = form.meta
  25. return super().setup_form(form)
  26. def generate_csrf_token(self, csrf_token_field):
  27. meta = self.form_meta
  28. if meta.csrf_secret is None:
  29. raise Exception(
  30. "must set `csrf_secret` on class Meta for SessionCSRF to work"
  31. )
  32. if meta.csrf_context is None:
  33. raise TypeError("Must provide a session-like object as csrf context")
  34. session = self.session
  35. if "csrf" not in session:
  36. session["csrf"] = sha1(os.urandom(64)).hexdigest()
  37. if self.time_limit:
  38. expires = (self.now() + self.time_limit).strftime(self.TIME_FORMAT)
  39. csrf_build = "{}{}".format(session["csrf"], expires)
  40. else:
  41. expires = ""
  42. csrf_build = session["csrf"]
  43. hmac_csrf = hmac.new(
  44. meta.csrf_secret, csrf_build.encode("utf8"), digestmod=sha1
  45. )
  46. return f"{expires}##{hmac_csrf.hexdigest()}"
  47. def validate_csrf_token(self, form, field):
  48. meta = self.form_meta
  49. if not field.data or "##" not in field.data:
  50. raise ValidationError(field.gettext("CSRF token missing."))
  51. expires, hmac_csrf = field.data.split("##", 1)
  52. check_val = (self.session["csrf"] + expires).encode("utf8")
  53. hmac_compare = hmac.new(meta.csrf_secret, check_val, digestmod=sha1)
  54. if hmac_compare.hexdigest() != hmac_csrf:
  55. raise ValidationError(field.gettext("CSRF failed."))
  56. if self.time_limit:
  57. now_formatted = self.now().strftime(self.TIME_FORMAT)
  58. if now_formatted > expires:
  59. raise ValidationError(field.gettext("CSRF token expired."))
  60. def now(self):
  61. """
  62. Get the current time. Used for test mocking/overriding mainly.
  63. """
  64. return datetime.now()
  65. @property
  66. def time_limit(self):
  67. return getattr(self.form_meta, "csrf_time_limit", timedelta(minutes=30))
  68. @property
  69. def session(self):
  70. return getattr(
  71. self.form_meta.csrf_context, "session", self.form_meta.csrf_context
  72. )