import sys import unittest from contextlib import contextmanager from functools import wraps from pathlib import Path from django.conf import settings from django.test import LiveServerTestCase, override_settings, tag from django.utils.functional import classproperty from django.utils.module_loading import import_string from django.utils.text import capfirst class SeleniumTestCaseBase(type(LiveServerTestCase)): # List of browsers to dynamically create test classes for. browsers = [] # A selenium hub URL to test against. selenium_hub = None # The external host Selenium Hub can reach. external_host = None # Sentinel value to differentiate browser-specific instances. browser = None # Run browsers in headless mode. headless = False def __new__(cls, name, bases, attrs): """ Dynamically create new classes and add them to the test module when multiple browsers specs are provided (e.g. --selenium=firefox,chrome). """ test_class = super().__new__(cls, name, bases, attrs) # If the test class is either browser-specific or a test base, return it. if test_class.browser or not any( name.startswith("test") and callable(value) for name, value in attrs.items() ): return test_class elif test_class.browsers: # Reuse the created test class to make it browser-specific. # We can't rename it to include the browser name or create a # subclass like we do with the remaining browsers as it would # either duplicate tests or prevent pickling of its instances. first_browser = test_class.browsers[0] test_class.browser = first_browser # Listen on an external interface if using a selenium hub. host = test_class.host if not test_class.selenium_hub else "0.0.0.0" test_class.host = host test_class.external_host = cls.external_host # Create subclasses for each of the remaining browsers and expose # them through the test's module namespace. module = sys.modules[test_class.__module__] for browser in test_class.browsers[1:]: browser_test_class = cls.__new__( cls, "%s%s" % (capfirst(browser), name), (test_class,), { "browser": browser, "host": host, "external_host": cls.external_host, "__module__": test_class.__module__, }, ) setattr(module, browser_test_class.__name__, browser_test_class) return test_class # If no browsers were specified, skip this class (it'll still be discovered). return unittest.skip("No browsers specified.")(test_class) @classmethod def import_webdriver(cls, browser): return import_string("selenium.webdriver.%s.webdriver.WebDriver" % browser) @classmethod def import_options(cls, browser): return import_string("selenium.webdriver.%s.options.Options" % browser) @classmethod def get_capability(cls, browser): from selenium.webdriver.common.desired_capabilities import DesiredCapabilities return getattr(DesiredCapabilities, browser.upper()) def create_options(self): options = self.import_options(self.browser)() if self.headless: match self.browser: case "chrome" | "edge": options.add_argument("--headless=new") case "firefox": options.add_argument("-headless") return options def create_webdriver(self): options = self.create_options() if self.selenium_hub: from selenium import webdriver for key, value in self.get_capability(self.browser).items(): options.set_capability(key, value) return webdriver.Remote(command_executor=self.selenium_hub, options=options) return self.import_webdriver(self.browser)(options=options) class ChangeWindowSize: def __init__(self, width, height, selenium): self.selenium = selenium self.new_size = (width, height) def __enter__(self): self.old_size = self.selenium.get_window_size() self.selenium.set_window_size(*self.new_size) return self def __exit__(self, exc_type, exc_value, traceback): self.selenium.set_window_size(self.old_size["width"], self.old_size["height"]) @tag("selenium") class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase): implicit_wait = 10 external_host = None screenshots = False @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if not cls.screenshots: return for name, func in list(cls.__dict__.items()): if not hasattr(func, "_screenshot_cases"): continue # Remove the main test. delattr(cls, name) # Add separate tests for each screenshot type. for screenshot_case in getattr(func, "_screenshot_cases"): @wraps(func) def test(self, *args, _func=func, _case=screenshot_case, **kwargs): with getattr(self, _case)(): return _func(self, *args, **kwargs) test.__name__ = f"{name}_{screenshot_case}" test.__qualname__ = f"{test.__qualname__}_{screenshot_case}" test._screenshot_name = name test._screenshot_case = screenshot_case setattr(cls, test.__name__, test) @classproperty def live_server_url(cls): return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port) @classproperty def allowed_host(cls): return cls.external_host or cls.host @classmethod def setUpClass(cls): cls.selenium = cls.create_webdriver() cls.selenium.implicitly_wait(cls.implicit_wait) super().setUpClass() cls.addClassCleanup(cls._quit_selenium) @contextmanager def desktop_size(self): with ChangeWindowSize(1280, 720, self.selenium): yield @contextmanager def small_screen_size(self): with ChangeWindowSize(1024, 768, self.selenium): yield @contextmanager def mobile_size(self): with ChangeWindowSize(360, 800, self.selenium): yield @contextmanager def rtl(self): with self.desktop_size(): with override_settings(LANGUAGE_CODE=settings.LANGUAGES_BIDI[-1]): yield @contextmanager def dark(self): # Navigate to a page before executing a script. self.selenium.get(self.live_server_url) self.selenium.execute_script("localStorage.setItem('theme', 'dark');") with self.desktop_size(): try: yield finally: self.selenium.execute_script("localStorage.removeItem('theme');") def set_emulated_media(self, *, media=None, features=None): if self.browser not in {"chrome", "edge"}: self.skipTest( "Emulation.setEmulatedMedia is only supported on Chromium and " "Chrome-based browsers. See https://chromedevtools.github.io/devtools-" "protocol/1-3/Emulation/#method-setEmulatedMedia for more details." ) params = {} if media is not None: params["media"] = media if features is not None: params["features"] = features # Not using .execute_cdp_cmd() as it isn't supported by the remote web driver # when using --selenium-hub. self.selenium.execute( driver_command="executeCdpCommand", params={"cmd": "Emulation.setEmulatedMedia", "params": params}, ) @contextmanager def high_contrast(self): self.set_emulated_media(features=[{"name": "forced-colors", "value": "active"}]) with self.desktop_size(): try: yield finally: self.set_emulated_media( features=[{"name": "forced-colors", "value": "none"}] ) def take_screenshot(self, name): if not self.screenshots: return test = getattr(self, self._testMethodName) filename = f"{test._screenshot_name}--{name}--{test._screenshot_case}.png" path = Path.cwd() / "screenshots" / filename path.parent.mkdir(exist_ok=True, parents=True) self.selenium.save_screenshot(path) @classmethod def _quit_selenium(cls): # quit() the WebDriver before attempting to terminate and join the # single-threaded LiveServerThread to avoid a dead lock if the browser # kept a connection alive. if hasattr(cls, "selenium"): cls.selenium.quit() @contextmanager def disable_implicit_wait(self): """Disable the default implicit wait.""" self.selenium.implicitly_wait(0) try: yield finally: self.selenium.implicitly_wait(self.implicit_wait) def screenshot_cases(method_names): if isinstance(method_names, str): method_names = method_names.split(",") def wrapper(func): func._screenshot_cases = method_names setattr(func, "tags", {"screenshot"}.union(getattr(func, "tags", set()))) return func return wrapper