diff --git a/examples/middleware.py b/examples/middleware.py
new file mode 100644
index 0000000..682bc2d
--- /dev/null
+++ b/examples/middleware.py
@@ -0,0 +1,37 @@
+from socketify import App
+
+
+def middleware(*functions):
+ def middleware_route(res, req):
+ data = None
+ #cicle to all middlewares
+ for function in functions:
+ #call middlewares
+ data = function(res, req, data)
+ #stops if returns Falsy
+ if not data:
+ break
+
+ return middleware_route
+
+def get_user(authorization_header):
+ if authorization_header:
+ return { 'greeting': 'Hello, World' }
+ return None
+
+def auth(res, req, data=None):
+ user = get_user(req.get_header('authorization'))
+ if not user:
+ res.write_status(403).end("not authorized")
+ return False
+
+ #returns extra data
+ return user
+
+def home(res, req, user=None):
+ res.end(user.get('greeting', None))
+
+app = App()
+app.get("/", middleware(auth, home))
+app.listen(3000, lambda config: print("Listening on port http://localhost:%d now\n" % config.port))
+app.run()
\ No newline at end of file
diff --git a/examples/middleware_async.py b/examples/middleware_async.py
new file mode 100644
index 0000000..696ab4e
--- /dev/null
+++ b/examples/middleware_async.py
@@ -0,0 +1,46 @@
+from socketify import App
+
+
+async def get_user(authorization):
+ if authorization:
+ #do actually something async here
+ return { 'greeting': 'Hello, World' }
+ return None
+
+def auth(home, queries=[]):
+ #in async query string, arguments and headers are only valid until the first await
+ async def auth_middleware(res, req):
+ #get_headers will preserve headers (and cookies) after await
+ headers = req.get_headers()
+ #get_parameters will preserve all params after await
+ params = req.get_parameters()
+
+ #preserve desired query string data
+ query_data = {}
+ for query in queries:
+ value = req.get_query(query)
+ if value:
+ query_data[query] = value
+
+ user = await get_user(headers.get('authorization', None))
+ if user:
+ return home(res, req, user, query_data)
+
+ return res.write_status(403).cork_end("not authorized")
+
+ return auth_middleware
+
+
+def home(res, req, user=None, query={}):
+ theme = query.get("theme_color", "light")
+ greeting = user.get('greeting', None)
+ user_id = req.get_parameter(0)
+ res.cork_end(f"{greeting}
theme: {theme}
id: {user_id}")
+
+app = App()
+app.get("/user/:id", auth(home, ['theme_color']))
+app.listen(3000, lambda config: print("Listening on port http://localhost:%d now\n" % config.port))
+app.run()
+
+
+#curl --location --request GET 'http://localhost:3000/user/10?theme_color=dark' --header 'Authorization: Bearer 23456789'
\ No newline at end of file
diff --git a/src/socketify/socketify.py b/src/socketify/socketify.py
index c262ddb..3325232 100644
--- a/src/socketify/socketify.py
+++ b/src/socketify/socketify.py
@@ -867,13 +867,20 @@ class AppRequest:
self.jar_parsed = False
self._for_each_header_handler = None
self._ptr = ffi.new_handle(self)
+ self._headers = None
+ self._params = None
def get_cookie(self, name):
if self.read_jar == None:
if self.jar_parsed:
return None
- raw_cookies = self.get_header("cookie")
+
+ if self._headers:
+ raw_cookies = self._headers.get("cookie", None)
+ else:
+ raw_cookies = self.get_header("cookie")
+
if raw_cookies:
self.jar_parsed = True
self.read_jar = cookies.SimpleCookie(raw_cookies)
@@ -925,11 +932,14 @@ class AppRequest:
lib.uws_req_for_each_header(self.req, uws_req_for_each_header_handler, self._ptr)
def get_headers(self):
- headers = {}
+ if not self._headers is None:
+ return self._headers
+
+ self._headers = {}
def copy_headers(key, value):
- headers[key] = value
+ self._headers[key] = value
self.for_each_header(copy_headers)
- return headers
+ return self._headers
def get_header(self, lower_case_header):
if isinstance(lower_case_header, str):
@@ -966,7 +976,28 @@ class AppRequest:
return ffi.unpack(buffer_address, length).decode("utf-8")
except Exception: #invalid utf-8
return None
+
+ def get_parameters(self):
+ if self._params:
+ return self._params
+ self._params = []
+ i = 0
+ while True:
+ value = self.get_parameter(i)
+ if value:
+ self._params.append(value)
+ else:
+ break
+ i = i + 1
+ return self._params
+
def get_parameter(self, index):
+ if self._params:
+ try:
+ return self._params[index]
+ except:
+ return None
+
buffer = ffi.new("char**")
length = lib.uws_req_get_parameter(self.req, ffi.cast("unsigned short", index), buffer)
buffer_address = ffi.addressof(buffer, 0)[0]
@@ -1350,9 +1381,9 @@ class AppResponse:
return 0
def has_responded(self):
- if not self.aborted:
+ if self.aborted:
return False
- return bool(lib.uws_res_has_responded(self.SSL, self.res, data, len(data)))
+ return bool(lib.uws_res_has_responded(self.SSL, self.res))
def on_aborted(self, handler):
if hasattr(handler, '__call__'):