diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index 725db10..7e38d6e 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -17,6 +17,8 @@ from .native import ffi, lib from .loop import Loop from .status_codes import status_codes from .helpers import static_route +from dataclasses import dataclass +from types import NoneType mimetypes.init() @@ -967,7 +969,7 @@ class WebSocketFactory: def get(self, app, ws): if len(self.factory_queue) == 0: response = WebSocket(ws, app.SSL, app.loop) - return (response, False) + return response, False instances = self.factory_queue.pop() (websocket, _) = instances @@ -997,7 +999,7 @@ class RequestResponseFactory: if len(self.factory_queue) == 0: response = AppResponse(res, app.loop, app.SSL, app._template) request = AppRequest(req) - return (response, request, False) + return response, request, False instances = self.factory_queue.pop() (response, request, _) = instances @@ -1465,7 +1467,7 @@ class AppResponse: def try_end(self, message, total_size, end_connection=False): try: if self.aborted: - return (False, True) + return False, True if self._write_jar is not None: self.write_header("Set-Cookie", self._write_jar.output(header="")) self._write_jar = None @@ -1474,7 +1476,7 @@ class AppResponse: elif isinstance(message, bytes): data = message else: - return (False, True) + return False, True result = lib.uws_res_try_end( self.SSL, self.res, @@ -1483,9 +1485,9 @@ class AppResponse: ffi.cast("uintmax_t", total_size), 1 if end_connection else 0, ) - return (bool(result.ok), bool(result.has_responded)) + return bool(result.ok), bool(result.has_responded) except: - return (False, True) + return False, True def cork_end(self, message, end_connection=False): self.cork(lambda res: res.end(message, end_connection)) @@ -1757,11 +1759,6 @@ class App: if options.key_file_name is None else ffi.new("char[]", options.key_file_name.encode("utf-8")) ) - socket_options.key_file_name = ( - ffi.NULL - if options.key_file_name is None - else ffi.new("char[]", options.key_file_name.encode("utf-8")) - ) socket_options.cert_file_name = ( ffi.NULL if options.cert_file_name is None @@ -2259,7 +2256,7 @@ class App: def close(self): if hasattr(self, "socket"): - if not self.socket == ffi.NULL: + if self.socket != ffi.NULL: lib.us_listen_socket_close(self.SSL, self.socket) self.loop.stop() return self @@ -2302,49 +2299,43 @@ class App: lib.uws_app_destroy(self.SSL, self.app) +@dataclass class AppListenOptions: - def __init__(self, port=0, host=None, options=0): - if not isinstance(port, int): + port: int = 0 + host: str = None + options: int = 0 + + def __post_init__(self): + if not isinstance(self.port, int): raise RuntimeError("port must be an int") - if host is not None and not isinstance(host, str): - raise RuntimeError("host must be an String or None") - if not isinstance(options, int): + if not isinstance(self.host, (NoneType, str)): + raise RuntimeError("host must be a str if specified") + if not isinstance(self.options, int): raise RuntimeError("options must be an int") - self.port = port - self.host = host - self.options = options +@dataclass class AppOptions: - def __init__( - self, - key_file_name=None, - cert_file_name=None, - passphrase=None, - dh_params_file_name=None, - ca_file_name=None, - ssl_ciphers=None, - ssl_prefer_low_memory_usage=0 - ): - if key_file_name is not None and not isinstance(key_file_name, str): - raise RuntimeError("key_file_name must be an String or None") - if cert_file_name is not None and not isinstance(cert_file_name, str): - raise RuntimeError("cert_file_name must be an String or None") - if passphrase is not None and not isinstance(passphrase, str): - raise RuntimeError("passphrase must be an String or None") - if dh_params_file_name is not None and not isinstance(dh_params_file_name, str): - raise RuntimeError("dh_params_file_name must be an String or None") - if ca_file_name is not None and not isinstance(ca_file_name, str): - raise RuntimeError("ca_file_name must be an String or None") - if ssl_ciphers is not None and not isinstance(ssl_ciphers, str): - raise RuntimeError("ssl_ciphers must be an String or None") - if not isinstance(ssl_prefer_low_memory_usage, int): - raise RuntimeError("ssl_prefer_low_memory_usage must be an int") + key_file_name: str = None, + cert_file_name: str = None, + passphrase: str = None, + dh_params_file_name: str = None, + ca_file_name: str = None, + ssl_ciphers: str = None, + ssl_prefer_low_memory_usage: int = 0 - self.key_file_name = key_file_name - self.cert_file_name = cert_file_name - self.passphrase = passphrase - self.dh_params_file_name = dh_params_file_name - self.ca_file_name = ca_file_name - self.ssl_ciphers = ssl_ciphers - self.ssl_prefer_low_memory_usage = ssl_prefer_low_memory_usage + def __post_init__(self): + if not isinstance(self.key_file_name, (NoneType, str)): + raise RuntimeError("key_file_name must be a str if specified") + if not isinstance(self.cert_file_name, (NoneType, str)): + raise RuntimeError("cert_file_name must be a str if specified") + if not isinstance(self.passphrase, (NoneType, str)): + raise RuntimeError("passphrase must be a str if specified") + if not isinstance(self.dh_params_file_name, (NoneType, str)): + raise RuntimeError("dh_params_file_name must be a str if specified") + if not isinstance(self.ca_file_name, (NoneType, str)): + raise RuntimeError("ca_file_name must be a str if specified") + if not isinstance(self.ssl_ciphers, (NoneType, str)): + raise RuntimeError("ssl_ciphers must be a str if specified") + if not isinstance(self.ssl_prefer_low_memory_usage, int): + raise RuntimeError("ssl_prefer_low_memory_usage must be an int")