wsgi.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from io import BytesIO
  2. from tempfile import SpooledTemporaryFile
  3. from asgiref.sync import AsyncToSync, sync_to_async
  4. class WsgiToAsgi:
  5. """
  6. Wraps a WSGI application to make it into an ASGI application.
  7. """
  8. def __init__(self, wsgi_application):
  9. self.wsgi_application = wsgi_application
  10. async def __call__(self, scope, receive, send):
  11. """
  12. ASGI application instantiation point.
  13. We return a new WsgiToAsgiInstance here with the WSGI app
  14. and the scope, ready to respond when it is __call__ed.
  15. """
  16. await WsgiToAsgiInstance(self.wsgi_application)(scope, receive, send)
  17. class WsgiToAsgiInstance:
  18. """
  19. Per-socket instance of a wrapped WSGI application
  20. """
  21. def __init__(self, wsgi_application):
  22. self.wsgi_application = wsgi_application
  23. self.response_started = False
  24. self.response_content_length = None
  25. async def __call__(self, scope, receive, send):
  26. if scope["type"] != "http":
  27. raise ValueError("WSGI wrapper received a non-HTTP scope")
  28. self.scope = scope
  29. with SpooledTemporaryFile(max_size=65536) as body:
  30. # Alright, wait for the http.request messages
  31. while True:
  32. message = await receive()
  33. if message["type"] != "http.request":
  34. raise ValueError("WSGI wrapper received a non-HTTP-request message")
  35. body.write(message.get("body", b""))
  36. if not message.get("more_body"):
  37. break
  38. body.seek(0)
  39. # Wrap send so it can be called from the subthread
  40. self.sync_send = AsyncToSync(send)
  41. # Call the WSGI app
  42. await self.run_wsgi_app(body)
  43. def build_environ(self, scope, body):
  44. """
  45. Builds a scope and request body into a WSGI environ object.
  46. """
  47. script_name = scope.get("root_path", "").encode("utf8").decode("latin1")
  48. path_info = scope["path"].encode("utf8").decode("latin1")
  49. if path_info.startswith(script_name):
  50. path_info = path_info[len(script_name) :]
  51. environ = {
  52. "REQUEST_METHOD": scope["method"],
  53. "SCRIPT_NAME": script_name,
  54. "PATH_INFO": path_info,
  55. "QUERY_STRING": scope["query_string"].decode("ascii"),
  56. "SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"],
  57. "wsgi.version": (1, 0),
  58. "wsgi.url_scheme": scope.get("scheme", "http"),
  59. "wsgi.input": body,
  60. "wsgi.errors": BytesIO(),
  61. "wsgi.multithread": True,
  62. "wsgi.multiprocess": True,
  63. "wsgi.run_once": False,
  64. }
  65. # Get server name and port - required in WSGI, not in ASGI
  66. if "server" in scope:
  67. environ["SERVER_NAME"] = scope["server"][0]
  68. environ["SERVER_PORT"] = str(scope["server"][1])
  69. else:
  70. environ["SERVER_NAME"] = "localhost"
  71. environ["SERVER_PORT"] = "80"
  72. if scope.get("client") is not None:
  73. environ["REMOTE_ADDR"] = scope["client"][0]
  74. # Go through headers and make them into environ entries
  75. for name, value in self.scope.get("headers", []):
  76. name = name.decode("latin1")
  77. if name == "content-length":
  78. corrected_name = "CONTENT_LENGTH"
  79. elif name == "content-type":
  80. corrected_name = "CONTENT_TYPE"
  81. else:
  82. corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
  83. # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case
  84. value = value.decode("latin1")
  85. if corrected_name in environ:
  86. value = environ[corrected_name] + "," + value
  87. environ[corrected_name] = value
  88. return environ
  89. def start_response(self, status, response_headers, exc_info=None):
  90. """
  91. WSGI start_response callable.
  92. """
  93. # Don't allow re-calling once response has begun
  94. if self.response_started:
  95. raise exc_info[1].with_traceback(exc_info[2])
  96. # Don't allow re-calling without exc_info
  97. if hasattr(self, "response_start") and exc_info is None:
  98. raise ValueError(
  99. "You cannot call start_response a second time without exc_info"
  100. )
  101. # Extract status code
  102. status_code, _ = status.split(" ", 1)
  103. status_code = int(status_code)
  104. # Extract headers
  105. headers = [
  106. (name.lower().encode("ascii"), value.encode("ascii"))
  107. for name, value in response_headers
  108. ]
  109. # Extract content-length
  110. self.response_content_length = None
  111. for name, value in response_headers:
  112. if name.lower() == "content-length":
  113. self.response_content_length = int(value)
  114. # Build and send response start message.
  115. self.response_start = {
  116. "type": "http.response.start",
  117. "status": status_code,
  118. "headers": headers,
  119. }
  120. @sync_to_async
  121. def run_wsgi_app(self, body):
  122. """
  123. Called in a subthread to run the WSGI app. We encapsulate like
  124. this so that the start_response callable is called in the same thread.
  125. """
  126. # Translate the scope and incoming request body into a WSGI environ
  127. environ = self.build_environ(self.scope, body)
  128. # Run the WSGI app
  129. bytes_sent = 0
  130. for output in self.wsgi_application(environ, self.start_response):
  131. # If this is the first response, include the response headers
  132. if not self.response_started:
  133. self.response_started = True
  134. self.sync_send(self.response_start)
  135. # If the application supplies a Content-Length header
  136. if self.response_content_length is not None:
  137. # The server should not transmit more bytes to the client than the header allows
  138. bytes_allowed = self.response_content_length - bytes_sent
  139. if len(output) > bytes_allowed:
  140. output = output[:bytes_allowed]
  141. self.sync_send(
  142. {"type": "http.response.body", "body": output, "more_body": True}
  143. )
  144. bytes_sent += len(output)
  145. # The server should stop iterating over the response when enough data has been sent
  146. if bytes_sent == self.response_content_length:
  147. break
  148. # Close connection
  149. if not self.response_started:
  150. self.response_started = True
  151. self.sync_send(self.response_start)
  152. self.sync_send({"type": "http.response.body"})