diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index c0c66ff..b53134c 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -283,8 +283,7 @@ def uws_missing_server_name(hostname, hostname_length, user_data): handler = app._missing_server_handler if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(data)) + app.run_async(handler(data)) else: handler(data) except Exception as err: @@ -298,8 +297,7 @@ def uws_websocket_drain_handler(ws, user_data): ws = WebSocket(ws, app.SSL, app.loop) handler = handlers.drain if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws)) + app.run_async(handler(ws)) else: handler(ws) except Exception as err: @@ -314,8 +312,7 @@ def uws_websocket_open_handler(ws, user_data): ws = WebSocket(ws, app.SSL, app.loop) handler = handlers.open if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws)) + app.run_async(handler(ws)) else: handler(ws) except Exception as err: @@ -338,8 +335,7 @@ def uws_websocket_message_handler(ws, message, length, opcode, user_data): handler = handlers.message if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws, data, opcode)) + app.run_async(handler(ws, data, opcode)) else: handler(ws, data, opcode) @@ -359,8 +355,7 @@ def uws_websocket_pong_handler(ws, message, length, user_data): handler = handlers.pong if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws, data)) + app.run_async(handler(ws, data)) else: handler(ws, data) @@ -381,8 +376,7 @@ def uws_websocket_ping_handler(ws, message,length, user_data): handler = handlers.ping if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws, data)) + app.run_async(handler(ws, data)) else: handler(ws, data) @@ -396,7 +390,7 @@ def uws_websocket_close_handler(ws, code, message, length, user_data): try: (handlers, app) = ffi.from_handle(user_data) #pass to free data on WebSocket if needed - ws = WebSocket(ws, app.SSL, app.loop, True) + ws = WebSocket(ws, app.SSL, app.loop) if message == ffi.NULL: data = None @@ -407,11 +401,21 @@ def uws_websocket_close_handler(ws, code, message, length, user_data): if handler is None: return + + if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(ws, int(code), data)) + future = app.run_async(handler(ws, int(code), data)) + def when_finished(_): + key = ws.get_user_data_uuid() + if not key is None: + SocketRefs.pop(key, None) + future.add_done_callback(when_finished) else: handler(ws, int(code), data) + key = ws.get_user_data_uuid() + if not key is None: + SocketRefs.pop(key, None) + except Exception as err: print("Uncaught Exception: %s" % str(err)) #just log in console the error to call attention @@ -426,8 +430,7 @@ def uws_websocket_upgrade_handler(res, req, context, user_data): request = AppRequest(req) handler = handlers.upgrade if inspect.iscoroutinefunction(handler): - response.grab_aborted_handler() - response.run_async(handler(response, request, context)) + app.run_async(handler(response, request, context)) else: handler(response, request, context) @@ -589,14 +592,13 @@ class SendStatus(IntEnum): SocketRefs = {} class WebSocket: - def __init__(self, websocket, ssl, loop, free_socket_data=False): + def __init__(self, websocket, ssl, loop): self.ws = websocket self.SSL = ssl self._ptr = ffi.new_handle(self) self.loop = loop self._cork_handler = None self._for_each_topic_handler = None - self.free_socket_data = free_socket_data self.socket_data_id = None self.socket_data = None self.got_socket_data = False @@ -613,12 +615,12 @@ class WebSocket: #uuid for socket data, used to free data after socket closes def get_user_data_uuid(self): - if self.got_socket_data: - return self.socket_data_id - user_data = lib.uws_ws_get_user_data(self.SSL, self.ws) - if user_data == ffi.NULL: - return None try: + if self.got_socket_data: + return self.socket_data_id + user_data = lib.uws_ws_get_user_data(self.SSL, self.ws) + if user_data == ffi.NULL: + return None (data, socket_data_id) = ffi.from_handle(user_data) self.socket_data_id = socket_data_id self.socket_data = data @@ -628,12 +630,12 @@ class WebSocket: return None def get_user_data(self): - if self.got_socket_data: - return self.socket_data - user_data = lib.uws_ws_get_user_data(self.SSL, self.ws) - if user_data == ffi.NULL: - return None try: + if self.got_socket_data: + return self.socket_data + user_data = lib.uws_ws_get_user_data(self.SSL, self.ws) + if user_data == ffi.NULL: + return None (data, socket_data_id) = ffi.from_handle(user_data) self.socket_data_id = socket_data_id self.socket_data = data @@ -838,12 +840,6 @@ class WebSocket: lib.uws_ws_cork(self.SSL, self.ws, uws_ws_cork_handler, self._ptr) def __del__(self): - #free SocketRefs when if needed - if self.free_socket_data: - key = self.get_user_data_uuid() - if not key is None: - SocketRefs.pop(key, None) - self.ws = ffi.NULL self._ptr = ffi.NULL