diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index caa0eeb..3acbc79 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -161,6 +161,9 @@ typedef void (*uws_listen_handler)(struct us_listen_socket_t *listen_socket, uws typedef void (*uws_method_handler)(uws_res_t *response, uws_req_t *request, void *user_data); typedef void (*uws_filter_handler)(uws_res_t *response, int, void *user_data); typedef void (*uws_missing_server_handler)(const char *hostname, void *user_data); +typedef void (*uws_get_headers_server_handler)(const char *header_name, size_t header_name_size, const char *header_value, size_t header_value_size, void *user_data); + + uws_app_t *uws_create_app(int ssl, struct us_socket_context_options_t options); void uws_app_destroy(int ssl, uws_app_t *app); void uws_app_get(int ssl, uws_app_t *app, const char *pattern, uws_method_handler handler, void *user_data); @@ -179,7 +182,7 @@ void uws_app_run(int ssl, uws_app_t *); void uws_app_listen(int ssl, uws_app_t *app, int port, uws_listen_handler handler, void *user_data); void uws_app_listen_with_config(int ssl, uws_app_t *app, uws_app_listen_config_t config, uws_listen_handler handler, void *user_data); bool uws_constructor_failed(int ssl, uws_app_t *app); -unsigned int uws_num_subscribers(int ssl, uws_app_t *app, const char *topic); +unsigned int uws_num_subscribers(int ssl, uws_app_t *app, const char *topic, size_t topic_length); bool uws_publish(int ssl, uws_app_t *app, const char *topic, size_t topic_length, const char *message, size_t message_length, uws_opcode_t opcode, bool compress); void *uws_get_native_handle(int ssl, uws_app_t *app); void uws_remove_server_name(int ssl, uws_app_t *app, const char *hostname_pattern); @@ -196,6 +199,7 @@ void uws_res_resume(int ssl, uws_res_t *res); void uws_res_write_continue(int ssl, uws_res_t *res); void uws_res_write_status(int ssl, uws_res_t *res, const char *status, size_t length); void uws_res_write_header(int ssl, uws_res_t *res, const char *key, size_t key_length, const char *value, size_t value_length); +void uws_res_override_write_offset(int ssl, uws_res_t *res, uintmax_t offset); void uws_res_write_header_int(int ssl, uws_res_t *res, const char *key, size_t key_length, uint64_t value); void uws_res_end_without_body(int ssl, uws_res_t *res, bool close_connection); @@ -206,7 +210,7 @@ void uws_res_on_writable(int ssl, uws_res_t *res, bool (*handler)(uws_res_t *res void uws_res_on_aborted(int ssl, uws_res_t *res, void (*handler)(uws_res_t *res, void *opcional_data), void *opcional_data); void uws_res_on_data(int ssl, uws_res_t *res, void (*handler)(uws_res_t *res, const char *chunk, size_t chunk_length, bool is_end, void *opcional_data), void *opcional_data); void uws_res_upgrade(int ssl, uws_res_t *res, void *data, const char *sec_web_socket_key, size_t sec_web_socket_key_length, const char *sec_web_socket_protocol, size_t sec_web_socket_protocol_length, const char *sec_web_socket_extensions, size_t sec_web_socket_extensions_length, uws_socket_context_t *ws); -uws_try_end_result_t uws_res_try_end(int ssl, uws_res_t *res, const char *data, size_t length, uintmax_t total_size); +uws_try_end_result_t uws_res_try_end(int ssl, uws_res_t *res, const char *data, size_t length, uintmax_t total_size, bool close_connection); void uws_res_cork(int ssl, uws_res_t *res,void(*callback)(uws_res_t *res, void* user_data) ,void* user_data); bool uws_req_is_ancient(uws_req_t *res); bool uws_req_get_yield(uws_req_t *res); @@ -216,11 +220,27 @@ size_t uws_req_get_method(uws_req_t *res, const char **dest); size_t uws_req_get_header(uws_req_t *res, const char *lower_case_header, size_t lower_case_header_length, const char **dest); size_t uws_req_get_query(uws_req_t *res, const char *key, size_t key_length, const char **dest); size_t uws_req_get_parameter(uws_req_t *res, unsigned short index, const char **dest); +size_t uws_req_get_full_url(uws_req_t *res, const char **dest); +void uws_req_for_each_header(uws_req_t *res, uws_get_headers_server_handler handler, void *user_data); + """) library_path = os.path.join(os.path.dirname(__file__), "libuwebsockets.so") lib = ffi.dlopen(library_path) +@ffi.callback("void(const char *, size_t, const char *, size_t, void *)") +def uws_req_for_each_header_handler(header_name, header_name_size, header_value, header_value_size, user_data): + if not user_data == ffi.NULL: + req = ffi.from_handle(user_data) + try: + + header_name = ffi.unpack(header_name, header_name_size).decode("utf-8") + header_value = ffi.unpack(header_value, header_value_size).decode("utf-8") + + req.trigger_for_each_header_handler(header_name, header_value) + except Exception: #invalid utf-8 + return + @ffi.callback("void(uws_res_t *, uws_req_t *, void *)") def uws_generic_method_handler(res, req, user_data): @@ -312,7 +332,10 @@ class AppRequest: self.req = request self.read_jar = None self.jar_parsed = False + self._for_each_header_handler = None + self._ptr = ffi.new_handle(self) + def get_cookie(self, name): if self.read_jar == None: if self.jar_parsed: @@ -339,6 +362,16 @@ class AppRequest: return ffi.unpack(buffer_address, length).decode("utf-8") except Exception: #invalid utf-8 return None + def get_full_url(self): + buffer = ffi.new("char**") + length = lib.uws_req_get_full_url(self.req, buffer) + buffer_address = ffi.addressof(buffer, 0)[0] + if buffer_address == ffi.NULL: + return None + try: + return ffi.unpack(buffer_address, length).decode("utf-8") + except Exception: #invalid utf-8 + return None def get_method(self): buffer = ffi.new("char**") length = lib.uws_req_get_method(self.req, buffer) @@ -350,6 +383,11 @@ class AppRequest: return ffi.unpack(buffer_address, length).decode("utf-8") except Exception: #invalid utf-8 return None + def for_each_header(self, handler): + self._for_each_header_handler = handler + lib.uws_req_for_each_header(self.req, uws_req_for_each_header_handler, self._ptr) + + def get_header(self, lower_case_header): if isinstance(lower_case_header, str): data = lower_case_header.encode("utf-8") @@ -401,6 +439,19 @@ class AppRequest: return bool(lib.uws_req_get_yield(self.req)) def is_ancient(self): return bool(lib.uws_req_is_ancient(self.req)) + def trigger_for_each_header_handler(self, key, value): + if hasattr(self, "_for_each_header_handler") and hasattr(self._for_each_header_handler, '__call__'): + try: + if inspect.iscoroutinefunction(self._for_each_header_handler): + raise RuntimeError("AppResponse.for_each_header_handler must be synchronous") + self._for_each_header_handler(key, value) + except Exception as err: + print("Error on data handler %s" % str(err)) + + return self + def __del__(self): + self.req = ffi.NULL + self._ptr = ffi.NULL class AppResponse: def __init__(self, response, loop, is_ssl): @@ -581,13 +632,19 @@ class AppResponse: if not self.aborted and not self._grabed_abort_handler_once: self._grabed_abort_handler_once = True lib.uws_res_on_aborted(self.SSL, self.res, uws_generic_aborted_handler, self._ptr) + return self def redirect(self, location, status_code=302): self.write_status(status_code) self.write_header("Location", location) self.end_without_body(False) + return self - def try_end(self, message, total_size): + def write_offset(self, offset): + lib.uws_res_override_write_offset(self.SSL, self.res, ffi.cast("uintmax_t", offset)) + return self + + def try_end(self, message, total_size, close_connection=False): try: if self.aborted: return (False, False) @@ -600,13 +657,14 @@ class AppResponse: data = message else: return (False, False) - result = lib.uws_res_try_end(self.SSL, self.res, data, len(data),ffi.cast("uintmax_t", total_size)) + result = lib.uws_res_try_end(self.SSL, self.res, data, len(data),ffi.cast("uintmax_t", total_size), 1 if end_connection else 0) return (bool(result.ok), bool(result.has_responded)) except: return (False, False) def cork_end(self, message, end_connection=False): self.cork(lambda res: res.end(message, end_connection)) + return self def end(self, message, end_connection=False): try: diff --git a/src/tests.py b/src/tests.py index 3c4af33..db4cae7 100644 --- a/src/tests.py +++ b/src/tests.py @@ -17,76 +17,32 @@ # multipart/form-data -# try_end -# get_full_url -# for_each_header + +# unsigned int uws_num_subscribers(int ssl, uws_app_t *app, const char *topic); +# bool uws_publish(int ssl, uws_app_t *app, const char *topic, size_t topic_length, const char *message, size_t message_length, uws_opcode_t opcode, bool compress); +# void *uws_get_native_handle(int ssl, uws_app_t *app); +# void uws_remove_server_name(int ssl, uws_app_t *app, const char *hostname_pattern); +# void uws_add_server_name(int ssl, uws_app_t *app, const char *hostname_pattern); +# void uws_add_server_name_with_options(int ssl, uws_app_t *app, const char *hostname_pattern, struct us_socket_context_options_t options); +# void uws_missing_server_name(int ssl, uws_app_t *app, uws_missing_server_handler handler, void *user_data); +# void uws_filter(int ssl, uws_app_t *app, uws_filter_handler handler, void *user_data); + # https://github.com/uNetworking/uWebSockets.js/blob/master/examples/VideoStreamer.js from socketify import App # import os import multiprocessing import asyncio -import aiofiles -from aiofiles import os import time import mimetypes mimetypes.init() #need to fix get_data using sel._data etc async def home(res, req): + print("full", req.get_full_url()) + print("normal", req.get_url()) - filename = "./file_example_MP3_5MG.mp3" - - if_modified_since = req.get_header('if-modified-since') - range_header = req.get_header('range') - bytes_range = None - start = 0 - end = -1 - if range_header: - bytes_range = range_header.replace("bytes=", '').split('-') - start = int(bytes_range[0]) - if bytes_range[1]: - end = int(bytes_range[1]) - try: - exists = await os.path.exists(filename) - if not exists: - return res.write_status(404).end(b'Not Found') - stats = await os.stat(filename) - total_size = stats.st_size - last_modified = time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(stats.st_mtime)) - if if_modified_since == last_modified: - res.write_status(304).end_without_body() - return - res.write_header(b'Last-Modified', last_modified) - - (content_type, encoding) = mimetypes.guess_type(filename, strict=True) - if content_type and encoding: - res.write_header(b'Content-Type', '%s; %s' % (content_type, encoding)) - elif content_type: - res.write_header(b'Content-Type', content_type) - - async with aiofiles.open(filename, "rb") as fd: - if start > 0 or not end == -1: - if end < 1 or end >= total_size: - end = total_size - total_size = end - start - await fd.seek(start) - res.write_status(206) - else: - end = total_size - res.write_status(200) - - #tells the browser that we support ranges - res.write_header(b'Accept-Ranges', b'bytes') - res.write_header(b'Content-Range', 'bytes %d-%d/%d' % (start, end, total_size)) - - while not res.aborted: - buffer = await fd.read(16384) #16kb chunks - (ok, done) = await res.send_chunk(buffer, total_size) - if not ok or done: #if cannot send probably aborted - break - except Exception as error: - print(str(error)) - res.write_status(500).end("Internal Error") + req.for_each_header(lambda key,value: print("Header %s: %s" % (key, value))) + res.end("Test") def run_app(): app = App() diff --git a/tests/examples/router_and_basics.py b/tests/examples/router_and_basics.py index eee1129..0480e87 100644 --- a/tests/examples/router_and_basics.py +++ b/tests/examples/router_and_basics.py @@ -66,6 +66,9 @@ def json(res, req): async def sleepy_json(res, req): #get parameters, query, headers anything you need here before first await :) user_agent = req.get_header("user-agent") + #get all headers + req.for_each_header(lambda key,value: print("Header %s: %s" % (key, value))) + #req maybe will not be available in direct attached async functions after await #but if you dont care about req info you can do it await asyncio.sleep(2) #do something async