diff --git a/src/socketify/loop.py b/src/socketify/loop.py index 3e67392..0e1ff6f 100644 --- a/src/socketify/loop.py +++ b/src/socketify/loop.py @@ -54,3 +54,12 @@ class Loop: future = asyncio.run_coroutine_threadsafe(task, self.loop) future.add_done_callback(lambda f: future_handler(f, self.loop, self.exception_handler, response)) return future + + + +# if sys.version_info >= (3, 11) +# with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: +# runner.run(main()) +# else: +# uvloop.install() +# asyncio.run(main()) \ No newline at end of file diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index aea04d4..9606c13 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -4,6 +4,7 @@ from .loop import Loop from .status_codes import status_codes import json import inspect +import threading ffi = cffi.FFI() ffi.cdef(""" @@ -186,7 +187,7 @@ void uws_res_write_status(int ssl, uws_res_t *res, const char *status, size_t le void uws_res_write_header(int ssl, uws_res_t *res, const char *key, size_t key_length, const char *value, size_t value_length); void uws_res_write_header_int(int ssl, uws_res_t *res, const char *key, size_t key_length, uint64_t value); -void uws_res_end_without_body(int ssl, uws_res_t *res); +void uws_res_end_without_body(int ssl, uws_res_t *res, bool close_connection); bool uws_res_write(int ssl, uws_res_t *res, const char *data, size_t length); uintmax_t uws_res_get_write_offset(int ssl, uws_res_t *res); bool uws_res_has_responded(int ssl, uws_res_t *res); @@ -246,7 +247,6 @@ def uws_generic_ssl_method_handler(res, req, user_data): def uws_generic_listen_handler(listen_socket, config, user_data): if listen_socket == ffi.NULL: raise RuntimeError("Failed to listen on port %d" % int(config.port)) - if not user_data == ffi.NULL: app = ffi.from_handle(user_data) @@ -256,10 +256,9 @@ def uws_generic_listen_handler(listen_socket, config, user_data): app._listen_handler(None if config == ffi.NULL else AppListenOptions(port=int(config.port),host=None if config.host == ffi.NULL else ffi.string(config.host).decode("utf-8"), options=int(config.options))) @ffi.callback("void(uws_res_t *, void*)") -def uws_generic_abord_handler(response, user_data): +def uws_generic_aborted_handler(response, user_data): if not user_data == ffi.NULL: res = ffi.from_handle(user_data) - res.aborted = True res.trigger_aborted() class AppRequest: @@ -309,6 +308,7 @@ class AppRequest: def is_ancient(self): return bool(lib.uws_req_is_ancient(self.req)) + class AppResponse: def __init__(self, response, loop, is_ssl): self.res = response @@ -317,22 +317,28 @@ class AppResponse: self._ptr = ffi.NULL self.loop = loop + def trigger_aborted(self): + self.aborted = True + self.res = ffi.NULL + self._ptr = ffi.NULL + if hasattr(self, "_aborted_handler") and hasattr(self._aborted_handler, '__call__'): + self._aborted_handler() + return self + def run_async(self, task): self.grab_aborted_handler() return self.loop.run_async(task, self) def grab_aborted_handler(self): #only needed if is async - if self._ptr == ffi.NULL: + if self._ptr == ffi.NULL and not self.aborted: self._ptr = ffi.new_handle(self) - lib.uws_res_on_aborted(self.SSL, self.res, uws_generic_abord_handler, self._ptr) + lib.uws_res_on_aborted(self.SSL, self.res, uws_generic_aborted_handler, self._ptr) def redirect(self, location, status_code=302): - if not isinstance(location, str): - raise RuntimeError("Location must be an string") self.write_status(status_code) self.write_header("Location", location) - self.end_without_body() + self.end_without_body(False) def end(self, message, end_connection=False): if not self.aborted: @@ -341,7 +347,7 @@ class AppResponse: elif isinstance(message, bytes): data = message elif message == None: - self.end_without_body() + self.end_without_body(end_connection) return self else: self.write_header("Content-Type", "application/json") @@ -395,9 +401,9 @@ class AppResponse: lib.uws_res_write_header(self.SSL, self.res, key_data, len(key_data), value_data, len(value_data)) return self - def end_without_body(self): + def end_without_body(self, end_connection=False): if not self.aborted: - lib.uws_res_end_without_body(self.SSL, self.res) + lib.uws_res_end_without_body(self.SSL, self.res, 1 if end_connection else 0) return self def write(self, message): @@ -427,11 +433,6 @@ class AppResponse: return False return bool(lib.uws_res_has_responded(self.SSL, self.res, data, len(data))) - def trigger_aborted(self): - if hasattr(self, "_aborted_handler") and hasattr(self._aborted_handler, '__call__'): - self._aborted_handler() - return self - def on_aborted(self, handler): if hasattr(handler, '__call__'): self.grab_aborted_handler() diff --git a/src/socketify/uWebSockets b/src/socketify/uWebSockets index fd78f29..c168734 160000 --- a/src/socketify/uWebSockets +++ b/src/socketify/uWebSockets @@ -1 +1 @@ -Subproject commit fd78f2960ac3c8ac529a11f115ba824db7e60c09 +Subproject commit c168734e80daa0c91123ed44172f193b1ba8e365 diff --git a/tests/examples/router_and_basics.py b/tests/examples/router_and_basics.py index a5d9168..4f2881f 100644 --- a/tests/examples/router_and_basics.py +++ b/tests/examples/router_and_basics.py @@ -56,7 +56,6 @@ def send_in_parts(res, req): res.write("messages") res.end(" in parts!") - def redirect(res, req): #status code is optional default is 302 res.redirect("/redirected", 302) @@ -64,7 +63,6 @@ def redirect(res, req): def redirected(res, req): res.end("You got redirected to here :D") - def not_found(res, req): res.write_status(404).end("Not Found")