kopia lustrzana https://github.com/cirospaciari/socketify.py
292 wiersze
12 KiB
Python
292 wiersze
12 KiB
Python
from socketify import App, CompressOptions, OpCode
|
|
from typing import Union, Callable, Awaitable, Optional
|
|
import inspect
|
|
from queue import SimpleQueue
|
|
|
|
|
|
class SSGIHttpResponse:
|
|
extensions: Optional[dict] = None # extensions for http
|
|
|
|
def __init__(self, res, req, extensions = None):
|
|
self.res = res
|
|
self.req = req
|
|
self._need_cork = False
|
|
self._received_queue = None
|
|
self._miss_receive_queue = None
|
|
self.extensions = extensions
|
|
# if payload is None, request ends without body
|
|
# if has_more is True, data is written but connection will not end
|
|
def send(self, payload: Union[str, bytes, bytearray, memoryview, None], has_more: Optional[bool] = False):
|
|
if has_more:
|
|
self.res.write(payload)
|
|
else:
|
|
self.res.end(payload)
|
|
|
|
# send chunk of data, can be used to perform with less backpressure than using send
|
|
# total_size is the sum of all lengths in bytes of all chunks to be sended
|
|
# connection will end when total_size is met
|
|
# returns tuple(bool, bool) first bool represents if the chunk is succefully sended, the second if the connection has ended
|
|
def send_chunk(self, chunk: Union[bytes, bytearray, memoryview], total_size: int) -> Awaitable:
|
|
return self.res.send_chunk(chunk, total_size)
|
|
|
|
# send status code
|
|
def send_status(self, status_code: Optional[Union[int, str]] = '200 OK'):
|
|
self.res.write_status(status_code)
|
|
|
|
# send headers to the http response
|
|
def send_headers(self, headers):
|
|
for name, value in headers:
|
|
self.res.write_header(name, value)
|
|
|
|
# ensure async call for the handler, passing any arguments to it
|
|
def run_async(self, handler: Awaitable, *arguments) -> Awaitable:
|
|
self.req.get_headers() # preserve headers
|
|
return self.res.run_async(handler(*arguments))
|
|
|
|
# get an all data
|
|
# returns an BytesIO() or None if no payload is available
|
|
def get_data(self) -> Awaitable:
|
|
if self.res.get_header("content-length", False) or self.res.get_header("transfer-encoding", False):
|
|
return self.res.get_data()
|
|
|
|
#return empty result
|
|
future = self.res.loop.create_future()
|
|
future.set_result(None)
|
|
return future
|
|
|
|
# get an chunk of data (chunk size is decided by the Server implementation)
|
|
# returns the bytes or None if no more chunks are sent
|
|
def get_chunk(self) -> Awaitable:
|
|
if not self._received_queue:
|
|
self._miss_receive_queue = SimpleQueue()
|
|
self._received_queue = SimpleQueue()
|
|
def on_data(res, chunk, is_end):
|
|
if not self._received_queue.empty():
|
|
future = self._received_queue.get(False)
|
|
future.set_result(chunk)
|
|
if not self._received_queue.empty() and is_end and chunk:
|
|
future = self._received_queue.get(False)
|
|
future.set_result(None)
|
|
return
|
|
else:
|
|
self._miss_receive_queue.put(chunk, False)
|
|
|
|
if is_end and chunk:
|
|
self._miss_receive_queue.put(None, False)
|
|
|
|
future = self.res.loop.create_future()
|
|
self._received_queue.put(future, False)
|
|
self.res.on_data(on_data)
|
|
return future
|
|
else:
|
|
future = self.res.loop.create_future()
|
|
if not self._miss_receive_queue.empty():
|
|
future.set_result(self._miss_receive_queue.get(False))
|
|
return future
|
|
self._received_queue.put(future, False)
|
|
return future
|
|
|
|
# on aborted event, called when the connection abort
|
|
def on_aborted(self, handler: Union[Awaitable, Callable], *arguments):
|
|
def on_aborted(res):
|
|
res.aborted = True
|
|
if inspect.iscoroutinefunction(handler):
|
|
res.run_async(handler(*arguments))
|
|
else:
|
|
handler(*arguments)
|
|
|
|
self.res.on_aborted(on_aborted)
|
|
|
|
class SSGIWebSocket:
|
|
status: int = 0 # 0 pending upgrade, 1 rejected, 2 closed, 3 accepted
|
|
extensions: Optional[dict] = None # extensions for websocket
|
|
def __init__(self, res, req, socket_context, ws, extensions = None):
|
|
self.res = res
|
|
self.req = req
|
|
self.status = 0
|
|
self.extensions = extensions
|
|
self._socket_context = socket_context
|
|
self._key = self.req.get_header("sec-websocket-key")
|
|
self._protocol = self.req.get_header("sec-websocket-protocol")
|
|
self._extensions = self.req.get_header("sec-websocket-extensions")
|
|
self._close_handler = None
|
|
self._receive_handler = None
|
|
self._need_cork = False
|
|
self._accept_future = None
|
|
|
|
# accept the connection upgrade
|
|
# can pass the protocol to accept if None is informed will use sec-websocket-protocol header if available
|
|
def accept(self, protocol: str = None) -> Awaitable:
|
|
if self.status == 0:
|
|
self._accept_future = self.res.loop.create_future()
|
|
upgrade_protocol = protocol if protocol else self._protocol
|
|
|
|
self.res.upgrade(self._key, upgrade_protocol if upgrade_protocol else "", self._extensions, self._socket_context, self)
|
|
return self._accept_future
|
|
future = self.res.loop.create_future()
|
|
future.set_result(False)
|
|
return future
|
|
|
|
# reject the connection upgrade, you can send status_code, payload and headers if you want, all optional
|
|
def reject(self, status_code: Optional[int] = 403, payload = None, headers = None) -> Awaitable:
|
|
|
|
future = self.res.loop.create_future()
|
|
if self.status < 1:
|
|
self.status = 1
|
|
if headers:
|
|
for name, value in headers:
|
|
self.res.write_header(name, value)
|
|
if not payload:
|
|
self.res.write_status(status_code).end_without_body()
|
|
else:
|
|
self.res.write_status(status_code).cork_end(payload)
|
|
future.set_result(True)
|
|
else:
|
|
future.set_result(False)
|
|
return future
|
|
|
|
# if returns an future, this can be awaited or not
|
|
def send(self, payload: Union[bytes, bytearray, memoryview]):
|
|
if self.status == 3:
|
|
if self._need_cork:
|
|
self.ws.cork_send(payload)
|
|
else:
|
|
self.ws.send(payload)
|
|
|
|
# close connection
|
|
def close(self, code: Optional[int] = 1000):
|
|
if self.status == 3:
|
|
self.ws.close()
|
|
return True
|
|
return False
|
|
# ensure async call for the handler, passing any arguments to it
|
|
def run_async(self, handler: Awaitable, *arguments) -> Awaitable:
|
|
self.req.get_headers()
|
|
self._need_cork = True
|
|
return self.res.run_async(handler(*arguments))
|
|
|
|
# on receive event, called when the socket disconnect
|
|
# passes ws: SSGIWebSocket, msg: Union[str, bytes, bytearray, memoryview, None], *arguments
|
|
def on_receive(self, handler: Union[Awaitable, Callable], *arguments):
|
|
def on_receive_handler(ws, message, opcode):
|
|
if inspect.iscoroutinefunction(handler):
|
|
ws.res.run_async(handler(ws, message, *arguments))
|
|
else:
|
|
handler(ws, message, *arguments)
|
|
self._receive_handler = on_receive_handler
|
|
|
|
# on close event, called when the socket disconnect
|
|
# passes ws: SSGIWebSocket, code: int and reason: Optional[str] = None, *arguments
|
|
def on_close(self, handler: Union[Awaitable, Callable], *arguments):
|
|
def on_close_handler(ws, code, message):
|
|
if inspect.iscoroutinefunction(handler):
|
|
ws.res.run_async(handler(ws, code, message, *arguments))
|
|
else:
|
|
handler(ws, code, message, *arguments)
|
|
|
|
self._close_handler = on_close_handler
|
|
|
|
class SSGI:
|
|
def __init__(self, app, options=None, request_response_factory_max_items=0, websocket_factory_max_itens=0):
|
|
self.server = App(options, request_response_factory_max_items, websocket_factory_max_itens)
|
|
self.SERVER_PORT = None
|
|
self.SERVER_HOST = ''
|
|
self.SERVER_SCHEME = 'https' if self.server.options else 'http'
|
|
self.SERVER_WS_SCHEME = 'wss' if self.server.options else 'ws'
|
|
self.SERVER_ADDRESS = ''
|
|
|
|
self.app = app
|
|
support = app.get_supported({ "ssgi": "1.0" })
|
|
http, middleware = support.get('http', (None, None))
|
|
websocket, ws_middleware = support.get('websocket', (None, None))
|
|
|
|
def ssgi(res, req):
|
|
|
|
response = SSGIHttpResponse(res, req)
|
|
PATH_INFO = req.get_url()
|
|
# FULL_PATH_INFO = req.get_full_url()
|
|
METHOD = req.get_method()
|
|
QUERY_STRING = "" #FULL_PATH_INFO[len(PATH_INFO):]
|
|
|
|
# REMOTE_ADDRESS = res.get_remote_address()
|
|
def get_header(name = None):
|
|
if name:
|
|
return req.get_header(name)
|
|
else:
|
|
return req.get_headers()
|
|
# self.SERVER_SCHEME, self.SERVER_ADDRESS,
|
|
# self.SERVER_SCHEME, self.SERVER_ADDRESS,
|
|
if inspect.iscoroutinefunction(middleware):
|
|
req.get_headers() # preserve
|
|
res.run_async(middleware('http', METHOD, PATH_INFO, QUERY_STRING, get_header, response))
|
|
else:
|
|
middleware('http', METHOD, PATH_INFO, QUERY_STRING, get_header, response)
|
|
# if not response._responded:
|
|
# res.grab_aborted_handler()
|
|
|
|
if http == "ssgi" and middleware:
|
|
self.server.any("/*", ssgi)
|
|
|
|
def ws_upgrade(res, req, socket_context):
|
|
response = SSGIWebSocket(res, req, socket_context, None)
|
|
PATH_INFO = req.get_url()
|
|
FULL_PATH_INFO = req.get_full_url()
|
|
METHOD = req.get_method()
|
|
|
|
REMOTE_ADDRESS = req.get_remote_address()
|
|
def get_header(name = None):
|
|
if name:
|
|
return req.get_header(name)
|
|
else:
|
|
return req.get_headers()
|
|
if inspect.iscoroutinefunction(ws_middleware):
|
|
req.get_headers() # preserve
|
|
res.run_async(ws_middleware('websocket', self.SERVER_SCHEME, self.SERVER_ADDRESS, REMOTE_ADDRESS, METHOD, PATH_INFO, FULL_PATH_INFO[len(PATH_INFO):], get_header, response))
|
|
return
|
|
else:
|
|
ws_middleware('websocket', self.SERVER_WS_SCHEME, self.SERVER_HOST, REMOTE_ADDRESS, METHOD, PATH_INFO, FULL_PATH_INFO[len(PATH_INFO):], get_header, response)
|
|
# not accepted (async?)
|
|
if response.status == 0 and not response._accept_future:
|
|
res.grab_aborted_handler()
|
|
|
|
if websocket == "ssgi" and ws_middleware:
|
|
def ws_open(ws):
|
|
res = ws.get_user_data()
|
|
res.ws = ws
|
|
res.status = 3
|
|
res._accept_future.set_result(True)
|
|
|
|
def ws_message(ws, message, op):
|
|
res = ws.get_user_data()
|
|
if res._receive_handler:
|
|
res._receive_handler(res, message, op)
|
|
|
|
def ws_close(ws, code, message):
|
|
res = ws.get_user_data()
|
|
if res._close_handler:
|
|
res._close_handler(res, code, message)
|
|
|
|
|
|
self.server.ws("/*", {
|
|
"compression": CompressOptions.DISABLED,
|
|
"max_payload_length": 16 * 1024 * 1024,
|
|
"idle_timeout": 0,
|
|
"upgrade": ws_upgrade,
|
|
"open": ws_open,
|
|
"message": ws_message,
|
|
"close": ws_close
|
|
})
|
|
|
|
def listen(self, port_or_options, handler):
|
|
self.SERVER_PORT = port_or_options if isinstance(port_or_options, int) else port_or_options.port
|
|
self.SERVER_HOST = "0.0.0.0" if isinstance(port_or_options, int) else port_or_options.host
|
|
if self.SERVER_PORT:
|
|
self.SERVER_ADDRESS = f"{self.SERVER_HOST}:{self.SERVER_PORT}"
|
|
else:
|
|
self.SERVER_ADDRESS = self.SERVER_HOST
|
|
|
|
self.server.listen(port_or_options, handler)
|
|
return self
|
|
def run(self):
|
|
self.server.run()
|
|
return self |