diff --git a/examples/middleware.py b/examples/middleware.py new file mode 100644 index 0000000..682bc2d --- /dev/null +++ b/examples/middleware.py @@ -0,0 +1,37 @@ +from socketify import App + + +def middleware(*functions): + def middleware_route(res, req): + data = None + #cicle to all middlewares + for function in functions: + #call middlewares + data = function(res, req, data) + #stops if returns Falsy + if not data: + break + + return middleware_route + +def get_user(authorization_header): + if authorization_header: + return { 'greeting': 'Hello, World' } + return None + +def auth(res, req, data=None): + user = get_user(req.get_header('authorization')) + if not user: + res.write_status(403).end("not authorized") + return False + + #returns extra data + return user + +def home(res, req, user=None): + res.end(user.get('greeting', None)) + +app = App() +app.get("/", middleware(auth, home)) +app.listen(3000, lambda config: print("Listening on port http://localhost:%d now\n" % config.port)) +app.run() \ No newline at end of file diff --git a/examples/middleware_async.py b/examples/middleware_async.py new file mode 100644 index 0000000..696ab4e --- /dev/null +++ b/examples/middleware_async.py @@ -0,0 +1,46 @@ +from socketify import App + + +async def get_user(authorization): + if authorization: + #do actually something async here + return { 'greeting': 'Hello, World' } + return None + +def auth(home, queries=[]): + #in async query string, arguments and headers are only valid until the first await + async def auth_middleware(res, req): + #get_headers will preserve headers (and cookies) after await + headers = req.get_headers() + #get_parameters will preserve all params after await + params = req.get_parameters() + + #preserve desired query string data + query_data = {} + for query in queries: + value = req.get_query(query) + if value: + query_data[query] = value + + user = await get_user(headers.get('authorization', None)) + if user: + return home(res, req, user, query_data) + + return res.write_status(403).cork_end("not authorized") + + return auth_middleware + + +def home(res, req, user=None, query={}): + theme = query.get("theme_color", "light") + greeting = user.get('greeting', None) + user_id = req.get_parameter(0) + res.cork_end(f"{greeting}
theme: {theme}
id: {user_id}") + +app = App() +app.get("/user/:id", auth(home, ['theme_color'])) +app.listen(3000, lambda config: print("Listening on port http://localhost:%d now\n" % config.port)) +app.run() + + +#curl --location --request GET 'http://localhost:3000/user/10?theme_color=dark' --header 'Authorization: Bearer 23456789' \ No newline at end of file diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py index c262ddb..3325232 100644 --- a/src/socketify/socketify.py +++ b/src/socketify/socketify.py @@ -867,13 +867,20 @@ class AppRequest: self.jar_parsed = False self._for_each_header_handler = None self._ptr = ffi.new_handle(self) + self._headers = None + self._params = None def get_cookie(self, name): if self.read_jar == None: if self.jar_parsed: return None - raw_cookies = self.get_header("cookie") + + if self._headers: + raw_cookies = self._headers.get("cookie", None) + else: + raw_cookies = self.get_header("cookie") + if raw_cookies: self.jar_parsed = True self.read_jar = cookies.SimpleCookie(raw_cookies) @@ -925,11 +932,14 @@ class AppRequest: lib.uws_req_for_each_header(self.req, uws_req_for_each_header_handler, self._ptr) def get_headers(self): - headers = {} + if not self._headers is None: + return self._headers + + self._headers = {} def copy_headers(key, value): - headers[key] = value + self._headers[key] = value self.for_each_header(copy_headers) - return headers + return self._headers def get_header(self, lower_case_header): if isinstance(lower_case_header, str): @@ -966,7 +976,28 @@ class AppRequest: return ffi.unpack(buffer_address, length).decode("utf-8") except Exception: #invalid utf-8 return None + + def get_parameters(self): + if self._params: + return self._params + self._params = [] + i = 0 + while True: + value = self.get_parameter(i) + if value: + self._params.append(value) + else: + break + i = i + 1 + return self._params + def get_parameter(self, index): + if self._params: + try: + return self._params[index] + except: + return None + buffer = ffi.new("char**") length = lib.uws_req_get_parameter(self.req, ffi.cast("unsigned short", index), buffer) buffer_address = ffi.addressof(buffer, 0)[0] @@ -1350,9 +1381,9 @@ class AppResponse: return 0 def has_responded(self): - if not self.aborted: + if self.aborted: return False - return bool(lib.uws_res_has_responded(self.SSL, self.res, data, len(data))) + return bool(lib.uws_res_has_responded(self.SSL, self.res)) def on_aborted(self, handler): if hasattr(handler, '__call__'):