socketify.py/src/socketify/ssgi.py

292 wiersze
12 KiB
Python

from socketify import App, CompressOptions, OpCode
from typing import Union, Callable, Awaitable, Optional
import inspect
from queue import SimpleQueue
class SSGIHttpResponse:
extensions: Optional[dict] = None # extensions for http
def __init__(self, res, req, extensions = None):
self.res = res
self.req = req
self._need_cork = False
self._received_queue = None
self._miss_receive_queue = None
self.extensions = extensions
# if payload is None, request ends without body
# if has_more is True, data is written but connection will not end
def send(self, payload: Union[str, bytes, bytearray, memoryview, None], has_more: Optional[bool] = False):
if has_more:
self.res.write(payload)
else:
self.res.end(payload)
# send chunk of data, can be used to perform with less backpressure than using send
# total_size is the sum of all lengths in bytes of all chunks to be sended
# connection will end when total_size is met
# returns tuple(bool, bool) first bool represents if the chunk is succefully sended, the second if the connection has ended
def send_chunk(self, chunk: Union[bytes, bytearray, memoryview], total_size: int) -> Awaitable:
return self.res.send_chunk(chunk, total_size)
# send status code
def send_status(self, status_code: Optional[Union[int, str]] = '200 OK'):
self.res.write_status(status_code)
# send headers to the http response
def send_headers(self, headers):
for name, value in headers:
self.res.write_header(name, value)
# ensure async call for the handler, passing any arguments to it
def run_async(self, handler: Awaitable, *arguments) -> Awaitable:
self.req.get_headers() # preserve headers
return self.res.run_async(handler(*arguments))
# get an all data
# returns an BytesIO() or None if no payload is available
def get_data(self) -> Awaitable:
if self.res.get_header("content-length", False) or self.res.get_header("transfer-encoding", False):
return self.res.get_data()
#return empty result
future = self.res.loop.create_future()
future.set_result(None)
return future
# get an chunk of data (chunk size is decided by the Server implementation)
# returns the bytes or None if no more chunks are sent
def get_chunk(self) -> Awaitable:
if not self._received_queue:
self._miss_receive_queue = SimpleQueue()
self._received_queue = SimpleQueue()
def on_data(res, chunk, is_end):
if not self._received_queue.empty():
future = self._received_queue.get(False)
future.set_result(chunk)
if not self._received_queue.empty() and is_end and chunk:
future = self._received_queue.get(False)
future.set_result(None)
return
else:
self._miss_receive_queue.put(chunk, False)
if is_end and chunk:
self._miss_receive_queue.put(None, False)
future = self.res.loop.create_future()
self._received_queue.put(future, False)
self.res.on_data(on_data)
return future
else:
future = self.res.loop.create_future()
if not self._miss_receive_queue.empty():
future.set_result(self._miss_receive_queue.get(False))
return future
self._received_queue.put(future, False)
return future
# on aborted event, called when the connection abort
def on_aborted(self, handler: Union[Awaitable, Callable], *arguments):
def on_aborted(res):
res.aborted = True
if inspect.iscoroutinefunction(handler):
res.run_async(handler(*arguments))
else:
handler(*arguments)
self.res.on_aborted(on_aborted)
class SSGIWebSocket:
status: int = 0 # 0 pending upgrade, 1 rejected, 2 closed, 3 accepted
extensions: Optional[dict] = None # extensions for websocket
def __init__(self, res, req, socket_context, ws, extensions = None):
self.res = res
self.req = req
self.status = 0
self.extensions = extensions
self._socket_context = socket_context
self._key = self.req.get_header("sec-websocket-key")
self._protocol = self.req.get_header("sec-websocket-protocol")
self._extensions = self.req.get_header("sec-websocket-extensions")
self._close_handler = None
self._receive_handler = None
self._need_cork = False
self._accept_future = None
# accept the connection upgrade
# can pass the protocol to accept if None is informed will use sec-websocket-protocol header if available
def accept(self, protocol: str = None) -> Awaitable:
if self.status == 0:
self._accept_future = self.res.loop.create_future()
upgrade_protocol = protocol if protocol else self._protocol
self.res.upgrade(self._key, upgrade_protocol if upgrade_protocol else "", self._extensions, self._socket_context, self)
return self._accept_future
future = self.res.loop.create_future()
future.set_result(False)
return future
# reject the connection upgrade, you can send status_code, payload and headers if you want, all optional
def reject(self, status_code: Optional[int] = 403, payload = None, headers = None) -> Awaitable:
future = self.res.loop.create_future()
if self.status < 1:
self.status = 1
if headers:
for name, value in headers:
self.res.write_header(name, value)
if not payload:
self.res.write_status(status_code).end_without_body()
else:
self.res.write_status(status_code).cork_end(payload)
future.set_result(True)
else:
future.set_result(False)
return future
# if returns an future, this can be awaited or not
def send(self, payload: Union[bytes, bytearray, memoryview]):
if self.status == 3:
if self._need_cork:
self.ws.cork_send(payload)
else:
self.ws.send(payload)
# close connection
def close(self, code: Optional[int] = 1000):
if self.status == 3:
self.ws.close()
return True
return False
# ensure async call for the handler, passing any arguments to it
def run_async(self, handler: Awaitable, *arguments) -> Awaitable:
self.req.get_headers()
self._need_cork = True
return self.res.run_async(handler(*arguments))
# on receive event, called when the socket disconnect
# passes ws: SSGIWebSocket, msg: Union[str, bytes, bytearray, memoryview, None], *arguments
def on_receive(self, handler: Union[Awaitable, Callable], *arguments):
def on_receive_handler(ws, message, opcode):
if inspect.iscoroutinefunction(handler):
ws.res.run_async(handler(ws, message, *arguments))
else:
handler(ws, message, *arguments)
self._receive_handler = on_receive_handler
# on close event, called when the socket disconnect
# passes ws: SSGIWebSocket, code: int and reason: Optional[str] = None, *arguments
def on_close(self, handler: Union[Awaitable, Callable], *arguments):
def on_close_handler(ws, code, message):
if inspect.iscoroutinefunction(handler):
ws.res.run_async(handler(ws, code, message, *arguments))
else:
handler(ws, code, message, *arguments)
self._close_handler = on_close_handler
class SSGI:
def __init__(self, app, options=None, request_response_factory_max_items=0, websocket_factory_max_itens=0):
self.server = App(options, request_response_factory_max_items, websocket_factory_max_itens)
self.SERVER_PORT = None
self.SERVER_HOST = ''
self.SERVER_SCHEME = 'https' if self.server.options else 'http'
self.SERVER_WS_SCHEME = 'wss' if self.server.options else 'ws'
self.SERVER_ADDRESS = ''
self.app = app
support = app.get_supported({ "ssgi": "1.0" })
http, middleware = support.get('http', (None, None))
websocket, ws_middleware = support.get('websocket', (None, None))
def ssgi(res, req):
response = SSGIHttpResponse(res, req)
PATH_INFO = req.get_url()
# FULL_PATH_INFO = req.get_full_url()
METHOD = req.get_method()
QUERY_STRING = "" #FULL_PATH_INFO[len(PATH_INFO):]
# REMOTE_ADDRESS = res.get_remote_address()
def get_header(name = None):
if name:
return req.get_header(name)
else:
return req.get_headers()
# self.SERVER_SCHEME, self.SERVER_ADDRESS,
# self.SERVER_SCHEME, self.SERVER_ADDRESS,
if inspect.iscoroutinefunction(middleware):
req.get_headers() # preserve
res.run_async(middleware('http', METHOD, PATH_INFO, QUERY_STRING, get_header, response))
else:
middleware('http', METHOD, PATH_INFO, QUERY_STRING, get_header, response)
# if not response._responded:
# res.grab_aborted_handler()
if http == "ssgi" and middleware:
self.server.any("/*", ssgi)
def ws_upgrade(res, req, socket_context):
response = SSGIWebSocket(res, req, socket_context, None)
PATH_INFO = req.get_url()
FULL_PATH_INFO = req.get_full_url()
METHOD = req.get_method()
REMOTE_ADDRESS = req.get_remote_address()
def get_header(name = None):
if name:
return req.get_header(name)
else:
return req.get_headers()
if inspect.iscoroutinefunction(ws_middleware):
req.get_headers() # preserve
res.run_async(ws_middleware('websocket', self.SERVER_SCHEME, self.SERVER_ADDRESS, REMOTE_ADDRESS, METHOD, PATH_INFO, FULL_PATH_INFO[len(PATH_INFO):], get_header, response))
return
else:
ws_middleware('websocket', self.SERVER_WS_SCHEME, self.SERVER_HOST, REMOTE_ADDRESS, METHOD, PATH_INFO, FULL_PATH_INFO[len(PATH_INFO):], get_header, response)
# not accepted (async?)
if response.status == 0 and not response._accept_future:
res.grab_aborted_handler()
if websocket == "ssgi" and ws_middleware:
def ws_open(ws):
res = ws.get_user_data()
res.ws = ws
res.status = 3
res._accept_future.set_result(True)
def ws_message(ws, message, op):
res = ws.get_user_data()
if res._receive_handler:
res._receive_handler(res, message, op)
def ws_close(ws, code, message):
res = ws.get_user_data()
if res._close_handler:
res._close_handler(res, code, message)
self.server.ws("/*", {
"compression": CompressOptions.DISABLED,
"max_payload_length": 16 * 1024 * 1024,
"idle_timeout": 0,
"upgrade": ws_upgrade,
"open": ws_open,
"message": ws_message,
"close": ws_close
})
def listen(self, port_or_options, handler):
self.SERVER_PORT = port_or_options if isinstance(port_or_options, int) else port_or_options.port
self.SERVER_HOST = "0.0.0.0" if isinstance(port_or_options, int) else port_or_options.host
if self.SERVER_PORT:
self.SERVER_ADDRESS = f"{self.SERVER_HOST}:{self.SERVER_PORT}"
else:
self.SERVER_ADDRESS = self.SERVER_HOST
self.server.listen(port_or_options, handler)
return self
def run(self):
self.server.run()
return self