From 6ae34c3b770a03cc2e21d7c07848c600c8029543 Mon Sep 17 00:00:00 2001 From: Ciro Date: Fri, 6 Jan 2023 16:11:19 -0300 Subject: [PATCH] fixes, json serializer, wip extensions --- README.md | 9 +- bench/asgi_wsgi/falcon-asgi.py | 7 +- docs/README.md | 3 +- docs/_sidebar.md | 3 +- docs/api.md | 7 +- docs/basics.md | 16 + docs/extensions.md | 55 ++ docs/graphiql.md | 2 +- docs/websockets-backpressure.md | 2 +- examples/custom_json_serializer.py | 16 + pyproject.toml | 2 +- setup.py | 2 +- src/socketify/__init__.py | 3 +- src/socketify/asgi.py | 3 - src/socketify/socketify.py | 1369 +++++++++++++++++----------- src/socketify/wsgi.py | 3 - src/tests.py | 31 +- 17 files changed, 981 insertions(+), 552 deletions(-) create mode 100644 docs/extensions.md create mode 100644 examples/custom_json_serializer.py diff --git a/README.md b/README.md index 6faf968..65a92d4 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,11 @@ - Max Backpressure, Max Timeout, Max Payload and Idle Timeout Support - Automatic Ping / Pong Support - Per Socket Data -- Middlewares -- Templates Support (examples with [`Mako`](https://github.com/cirospaciari/socketify.py/tree/main/examples/template_mako.py) and [`Jinja2`](https://github.com/cirospaciari/socketify.py/tree/main/examples/template_jinja2.py)) -- ASGI Server with pub/sub extension for Falcon -- WSGI Server +- [`Middlewares`](https://docs.socketify.dev/middlewares.html) +- [`Templates`](https://docs.socketify.dev/templates.html) Support (examples with [`Mako`](https://github.com/cirospaciari/socketify.py/tree/main/examples/template_mako.py) and [`Jinja2`](https://github.com/cirospaciari/socketify.py/tree/main/examples/template_jinja2.py)) +- [`ASGI Server`](https://docs.socketify.dev/cli.html) +- [`WSGI Server`](https://docs.socketify.dev/cli.html) +- [`Plugins/Extensions`](https://docs.socketify.dev/extensions.html) ## :mag_right: Upcoming Features - In-Memory Cache Tools diff --git a/bench/asgi_wsgi/falcon-asgi.py b/bench/asgi_wsgi/falcon-asgi.py index b41ae44..71ab93b 100644 --- a/bench/asgi_wsgi/falcon-asgi.py +++ b/bench/asgi_wsgi/falcon-asgi.py @@ -8,12 +8,11 @@ class Home: resp.content_type = falcon.MEDIA_TEXT # Default is JSON, so override resp.text = "Hello, World!" async def on_post(self, req, resp): - # curl -d '{"key1":"value1", "key2":"value2"}' -H "Content-Type: application/json" -X POST http://localhost:8000/ - raw_data = await req.stream.read() - print("data", raw_data) + # curl -d '{"name":"test"}' -H "Content-Type: application/json" -X POST http://localhost:8000/ + json = await req.media resp.status = falcon.HTTP_200 # This is the default status resp.content_type = falcon.MEDIA_TEXT # Default is JSON, so override - resp.text = raw_data + resp.text = json.get("name", "") diff --git a/docs/README.md b/docs/README.md index 567e4cc..8912547 100644 --- a/docs/README.md +++ b/docs/README.md @@ -35,5 +35,6 @@ With no precedents websocket performance and an really fast HTTP server that can - [GraphiQL](graphiql.md) - [WebSockets and Backpressure](websockets-backpressure.md) - [SSL](ssl.md) -- [CLI Reference](cli.md) +- [Plugins / Extensions](extensions.md) +- [CLI, ASGI and WSGI](cli.md) - [API Reference](api.md) diff --git a/docs/_sidebar.md b/docs/_sidebar.md index 5343c67..b04f22d 100644 --- a/docs/_sidebar.md +++ b/docs/_sidebar.md @@ -13,5 +13,6 @@ - [GraphiQL](graphiql.md) - [WebSockets and Backpressure](websockets-backpressure.md) - [SSL](ssl.md) -- [CLI Reference](cli.md) +- [Plugins / Extensions](extensions.md) +- [CLI, ASGI and WSGI](cli.md) - [API Reference](api.md) diff --git a/docs/api.md b/docs/api.md index 714f820..8254ff4 100644 --- a/docs/api.md +++ b/docs/api.md @@ -4,6 +4,7 @@ class App: def __init__(self, options=None): def template(self, template_engine): + def json_serializer(self, json_serializer): def static(self, route, directory): def get(self, path, handler): def post(self, path, handler): @@ -34,7 +35,7 @@ class App: ## AppResponse ```python class AppResponse: - def __init__(self, response, loop, ssl, render=None): + def __init__(self, response, app): def cork(self, callback): def set_cookie(self, name, value, options={}): def run_async(self, task): @@ -81,7 +82,7 @@ class AppResponse: ## AppRequest ```python class AppRequest: - def __init__(self, request): + def __init__(self, request, app): def get_cookie(self, name): def get_url(self): def get_full_url(self): @@ -123,7 +124,7 @@ class AppOptions: ```python class WebSocket: - def __init__(self, websocket, ssl, loop): + def __init__(self, websocket, app): # uuid for socket data, used to free data after socket closes def get_user_data_uuid(self): diff --git a/docs/basics.md b/docs/basics.md index 22d96a8..02a0864 100644 --- a/docs/basics.md +++ b/docs/basics.md @@ -182,6 +182,22 @@ def route_handler(res, req): res.run_async(sendfile(res, req, "my_text")) ``` + +## Using ujson, orjson or any custom JSON serializer +socketify by default uses built in `json` module with have great performance on PyPy, but if you wanna to use another module instead of the default you can just register using `app.json_serializer(module)` + +```python +from socketify import App +import ujson +app = App() + +# set json serializer to ujson +# json serializer must have dumps and loads functions +app.json_serializer(ujson) + +app.get("/", lambda res, req: res.end({"Hello":"World!"})) +``` + ## Raw socket pointer If for some reason you need the raw socket pointer you can use `res.get_native_handle()` and will get an CFFI handler. diff --git a/docs/extensions.md b/docs/extensions.md new file mode 100644 index 0000000..709eb85 --- /dev/null +++ b/docs/extensions.md @@ -0,0 +1,55 @@ + +# Plugins / Extensions + +You can add more functionality to request, response, and websocket objects, for this you can use `app.register(extension)` to register an extension. +Be aware that using extensions can have a performance impact and using it with `request_response_factory_max_items`, `websocket_factory_max_items` +or the equivalent on CLI `--req-res-factory-maxitems`, `--ws-factory-maxitems` will reduce this performance impact. + +Extensions must follow the signature `def extension(request, response, ws)`, request, response, and ws objects contain `method` decorator that binds a method to an instance, +and also a `property(name: str, default_value: any = None)` that dynamic adds an property to the instance. + +```python +from socketify import App, OpCode + +app = App() + +def extension(request, response, ws): + @request.method + async def get_user(self): + token = self.get_header("token") + return { "name": "Test" } if token else { "name", "Anonymous" } + + @response.method + def msgpack(self, value: any): + self.write_header(b'Content-Type', b'application/msgpack') + data = msgpack.packb(value, default=encode_datetime, use_bin_type=True) + return self.end(data) + + @ws.method + def send_pm(self, to_username: str, message: str): + user_data = self.get_user_data() + pm_topic = f"pm-{to_username}+{user_data.username}" + + # if topic exists just send the message + if app.num_subscribers(pm_topic) > 0: + # send private message + return self.publish(pm_topic, message, OpCode.TEXT) + + # if the topic not exists create it and signal the user + # subscribe to the conversation + self.subscribe(pm_topic) + # signal user that you want to talk and create an pm room + # all users must subscribe to signal-{username} + self.publish(f"signal-{to_username}", { + "type": "pm", + "username": user_data.username, + "message": message + }, OpCode.TEXT) + # this property can be used on extension methods and/or middlewares + request.property("cart", []) + +# extensions must be registered before routes +app.register(extension) +``` + +### Next [CLI, ASGI and WSGI](cli.md) \ No newline at end of file diff --git a/docs/graphiql.md b/docs/graphiql.md index 3bf485f..f41803c 100644 --- a/docs/graphiql.md +++ b/docs/graphiql.md @@ -1,5 +1,5 @@ ## GraphiQL Support -In /src/examples/helper/graphiql.py we implemented an helper for using graphiQL with strawberry. +In [`/src/examples/helper/graphiql.py`](https://github.com/cirospaciari/socketify.py/blob/main/examples/graphiql.py) we implemented an helper for using graphiQL with strawberry. ### Usage ```python diff --git a/docs/websockets-backpressure.md b/docs/websockets-backpressure.md index e06a49b..38ae1f5 100644 --- a/docs/websockets-backpressure.md +++ b/docs/websockets-backpressure.md @@ -49,4 +49,4 @@ You probably want shared compressor if dealing with larger JSON messages, or 4kb idle_timeout is roughly the amount of seconds that may pass between messages. Being idle for more than this, and the connection is severed. This means you should make your clients send small ping messages every now and then, to keep the connection alive. The server will automatically send pings in case it needs to. -### Next [SSL](ssl.md) \ No newline at end of file +### Next [Plugins / Extensions](extensions.md) \ No newline at end of file diff --git a/examples/custom_json_serializer.py b/examples/custom_json_serializer.py new file mode 100644 index 0000000..e55bd89 --- /dev/null +++ b/examples/custom_json_serializer.py @@ -0,0 +1,16 @@ +from socketify import App +import ujson + +app = App() + + +# set json serializer to ujson +# json serializer must have dumps and loads functions +app.json_serializer(ujson) + +app.get("/", lambda res, req: res.end({"Hello":"World!"})) +app.listen( + 3000, + lambda config: print("Listening on port http://localhost:%d now\n" % config.port), +) +app.run() diff --git a/pyproject.toml b/pyproject.toml index 92c0010..0867bcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "socketify" -version = "0.0.3" +version = "0.0.4" authors = [ { name="Ciro Spaciari", email="ciro.spaciari@gmail.com" }, ] diff --git a/setup.py b/setup.py index 76f3724..98f2d59 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ with open("README.md", "r", encoding="utf-8") as fh: setuptools.setup( name="socketify", - version="0.0.3", + version="0.0.4", platforms=["any"], author="Ciro Spaciari", author_email="ciro.spaciari@gmail.com", diff --git a/src/socketify/__init__.py b/src/socketify/__init__.py index 904eee1..6d7a01e 100644 --- a/src/socketify/__init__.py +++ b/src/socketify/__init__.py @@ -5,7 +5,8 @@ from .socketify import ( OpCode, SendStatus, CompressOptions, - Loop + Loop, + AppExtension ) from .asgi import ( ASGI diff --git a/src/socketify/asgi.py b/src/socketify/asgi.py index 2c70a04..85acaab 100644 --- a/src/socketify/asgi.py +++ b/src/socketify/asgi.py @@ -374,7 +374,6 @@ def uws_asgi_corked_response_start_handler(res, user_data): lib.socketify_res_write_int_status(ssl, res, int(status)) for name, value in headers: write_header(ssl, res, name, value) - write_header(ssl, res, b"Server", b"socketify.py") @ffi.callback("void(uws_res_t*, void*)") @@ -384,7 +383,6 @@ def uws_asgi_corked_accept_handler(res, user_data): lib.socketify_res_write_int_status(ssl, res, int(status)) for name, value in headers: write_header(ssl, res, name, value) - write_header(ssl, res, b"Server", b"socketify.py") @ffi.callback("void(uws_res_t*, void*)") @@ -392,7 +390,6 @@ def uws_asgi_corked_ws_accept_handler(res, user_data): (ssl, headers) = ffi.from_handle(user_data) for name, value in headers: write_header(ssl, res, name, value) - write_header(ssl, res, b"Server", b"socketify.py") @ffi.callback("void(uws_res_t*, void*)") diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index 688015e..5742170 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -1,4 +1,3 @@ -import cffi from datetime import datetime from enum import IntEnum from http import cookies @@ -529,12 +528,38 @@ def uws_websocket_factory_upgrade_handler(res, req, context, user_data): if dispose: app._factory.dispose(instances) +@ffi.callback("void(uws_res_t*, uws_req_t*, uws_socket_context_t*, void*)") +def uws_websocket_upgrade_handler_with_extension(res, req, context, user_data): + if user_data != ffi.NULL: + handlers, app = ffi.from_handle(user_data) + response = AppResponse(res, app) + # set default value in properties + app._response_extension.set_properties(response) + # bind methods to response + app._response_extension.bind_methods(response) + request = AppRequest(req, app) + # set default value in properties + app._request_extension.set_properties(request) + # bind methods to request + app._request_extension.bind_methods(request) + + try: + handler = handlers.upgrade + if inspect.iscoroutinefunction(handler): + response.run_async(handler(response, request, context)) + else: + handler(response, request, context) + + except Exception as err: + response.grab_aborted_handler() + app.trigger_error(err, response, request) + @ffi.callback("void(uws_res_t*, uws_req_t*, uws_socket_context_t*, void*)") def uws_websocket_upgrade_handler(res, req, context, user_data): if user_data != ffi.NULL: handlers, app = ffi.from_handle(user_data) - response = AppResponse(res, app.loop, app.SSL, app._template, app._socket_refs) - request = AppRequest(req) + response = AppResponse(res, app) + request = AppRequest(req, app) try: handler = handlers.upgrade if inspect.iscoroutinefunction(handler): @@ -603,13 +628,38 @@ def uws_generic_factory_method_handler(res, req, user_data): app.trigger_error(err, response, request) if dispose: app._factory.dispose(instances) + +@ffi.callback("void(uws_res_t*, uws_req_t*, void*)") +def uws_generic_method_handler_with_extension(res, req, user_data): + if user_data != ffi.NULL: + (handler, app) = ffi.from_handle(user_data) + response = AppResponse(res, app) + # set default value in properties + app._response_extension.set_properties(response) + # bind methods to response + app._response_extension.bind_methods(response) + request = AppRequest(req, app) + # set default value in properties + app._request_extension.set_properties(request) + # bind methods to request + app._request_extension.bind_methods(request) + try: + if inspect.iscoroutinefunction(handler): + response.grab_aborted_handler() + response.run_async(handler(response, request)) + else: + handler(response, request) + except Exception as err: + response.grab_aborted_handler() + app.trigger_error(err, response, request) + @ffi.callback("void(uws_res_t*, uws_req_t*, void*)") def uws_generic_method_handler(res, req, user_data): if user_data != ffi.NULL: (handler, app) = ffi.from_handle(user_data) - response = AppResponse(res, app.loop, app.SSL, app._template, app._socket_refs) - request = AppRequest(req) + response = AppResponse(res, app) + request = AppRequest(req, app) try: if inspect.iscoroutinefunction(handler): @@ -889,7 +939,7 @@ class WebSocket: # elif message is None: # data = b"" # else: - # data = json.dumps(message).encode("utf-8") + # data = self.app._json_serializer.dumps(message).encode("utf-8") # return bool( # lib.uws_ws_publish_with_options( @@ -953,7 +1003,7 @@ class WebSocket: lib.uws_ws_send_fragment(self.app.SSL, self.ws, b"", 0, compress) return self else: - data = json.dumps(message).encode("utf-8") + data = self.app._json_serializer.dumps(message).encode("utf-8") return SendStatus( lib.uws_ws_send_fragment(self.app.SSL, self.ws, data, len(data), compress) @@ -971,7 +1021,7 @@ class WebSocket: lib.uws_ws_send_last_fragment(self.app.SSL, self.ws, b"", 0, compress) return self else: - data = json.dumps(message).encode("utf-8") + data = self.app._json_serializer.dumps(message).encode("utf-8") return SendStatus( lib.uws_ws_send_last_fragment( @@ -993,7 +1043,7 @@ class WebSocket: ) return self else: - data = json.dumps(message).encode("utf-8") + data = self.app._json_serializer.dumps(message).encode("utf-8") return SendStatus( lib.uws_ws_send_first_fragment_with_opcode( @@ -1019,7 +1069,7 @@ class WebSocket: ) return self else: - data = json.dumps(message).encode("utf-8") + data = self.app._json_serializer.dumps(message).encode("utf-8") return SendStatus( lib.uws_ws_send_with_options( @@ -1045,7 +1095,7 @@ class WebSocket: lib.uws_ws_end(self.app.SSL, self.ws, b"", 0) return self else: - data = json.dumps(message).encode("utf-8") + data = self.app._json_serializer.dumps(message).encode("utf-8") lib.uws_ws_end(self.app.SSL, self.ws, code, data, len(data)) finally: @@ -1063,6 +1113,495 @@ class WebSocket: self.ws = ffi.NULL self._ptr = ffi.NULL +class AppResponse: + def __init__(self, response, app): + self.res = response + self.app = app + self.aborted = False + self._aborted_handler = None + self._writable_handler = None + self._data_handler = None + self._ptr = ffi.new_handle(self) + self._grabbed_abort_handler_once = False + self._write_jar = None + self._cork_handler = None + self._lastChunkOffset = 0 + self._chunkFuture = None + self._dataFuture = None + self._data = None + + def cork(self, callback): + if not self.aborted: + self.grab_aborted_handler() + self._cork_handler = callback + lib.uws_res_cork(self.app.SSL, self.res, uws_generic_cork_handler, self._ptr) + + def set_cookie(self, name, value, options): + if options is None: + options = {} + if self._write_jar is None: + self._write_jar = cookies.SimpleCookie() + self._write_jar[name] = quote_plus(value) + if isinstance(options, dict): + for key in options: + if key == "expires" and isinstance(options[key], datetime): + self._write_jar[name][key] = options[key].strftime( + "%a, %d %b %Y %H:%M:%S GMT" + ) + else: + self._write_jar[name][key] = options[key] + + def trigger_aborted(self): + self.aborted = True + self._ptr = ffi.NULL + self.res = ffi.NULL + if hasattr(self, "_aborted_handler") and hasattr( + self._aborted_handler, "__call__" + ): + try: + if inspect.iscoroutinefunction(self._aborted_handler): + self.run_async(self._aborted_handler(self)) + else: + self._aborted_handler(self) + except Exception as err: + logging.error("Error on abort handler %s" % str(err)) + return self + + def trigger_data_handler(self, data, is_end): + if self.aborted: + return self + if hasattr(self, "_data_handler") and hasattr(self._data_handler, "__call__"): + try: + if inspect.iscoroutinefunction(self._data_handler): + self.run_async(self._data_handler(self, data, is_end)) + else: + self._data_handler(self, data, is_end) + except Exception as err: + logging.error("Error on data handler %s" % str(err)) + + return self + + def trigger_writable_handler(self, offset): + if self.aborted: + return False + if hasattr(self, "_writable_handler") and hasattr( + self._writable_handler, "__call__" + ): + try: + if inspect.iscoroutinefunction(self._writable_handler): + raise RuntimeError("AppResponse.on_writable must be synchronous") + return self._writable_handler(self, offset) + except Exception as err: + logging.error("Error on writable handler %s" % str(err)) + return False + return False + + def run_async(self, task): + self.grab_aborted_handler() + return self.app.loop.run_async(task, self) + + async def get_form_urlencoded(self, encoding="utf-8"): + data = await self.get_data() + try: + # decode and unquote all + result = {} + parsed = parse_qs(data.getvalue(), encoding=encoding) + has_value = False + for key in parsed: + has_value = True + try: + value = parsed[key] + new_key = key.decode(encoding) + last_value = value[len(value) - 1] + + result[new_key] = unquote_plus(last_value.decode(encoding)) + except Exception as error: + pass + return result if has_value else None + except Exception as error: + return None # invalid encoding + + async def get_text(self, encoding="utf-8"): + data = await self.get_data() + try: + return data.getvalue().decode(encoding) + except Exception: + return None # invalid encoding + + async def get_json(self): + data = await self.get_data() + try: + return self.app._json_serializer.loads(data.getvalue().decode("utf-8")) + except Exception: + return None # invalid json + + def send_chunk(self, buffer, total_size): + self._chunkFuture = self.app.loop.create_future() + self._lastChunkOffset = 0 + + def is_aborted(self): + self.aborted = True + try: + if not self._chunkFuture.done(): + self._chunkFuture.set_result( + (False, True) + ) # if aborted set to done True and ok False + except: + pass + + def on_writeble(self, offset): + # Here the timeout is off, we can spend as much time before calling try_end we want to + (ok, done) = self.try_end( + buffer[offset - self._lastChunkOffset : :], total_size + ) + if ok: + self._chunkFuture.set_result((ok, done)) + return ok + + self.on_writable(on_writeble) + self.on_aborted(is_aborted) + + if self.aborted: + self._chunkFuture.set_result( + (False, True) + ) # if aborted set to done True and ok False + return self._chunkFuture + + (ok, done) = self.try_end(buffer, total_size) + if ok: + self._chunkFuture.set_result((ok, done)) + return self._chunkFuture + # failed to send chunk + self._lastChunkOffset = self.get_write_offset() + + return self._chunkFuture + + def get_data(self): + self._dataFuture = self.app.loop.create_future() + self._data = BytesIO() + + def is_aborted(self): + self.aborted = True + try: + if not self._dataFuture.done(): + self._dataFuture.set_result(self._data) + except: + pass + + def get_chunks(self, chunk, is_end): + self._data.write(chunk) + if is_end: + self._dataFuture.set_result(self._data) + self._data = None + + self.on_aborted(is_aborted) + self.on_data(get_chunks) + return self._dataFuture + + def grab_aborted_handler(self): + # only needed if is async + if not self.aborted and not self._grabbed_abort_handler_once: + self._grabbed_abort_handler_once = True + lib.uws_res_on_aborted( + self.app.SSL, self.res, uws_generic_aborted_handler, self._ptr + ) + return self + + def redirect(self, location, status_code=302): + self.write_status(status_code) + self.write_header("Location", location) + self.end_without_body(False) + return self + + def write_offset(self, offset): + lib.uws_res_override_write_offset( + self.app.SSL, self.res, ffi.cast("uintmax_t", offset) + ) + return self + + def try_end(self, message, total_size, end_connection=False): + try: + if self.aborted: + return False, True + if self._write_jar is not None: + self.write_header("Set-Cookie", self._write_jar.output(header="")) + self._write_jar = None + if isinstance(message, str): + data = message.encode("utf-8") + elif isinstance(message, bytes): + data = message + else: + return False, True + result = lib.uws_res_try_end( + self.app.SSL, + self.res, + data, + len(data), + ffi.cast("uintmax_t", total_size), + 1 if end_connection else 0, + ) + return bool(result.ok), bool(result.has_responded) + except: + return False, True + + def cork_end(self, message, end_connection=False): + self.cork(lambda res: res.end(message, end_connection)) + return self + + def render(self, *args, **kwargs): + if self.app._template: + def render(res): + res.write_header(b'Content-Type', b'text/html') + res.end(self.app._template.render(*args, **kwargs)) + self.cork(render) + return self + raise RuntimeError("No registered templated engine") + + def get_remote_address_bytes(self): + buffer = ffi.new("char**") + length = lib.uws_res_get_remote_address(self.app.SSL, self.res, buffer) + buffer_address = ffi.addressof(buffer, 0)[0] + if buffer_address == ffi.NULL: + return None + try: + return ffi.unpack(buffer_address, length) + except Exception: # invalid + return None + + def get_remote_address(self): + buffer = ffi.new("char**") + length = lib.uws_res_get_remote_address_as_text(self.app.SSL, self.res, buffer) + buffer_address = ffi.addressof(buffer, 0)[0] + if buffer_address == ffi.NULL: + return None + try: + return ffi.unpack(buffer_address, length).decode("utf-8") + except Exception: # invalid utf-8 + return None + + def get_proxied_remote_address_bytes(self): + buffer = ffi.new("char**") + length = lib.uws_res_get_proxied_remote_address(self.app.SSL, self.res, buffer) + buffer_address = ffi.addressof(buffer, 0)[0] + if buffer_address == ffi.NULL: + return None + try: + return ffi.unpack(buffer_address, length) + except Exception: # invalid + return None + + def get_proxied_remote_address(self): + buffer = ffi.new("char**") + length = lib.uws_res_get_proxied_remote_address_as_text( + self.app.SSL, self.res, buffer + ) + buffer_address = ffi.addressof(buffer, 0)[0] + if buffer_address == ffi.NULL: + return None + try: + return ffi.unpack(buffer_address, length).decode("utf-8") + except Exception: # invalid utf-8 + return None + + def end(self, message, end_connection=False): + try: + if self.aborted: + return self + if self._write_jar is not None: + self.write_header("Set-Cookie", self._write_jar.output(header="")) + self._write_jar = None + if isinstance(message, str): + data = message.encode("utf-8") + elif isinstance(message, bytes): + data = message + elif message is None: + self.end_without_body(end_connection) + return self + else: + self.write_header(b"Content-Type", b"application/json") + data = self.app._json_serializer.dumps(message).encode("utf-8") + lib.uws_res_end( + self.app.SSL, self.res, data, len(data), 1 if end_connection else 0 + ) + finally: + return self + + def pause(self): + if not self.aborted: + lib.uws_res_pause(self.app.SSL, self.res) + return self + + def resume(self): + if not self.aborted: + lib.uws_res_resume(self.app.SSL, self.res) + return self + + def write_continue(self): + if not self.aborted: + lib.uws_res_write_continue(self.app.SSL, self.res) + return self + + def write_status(self, status_or_status_text): + if not self.aborted: + if isinstance(status_or_status_text, int): + if bool(lib.socketify_res_write_int_status(self.app.SSL, self.res, status_or_status_text)): + return self + raise RuntimeError( + '"%d" Is not an valid Status Code' % status_or_status_text + ) + + elif isinstance(status_or_status_text, str): + data = status_or_status_text.encode("utf-8") + elif isinstance(status_or_status_text, bytes): + data = status_or_status_text + else: + data = self.app._json_serializer.dumps(status_or_status_text).encode("utf-8") + + lib.uws_res_write_status(self.app.SSL, self.res, data, len(data)) + return self + + def write_header(self, key, value): + if not self.aborted: + if isinstance(key, str): + key_data = key.encode("utf-8") + elif isinstance(key, bytes): + key_data = key + else: + key_data = self.app._json_serializer.dumps(key).encode("utf-8") + + if isinstance(value, int): + lib.uws_res_write_header_int( + self.app.SSL, + self.res, + key_data, + len(key_data), + ffi.cast("uint64_t", value), + ) + elif isinstance(value, str): + value_data = value.encode("utf-8") + elif isinstance(value, bytes): + value_data = value + else: + value_data = self.app._json_serializer.dumps(value).encode("utf-8") + lib.uws_res_write_header( + self.app.SSL, self.res, key_data, len(key_data), value_data, len(value_data) + ) + return self + + def end_without_body(self, end_connection=False): + if not self.aborted: + if self._write_jar is not None: + self.write_header("Set-Cookie", self._write_jar.output(header="")) + lib.uws_res_end_without_body(self.app.SSL, self.res, 1 if end_connection else 0) + return self + + def write(self, message): + if not self.aborted: + if isinstance(message, str): + data = message.encode("utf-8") + elif isinstance(message, bytes): + data = message + else: + data = self.app._json_serializer.dumps(message).encode("utf-8") + lib.uws_res_write(self.app.SSL, self.res, data, len(data)) + return self + + def get_write_offset(self): + if not self.aborted: + return int(lib.uws_res_get_write_offset(self.app.SSL, self.res)) + return 0 + + def has_responded(self): + if self.aborted: + return False + return bool(lib.uws_res_has_responded(self.app.SSL, self.res)) + + def on_aborted(self, handler): + if hasattr(handler, "__call__"): + self._aborted_handler = handler + self.grab_aborted_handler() + return self + + def on_data(self, handler): + if not self.aborted: + if hasattr(handler, "__call__"): + self._data_handler = handler + self.grab_aborted_handler() + lib.uws_res_on_data( + self.app.SSL, self.res, uws_generic_on_data_handler, self._ptr + ) + return self + + def upgrade( + self, + sec_web_socket_key, + sec_web_socket_protocol, + sec_web_socket_extensions, + socket_context, + user_data=None, + ): + if self.aborted: + return False + + if isinstance(sec_web_socket_key, str): + sec_web_socket_key_data = sec_web_socket_key.encode("utf-8") + elif isinstance(sec_web_socket_key, bytes): + sec_web_socket_key_data = sec_web_socket_key + else: + sec_web_socket_key_data = b"" + + if isinstance(sec_web_socket_protocol, str): + sec_web_socket_protocol_data = sec_web_socket_protocol.encode("utf-8") + elif isinstance(sec_web_socket_protocol, bytes): + sec_web_socket_protocol_data = sec_web_socket_protocol + else: + sec_web_socket_protocol_data = b"" + + if isinstance(sec_web_socket_extensions, str): + sec_web_socket_extensions_data = sec_web_socket_extensions.encode("utf-8") + elif isinstance(sec_web_socket_extensions, bytes): + sec_web_socket_extensions_data = sec_web_socket_extensions + else: + sec_web_socket_extensions_data = b"" + + user_data_ptr = ffi.NULL + if user_data is not None: + _id = uuid.uuid4() + user_data_ptr = ffi.new_handle((user_data, _id)) + # keep alive data + self.app._socket_refs[_id] = user_data_ptr + + lib.uws_res_upgrade( + self.app.SSL, + self.res, + user_data_ptr, + sec_web_socket_key_data, + len(sec_web_socket_key_data), + sec_web_socket_protocol_data, + len(sec_web_socket_protocol_data), + sec_web_socket_extensions_data, + len(sec_web_socket_extensions_data), + socket_context, + ) + return True + + def on_writable(self, handler): + if not self.aborted: + if hasattr(handler, "__call__"): + self._writable_handler = handler + self.grab_aborted_handler() + lib.uws_res_on_writable( + self.app.SSL, self.res, uws_generic_on_writable_handler, self._ptr + ) + return self + + def get_native_handle(self): + return lib.uws_res_get_native_handle(self.app.SSL, self.res) + + def __del__(self): + self.res = ffi.NULL + self._ptr = ffi.NULL + class WSBehaviorHandlers: def __init__(self): @@ -1079,10 +1618,51 @@ class WSBehaviorHandlers: class WebSocketFactory: def __init__(self, app, max_size): self.factory_queue = [] - for _ in range(0, max_size): - websocket = WebSocket(None, app) + self.app = app + self.max_size = max_size + self.dispose = self._dispose + self.populate = self._populate + self.get = self._get + + def update_extensions(self): + self.populate = self._populate_with_extension + self.get = self._get_with_extension + if len(self.app._ws_extension.properties) > 0: + self.dispose = self._dispose_with_extension + + def _populate_with_extension(self): + self.factory_queue = [] + for _ in range(0, self.max_size): + websocket = WebSocket(None, self.app) + # bind methods to websocket + self.app._ws_extension.set_properties(websocket) + # set default value in properties + self.app._ws_extension.bind_methods(websocket) self.factory_queue.append((websocket, True)) - def get(self, app, ws): + + def _populate(self): + self.factory_queue = [] + for _ in range(0, self.max_size): + websocket = WebSocket(None, self.app) + self.factory_queue.append((websocket, True)) + + + def _get_with_extension(self, app, ws): + if len(self.factory_queue) == 0: + websocket = WebSocket(ws, app) + # bind methods to websocket + self.app._ws_extension.set_properties(websocket) + # set default value in properties + self.app._ws_extension.bind_methods(websocket) + return websocket, False + + instances = self.factory_queue.pop() + (websocket, _) = instances + websocket.ws = ws + return instances + + + def _get(self, app, ws): if len(self.factory_queue) == 0: response = WebSocket(ws, app) return response, False @@ -1092,7 +1672,20 @@ class WebSocketFactory: websocket.ws = ws return instances - def dispose(self, instances): + def _dispose_with_extension(self, instances): + (websocket, _) = instances + #dispose ws + websocket.ws = None + websocket._cork_handler = None + websocket._for_each_topic_handler = None + websocket.socket_data_id = None + websocket.socket_data = None + websocket.got_socket_data = False + # set default value in properties + self.app._ws_extension.set_properties(websocket) + self.factory_queue.append(instances) + + def _dispose(self, instances): (websocket, _) = instances #dispose ws websocket.ws = None @@ -1106,25 +1699,104 @@ class WebSocketFactory: class RequestResponseFactory: def __init__(self, app, max_size): self.factory_queue = [] - for _ in range(0, max_size): - response = AppResponse(None, app.loop, app.SSL, app._template, app._socket_refs) - request = AppRequest(None) + self.app = app + self.max_size = max_size + self.dispose = self._dispose + self.populate = self._populate + self.get = self._get + + def update_extensions(self): + self.dispose = self._dispose_with_extension + self.populate = self._populate_with_extension + self.get = self._get_with_extension + + def _populate_with_extension(self): + self.factory_queue = [] + for _ in range(0, self.max_size): + response = AppResponse(None, self.app) + # set default value in properties + self.app._response_extension.set_properties(response) + # bind methods to response + self.app._response_extension.bind_methods(response) + request = AppRequest(None, self.app) + # set default value in properties + self.app._request_extension.set_properties(request) + # bind methods to request + self.app._request_extension.bind_methods(request) self.factory_queue.append((response, request, True)) - def get(self, app, res, req): + def _populate(self): + self.factory_queue = [] + for _ in range(0, self.max_size): + response = AppResponse(None, self.app) + request = AppRequest(None, self.app) + self.factory_queue.append((response, request, True)) + + def _get_with_extension(self, app, res, req): if len(self.factory_queue) == 0: - response = AppResponse(res, app.loop, app.SSL, app._template, app._socket_refs) - request = AppRequest(req) + response = AppResponse(res, app) + # set default value in properties + self.app._response_extension.set_properties(response) + # bind methods to response + self.app._response_extension.bind_methods(response) + request = AppRequest(req, app) + # set default value in properties + self.app._request_extension.set_properties(request) + # bind methods to request + self.app._request_extension.bind_methods(request) return response, request, False instances = self.factory_queue.pop() (response, request, _) = instances response.res = res - response._render = app._template request.req = req return instances - def dispose(self, instances): + def _get(self, app, res, req): + if len(self.factory_queue) == 0: + response = AppResponse(res, app) + request = AppRequest(req, app) + return response, request, False + + instances = self.factory_queue.pop() + (response, request, _) = instances + response.res = res + request.req = req + return instances + + def _dispose_with_extension(self, instances): + (res, req, _) = instances + #dispose res + res.res = None + res.aborted = False + res._aborted_handler = None + res._writable_handler = None + res._data_handler = None + res._grabbed_abort_handler_once = False + res._write_jar = None + res._cork_handler = None + res._lastChunkOffset = 0 + res._chunkFuture = None + res._dataFuture = None + res._data = None + # set default value in properties + self.app._response_extension.set_properties(res) + #dispose req + req.req = None + req.read_jar = None + req.jar_parsed = False + req._for_each_header_handler = None + req._headers = None + req._params = None + req._query = None + req._url = None + req._full_url = None + req._method = None + # set default value in properties + self.app._request_extension.set_properties(req) + self.factory_queue.append(instances) + + def _dispose(self, instances): (res, req, _) = instances #dispose res res.res = None @@ -1139,7 +1811,6 @@ class RequestResponseFactory: res._chunkFuture = None res._dataFuture = None res._data = None - res._render = None #dispose req req.req = None req.read_jar = None @@ -1154,8 +1825,9 @@ class RequestResponseFactory: self.factory_queue.append(instances) class AppRequest: - def __init__(self, request): + def __init__(self, request, app): self.req = request + self.app = app self.read_jar = None self.jar_parsed = False self._for_each_header_handler = None @@ -1259,7 +1931,7 @@ class AppRequest: elif isinstance(lower_case_header, bytes): data = lower_case_header else: - data = json.dumps(lower_case_header).encode("utf-8") + data = self.app._json_serializer.dumps(lower_case_header).encode("utf-8") buffer = ffi.new("char**") length = lib.uws_req_get_header(self.req, data, len(data), buffer) @@ -1296,7 +1968,7 @@ class AppRequest: elif isinstance(key, bytes): key_data = key else: - key_data = json.dumps(key).encode("utf-8") + key_data = self.app._json_serializer.dumps(key).encode("utf-8") length = lib.uws_req_get_query(self.req, key_data, len(key_data), buffer) buffer_address = ffi.addressof(buffer, 0)[0] @@ -1376,498 +2048,35 @@ class AppRequest: self.req = ffi.NULL self._ptr = ffi.NULL +class AppExtension: + def __init__(self): + self.properties = [] + self.methods = [] + self.empty = True -class AppResponse: - def __init__(self, response, loop, ssl, render, socket_refs): - self.res = response - self.SSL = ssl - self._socket_refs = socket_refs - self.aborted = False - self.loop = loop - self._aborted_handler = None - self._writable_handler = None - self._data_handler = None - self._ptr = ffi.new_handle(self) - self._grabbed_abort_handler_once = False - self._write_jar = None - self._cork_handler = None - self._lastChunkOffset = 0 - self._chunkFuture = None - self._dataFuture = None - self._data = None - self._render = render + def bind_methods(self, instance: any): + for (name, method) in self.methods: + """ + Bind the function *func* to *instance*, with either provided name *as_name* + or the existing name of *func*. The provided *func* should accept the + instance as the first argument, i.e. "self". + """ + bound_method = method.__get__(instance, instance.__class__) + setattr(instance, name, bound_method) - def cork(self, callback): - if not self.aborted: - self.grab_aborted_handler() - self._cork_handler = callback - lib.uws_res_cork(self.SSL, self.res, uws_generic_cork_handler, self._ptr) + def set_properties(self, instance: any): + for (name, property) in self.properties: + setattr(instance, name, property) - def set_cookie(self, name, value, options): - if options is None: - options = {} - if self._write_jar is None: - self._write_jar = cookies.SimpleCookie() - self._write_jar[name] = quote_plus(value) - if isinstance(options, dict): - for key in options: - if key == "expires" and isinstance(options[key], datetime): - self._write_jar[name][key] = options[key].strftime( - "%a, %d %b %Y %H:%M:%S GMT" - ) - else: - self._write_jar[name][key] = options[key] - def trigger_aborted(self): - self.aborted = True - self._ptr = ffi.NULL - self.res = ffi.NULL - if hasattr(self, "_aborted_handler") and hasattr( - self._aborted_handler, "__call__" - ): - try: - if inspect.iscoroutinefunction(self._aborted_handler): - self.run_async(self._aborted_handler(self)) - else: - self._aborted_handler(self) - except Exception as err: - logging.error("Error on abort handler %s" % str(err)) - return self + def method(self, method: callable): + self.empty = False + self.methods.append((method.__name__, method)) + return method - def trigger_data_handler(self, data, is_end): - if self.aborted: - return self - if hasattr(self, "_data_handler") and hasattr(self._data_handler, "__call__"): - try: - if inspect.iscoroutinefunction(self._data_handler): - self.run_async(self._data_handler(self, data, is_end)) - else: - self._data_handler(self, data, is_end) - except Exception as err: - logging.error("Error on data handler %s" % str(err)) - - return self - - def trigger_writable_handler(self, offset): - if self.aborted: - return False - if hasattr(self, "_writable_handler") and hasattr( - self._writable_handler, "__call__" - ): - try: - if inspect.iscoroutinefunction(self._writable_handler): - raise RuntimeError("AppResponse.on_writable must be synchronous") - return self._writable_handler(self, offset) - except Exception as err: - logging.error("Error on writable handler %s" % str(err)) - return False - return False - - def run_async(self, task): - self.grab_aborted_handler() - return self.loop.run_async(task, self) - - async def get_form_urlencoded(self, encoding="utf-8"): - data = await self.get_data() - try: - # decode and unquote all - result = {} - parsed = parse_qs(data.getvalue(), encoding=encoding) - has_value = False - for key in parsed: - has_value = True - try: - value = parsed[key] - new_key = key.decode(encoding) - last_value = value[len(value) - 1] - - result[new_key] = unquote_plus(last_value.decode(encoding)) - except Exception as error: - pass - return result if has_value else None - except Exception as error: - return None # invalid encoding - - async def get_text(self, encoding="utf-8"): - data = await self.get_data() - try: - return data.getvalue().decode(encoding) - except Exception: - return None # invalid encoding - - async def get_json(self): - data = await self.get_data() - try: - return json.loads(data.getvalue().decode("utf-8")) - except Exception: - return None # invalid json - - def send_chunk(self, buffer, total_size): - self._chunkFuture = self.loop.create_future() - self._lastChunkOffset = 0 - - def is_aborted(self): - self.aborted = True - try: - if not self._chunkFuture.done(): - self._chunkFuture.set_result( - (False, True) - ) # if aborted set to done True and ok False - except: - pass - - def on_writeble(self, offset): - # Here the timeout is off, we can spend as much time before calling try_end we want to - (ok, done) = self.try_end( - buffer[offset - self._lastChunkOffset : :], total_size - ) - if ok: - self._chunkFuture.set_result((ok, done)) - return ok - - self.on_writable(on_writeble) - self.on_aborted(is_aborted) - - if self.aborted: - self._chunkFuture.set_result( - (False, True) - ) # if aborted set to done True and ok False - return self._chunkFuture - - (ok, done) = self.try_end(buffer, total_size) - if ok: - self._chunkFuture.set_result((ok, done)) - return self._chunkFuture - # failed to send chunk - self._lastChunkOffset = self.get_write_offset() - - return self._chunkFuture - - def get_data(self): - self._dataFuture = self.loop.create_future() - self._data = BytesIO() - - def is_aborted(self): - self.aborted = True - try: - if not self._dataFuture.done(): - self._dataFuture.set_result(self._data) - except: - pass - - def get_chunks(self, chunk, is_end): - self._data.write(chunk) - if is_end: - self._dataFuture.set_result(self._data) - self._data = None - - self.on_aborted(is_aborted) - self.on_data(get_chunks) - return self._dataFuture - - def grab_aborted_handler(self): - # only needed if is async - if not self.aborted and not self._grabbed_abort_handler_once: - self._grabbed_abort_handler_once = True - lib.uws_res_on_aborted( - self.SSL, self.res, uws_generic_aborted_handler, self._ptr - ) - return self - - def redirect(self, location, status_code=302): - self.write_status(status_code) - self.write_header("Location", location) - self.end_without_body(False) - return self - - def write_offset(self, offset): - lib.uws_res_override_write_offset( - self.SSL, self.res, ffi.cast("uintmax_t", offset) - ) - return self - - def try_end(self, message, total_size, end_connection=False): - try: - if self.aborted: - return False, True - if self._write_jar is not None: - self.write_header("Set-Cookie", self._write_jar.output(header="")) - self._write_jar = None - if isinstance(message, str): - data = message.encode("utf-8") - elif isinstance(message, bytes): - data = message - else: - return False, True - result = lib.uws_res_try_end( - self.SSL, - self.res, - data, - len(data), - ffi.cast("uintmax_t", total_size), - 1 if end_connection else 0, - ) - return bool(result.ok), bool(result.has_responded) - except: - return False, True - - def cork_end(self, message, end_connection=False): - self.cork(lambda res: res.end(message, end_connection)) - return self - - def render(self, *args, **kwargs): - if self._render: - def render(res): - res.write_header(b'Content-Type', b'text/html') - res.end(self._render.render(*args, **kwargs)) - self.cork(render) - return self - raise RuntimeError("No registered templated engine") - - def get_remote_address_bytes(self): - buffer = ffi.new("char**") - length = lib.uws_res_get_remote_address(self.SSL, self.res, buffer) - buffer_address = ffi.addressof(buffer, 0)[0] - if buffer_address == ffi.NULL: - return None - try: - return ffi.unpack(buffer_address, length) - except Exception: # invalid - return None - - def get_remote_address(self): - buffer = ffi.new("char**") - length = lib.uws_res_get_remote_address_as_text(self.SSL, self.res, buffer) - buffer_address = ffi.addressof(buffer, 0)[0] - if buffer_address == ffi.NULL: - return None - try: - return ffi.unpack(buffer_address, length).decode("utf-8") - except Exception: # invalid utf-8 - return None - - def get_proxied_remote_address_bytes(self): - buffer = ffi.new("char**") - length = lib.uws_res_get_proxied_remote_address(self.SSL, self.res, buffer) - buffer_address = ffi.addressof(buffer, 0)[0] - if buffer_address == ffi.NULL: - return None - try: - return ffi.unpack(buffer_address, length) - except Exception: # invalid - return None - - def get_proxied_remote_address(self): - buffer = ffi.new("char**") - length = lib.uws_res_get_proxied_remote_address_as_text( - self.SSL, self.res, buffer - ) - buffer_address = ffi.addressof(buffer, 0)[0] - if buffer_address == ffi.NULL: - return None - try: - return ffi.unpack(buffer_address, length).decode("utf-8") - except Exception: # invalid utf-8 - return None - - def end(self, message, end_connection=False): - try: - if self.aborted: - return self - if self._write_jar is not None: - self.write_header("Set-Cookie", self._write_jar.output(header="")) - self._write_jar = None - if isinstance(message, str): - data = message.encode("utf-8") - elif isinstance(message, bytes): - data = message - elif message is None: - self.end_without_body(end_connection) - return self - else: - self.write_header(b"Content-Type", b"application/json") - data = json.dumps(message).encode("utf-8") - lib.uws_res_end( - self.SSL, self.res, data, len(data), 1 if end_connection else 0 - ) - finally: - return self - - def pause(self): - if not self.aborted: - lib.uws_res_pause(self.SSL, self.res) - return self - - def resume(self): - if not self.aborted: - lib.uws_res_resume(self.SSL, self.res) - return self - - def write_continue(self): - if not self.aborted: - lib.uws_res_write_continue(self.SSL, self.res) - return self - - def write_status(self, status_or_status_text): - if not self.aborted: - if isinstance(status_or_status_text, int): - if bool(lib.socketify_res_write_int_status(self.SSL, self.res, status_or_status_text)): - return self - raise RuntimeError( - '"%d" Is not an valid Status Code' % status_or_status_text - ) - - elif isinstance(status_or_status_text, str): - data = status_or_status_text.encode("utf-8") - elif isinstance(status_or_status_text, bytes): - data = status_or_status_text - else: - data = json.dumps(status_or_status_text).encode("utf-8") - - lib.uws_res_write_status(self.SSL, self.res, data, len(data)) - return self - - def write_header(self, key, value): - if not self.aborted: - if isinstance(key, str): - key_data = key.encode("utf-8") - elif isinstance(key, bytes): - key_data = key - else: - key_data = json.dumps(key).encode("utf-8") - - if isinstance(value, int): - lib.uws_res_write_header_int( - self.SSL, - self.res, - key_data, - len(key_data), - ffi.cast("uint64_t", value), - ) - elif isinstance(value, str): - value_data = value.encode("utf-8") - elif isinstance(value, bytes): - value_data = value - else: - value_data = json.dumps(value).encode("utf-8") - lib.uws_res_write_header( - self.SSL, self.res, key_data, len(key_data), value_data, len(value_data) - ) - return self - - def end_without_body(self, end_connection=False): - if not self.aborted: - if self._write_jar is not None: - self.write_header("Set-Cookie", self._write_jar.output(header="")) - lib.uws_res_end_without_body(self.SSL, self.res, 1 if end_connection else 0) - return self - - def write(self, message): - if not self.aborted: - if isinstance(message, str): - data = message.encode("utf-8") - elif isinstance(message, bytes): - data = message - else: - data = json.dumps(message).encode("utf-8") - lib.uws_res_write(self.SSL, self.res, data, len(data)) - return self - - def get_write_offset(self): - if not self.aborted: - return int(lib.uws_res_get_write_offset(self.SSL, self.res)) - return 0 - - def has_responded(self): - if self.aborted: - return False - return bool(lib.uws_res_has_responded(self.SSL, self.res)) - - def on_aborted(self, handler): - if hasattr(handler, "__call__"): - self._aborted_handler = handler - self.grab_aborted_handler() - return self - - def on_data(self, handler): - if not self.aborted: - if hasattr(handler, "__call__"): - self._data_handler = handler - self.grab_aborted_handler() - lib.uws_res_on_data( - self.SSL, self.res, uws_generic_on_data_handler, self._ptr - ) - return self - - def upgrade( - self, - sec_web_socket_key, - sec_web_socket_protocol, - sec_web_socket_extensions, - socket_context, - user_data=None, - ): - if self.aborted: - return False - - if isinstance(sec_web_socket_key, str): - sec_web_socket_key_data = sec_web_socket_key.encode("utf-8") - elif isinstance(sec_web_socket_key, bytes): - sec_web_socket_key_data = sec_web_socket_key - else: - sec_web_socket_key_data = b"" - - if isinstance(sec_web_socket_protocol, str): - sec_web_socket_protocol_data = sec_web_socket_protocol.encode("utf-8") - elif isinstance(sec_web_socket_protocol, bytes): - sec_web_socket_protocol_data = sec_web_socket_protocol - else: - sec_web_socket_protocol_data = b"" - - if isinstance(sec_web_socket_extensions, str): - sec_web_socket_extensions_data = sec_web_socket_extensions.encode("utf-8") - elif isinstance(sec_web_socket_extensions, bytes): - sec_web_socket_extensions_data = sec_web_socket_extensions - else: - sec_web_socket_extensions_data = b"" - - user_data_ptr = ffi.NULL - if user_data is not None: - _id = uuid.uuid4() - user_data_ptr = ffi.new_handle((user_data, _id)) - # keep alive data - self._socket_refs[_id] = user_data_ptr - - lib.uws_res_upgrade( - self.SSL, - self.res, - user_data_ptr, - sec_web_socket_key_data, - len(sec_web_socket_key_data), - sec_web_socket_protocol_data, - len(sec_web_socket_protocol_data), - sec_web_socket_extensions_data, - len(sec_web_socket_extensions_data), - socket_context, - ) - return True - - def on_writable(self, handler): - if not self.aborted: - if hasattr(handler, "__call__"): - self._writable_handler = handler - self.grab_aborted_handler() - lib.uws_res_on_writable( - self.SSL, self.res, uws_generic_on_writable_handler, self._ptr - ) - return self - - def get_native_handle(self): - return lib.uws_res_get_native_handle(self.SSL, self.res) - - def __del__(self): - self.res = ffi.NULL - self._ptr = ffi.NULL + def property(self, name: str, default_value: any=None): + self.empty = False + self.properties.append((name, default_value)) class App: @@ -1945,9 +2154,31 @@ class App: self._ws_factory = WebSocketFactory(self, websocket_factory_max_items) else: self._ws_factory = None + self._json_serializer = json + self._request_extension = None + self._response_extension = None + self._ws_extension = None + def register(self, extension): + if self._request_extension is None: + self._request_extension = AppExtension() + if self._response_extension is None: + self._response_extension = AppExtension() + if self._ws_extension is None: + self._ws_extension = AppExtension() + + extension(self._request_extension, self._response_extension, self._ws_extension) + + if self._factory is not None and (not self._request_extension.empty or not self._response_extension.empty): + self._factory.update_extensions() + if self._ws_factory is not None and not self._ws_extension.empty: + self._ws_factory.update_extensions() + def template(self, template_engine): self._template = template_engine + + def json_serializer(self, json_serializer): + self._json_serializer = json_serializer def static(self, route, directory): static_route(self, route, directory) @@ -1956,11 +2187,18 @@ class App: def get(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_get( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -1968,11 +2206,18 @@ class App: def post(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_post( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -1980,11 +2225,18 @@ class App: def options(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_options( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -1992,11 +2244,18 @@ class App: def delete(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_delete( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2004,11 +2263,18 @@ class App: def patch(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_patch( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2016,11 +2282,19 @@ class App: def put(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + + lib.uws_app_put( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2028,11 +2302,19 @@ class App: def head(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + + lib.uws_app_head( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2040,11 +2322,19 @@ class App: def connect(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + + lib.uws_app_connect( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2052,11 +2342,18 @@ class App: def trace(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + lib.uws_app_trace( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2064,11 +2361,19 @@ class App: def any(self, path, handler): user_data = ffi.new_handle((handler, self)) self.handlers.append(user_data) # Keep alive handler + if self._factory: + handler = uws_generic_factory_method_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + handler = uws_generic_method_handler_with_extension + else: + handler = uws_generic_method_handler + + lib.uws_app_any( self.SSL, self.app, path.encode("utf-8"), - uws_generic_factory_method_handler if self._factory else uws_generic_method_handler, + handler, user_data, ) return self @@ -2103,7 +2408,7 @@ class App: elif message is None: message_data = b"" else: - message_data = json.dumps(message).encode("utf-8") + message_data = self._json_serializer.dumps(message).encode("utf-8") return bool( lib.uws_publish( @@ -2271,7 +2576,13 @@ class App: handlers = WSBehaviorHandlers() if upgrade_handler: handlers.upgrade = upgrade_handler - native_behavior.upgrade = uws_websocket_factory_upgrade_handler if self._factory else uws_websocket_upgrade_handler + + if self._factory: + native_behavior.upgrade = uws_websocket_factory_upgrade_handler + elif self._response_extension and (not self._response_extension.empty or not self._request_extension.empty): + native_behavior.upgrade = uws_websocket_upgrade_handler_with_extension + else: + native_behavior.upgrade = uws_websocket_upgrade_handler else: native_behavior.upgrade = ffi.NULL @@ -2391,6 +2702,12 @@ class App: return self.loop.run_async(task, response) def run(self): + # populate factories + if self._factory is not None: + self._factory.populate() + if self._ws_factory is not None: + self._ws_factory.populate() + signal.signal(signal.SIGINT, lambda sig, frame: self.close()) self.loop.run() return self diff --git a/src/socketify/wsgi.py b/src/socketify/wsgi.py index 9732305..ab89469 100644 --- a/src/socketify/wsgi.py +++ b/src/socketify/wsgi.py @@ -201,9 +201,6 @@ def wsgi(ssl, response, info, user_data, aborted): lib.uws_res_write_header( ssl, response, key_data, len(key_data), value_data, len(value_data) ) - lib.uws_res_write_header( - ssl, response, b'Server', 6, b'socketify.py', 12 - ) # check for body if bool(info.has_content): diff --git a/src/tests.py b/src/tests.py index af60a13..eb41c7a 100644 --- a/src/tests.py +++ b/src/tests.py @@ -14,13 +14,40 @@ def extension(request, response, ws): async def get_cart(self): return [{ "quantity": 10, "name": "T-Shirt" }] - request.property("token", None) + @response.method + def send(self, content: any, content_type: str = b'text/plain', status=200): + self.write_header(b'Content-Type', content_type) + self.write_status(status) + self.end(content) + request.property("token", "testing") + +# extensions must be registered before routes app.register(extension) -app.get("/", lambda res, req: res.end("Hello World!")) +async def home(res, req): + print("token", req.token) + cart = await req.get_cart() + print("cart", cart) + user = await req.get_user() + print("user", user) + print("token", req.token) + res.send("Hello World!") + +app.get("/", home) app.listen( 3000, lambda config: print("Listening on port http://localhost:%d now\n" % config.port), ) app.run() + +# uws_websocket_upgrade_handler +# uws_generic_method_handler +# uws_websocket_drain_handler +# uws_websocket_subscription_handler +# uws_websocket_open_handler +# uws_websocket_message_handler +# uws_websocket_pong_handler +# uws_websocket_ping_handler +# uws_websocket_close_handler +# uws_websocket_subscription_handler \ No newline at end of file