kopia lustrzana https://github.com/jointakahe/takahe
Remove hatchway's internal copy
rodzic
43ecf19cd1
commit
04ad97c69b
|
@ -1,5 +0,0 @@
|
||||||
from .http import ApiError, ApiResponse # noqa
|
|
||||||
from .schema import Field, Schema # noqa
|
|
||||||
from .types import Body, BodyDirect, Path, Query, QueryOrBody # noqa
|
|
||||||
from .urls import methods # noqa
|
|
||||||
from .view import api_view # noqa
|
|
|
@ -1,10 +0,0 @@
|
||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class InputSource(str, enum.Enum):
|
|
||||||
path = "path"
|
|
||||||
query = "query"
|
|
||||||
body = "body"
|
|
||||||
body_direct = "body_direct"
|
|
||||||
query_and_body_direct = "query_and_body_direct"
|
|
||||||
file = "file"
|
|
|
@ -1,47 +0,0 @@
|
||||||
import json
|
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
from django.core.serializers.json import DjangoJSONEncoder
|
|
||||||
from django.http import HttpResponse
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class ApiResponse(Generic[T], HttpResponse):
|
|
||||||
"""
|
|
||||||
A way to return extra information with a response if you want
|
|
||||||
headers, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
data: T,
|
|
||||||
encoder=DjangoJSONEncoder,
|
|
||||||
json_dumps_params: dict[str, object] | None = None,
|
|
||||||
finalize: bool = False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
self.data = data
|
|
||||||
self.encoder = encoder
|
|
||||||
self.json_dumps_params = json_dumps_params or {}
|
|
||||||
kwargs.setdefault("content_type", "application/json")
|
|
||||||
super().__init__(content=b"(unfinalised)", **kwargs)
|
|
||||||
if finalize:
|
|
||||||
self.finalize()
|
|
||||||
|
|
||||||
def finalize(self):
|
|
||||||
"""
|
|
||||||
Converts whatever our current data is into HttpResponse content
|
|
||||||
"""
|
|
||||||
# TODO: Automatically call this when we're asked to write output?
|
|
||||||
self.content = json.dumps(self.data, cls=self.encoder, **self.json_dumps_params)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiError(BaseException):
|
|
||||||
"""
|
|
||||||
A handy way to raise an error with JSONable contents
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, status: int, error: str):
|
|
||||||
self.status = status
|
|
||||||
self.error = error
|
|
|
@ -1,52 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from django.db.models import Manager, QuerySet
|
|
||||||
from django.db.models.fields.files import FieldFile
|
|
||||||
from django.template import Variable, VariableDoesNotExist
|
|
||||||
from pydantic.fields import Field # noqa
|
|
||||||
from pydantic.main import BaseModel
|
|
||||||
from pydantic.utils import GetterDict
|
|
||||||
|
|
||||||
|
|
||||||
class DjangoGetterDict(GetterDict):
|
|
||||||
def __init__(self, obj: Any):
|
|
||||||
self._obj = obj
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> Any:
|
|
||||||
try:
|
|
||||||
item = getattr(self._obj, key)
|
|
||||||
except AttributeError:
|
|
||||||
try:
|
|
||||||
item = Variable(key).resolve(self._obj)
|
|
||||||
except VariableDoesNotExist as e:
|
|
||||||
raise KeyError(key) from e
|
|
||||||
return self._convert_result(item)
|
|
||||||
|
|
||||||
def get(self, key: Any, default: Any = None) -> Any:
|
|
||||||
try:
|
|
||||||
return self[key]
|
|
||||||
except KeyError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
def _convert_result(self, result: Any) -> Any:
|
|
||||||
if isinstance(result, Manager):
|
|
||||||
return list(result.all())
|
|
||||||
|
|
||||||
elif isinstance(result, getattr(QuerySet, "__origin__", QuerySet)):
|
|
||||||
return list(result)
|
|
||||||
|
|
||||||
if callable(result):
|
|
||||||
return result()
|
|
||||||
|
|
||||||
elif isinstance(result, FieldFile):
|
|
||||||
if not result:
|
|
||||||
return None
|
|
||||||
return result.url
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class Schema(BaseModel):
|
|
||||||
class Config:
|
|
||||||
orm_mode = True
|
|
||||||
getter_dict = DjangoGetterDict
|
|
|
@ -1,63 +0,0 @@
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
from django.core.files import File
|
|
||||||
|
|
||||||
from hatchway.http import ApiResponse
|
|
||||||
from hatchway.types import (
|
|
||||||
Query,
|
|
||||||
QueryType,
|
|
||||||
acceptable_input,
|
|
||||||
extract_output_type,
|
|
||||||
extract_signifier,
|
|
||||||
is_optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_optional():
|
|
||||||
|
|
||||||
assert is_optional(Optional[int]) == (True, int)
|
|
||||||
assert is_optional(Union[int, None]) == (True, int)
|
|
||||||
assert is_optional(Union[None, int]) == (True, int)
|
|
||||||
assert is_optional(int | None) == (True, int)
|
|
||||||
assert is_optional(None | int) == (True, int)
|
|
||||||
assert is_optional(int) == (False, int)
|
|
||||||
assert is_optional(Query[int]) == (False, Query[int])
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_signifier():
|
|
||||||
|
|
||||||
assert extract_signifier(int) == (None, int)
|
|
||||||
assert extract_signifier(Query[int]) == (QueryType, int)
|
|
||||||
assert extract_signifier(Query[Optional[int]]) == ( # type:ignore
|
|
||||||
QueryType,
|
|
||||||
Optional[int],
|
|
||||||
)
|
|
||||||
assert extract_signifier(Query[int | None]) == ( # type:ignore
|
|
||||||
QueryType,
|
|
||||||
Optional[int],
|
|
||||||
)
|
|
||||||
assert extract_signifier(Optional[Query[int]]) == (QueryType, Optional[int])
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_output_type():
|
|
||||||
|
|
||||||
assert extract_output_type(int) == int
|
|
||||||
assert extract_output_type(ApiResponse[int]) == int
|
|
||||||
assert extract_output_type(ApiResponse[int | str]) == int | str
|
|
||||||
|
|
||||||
|
|
||||||
def test_acceptable_input():
|
|
||||||
|
|
||||||
assert acceptable_input(str) is True
|
|
||||||
assert acceptable_input(int) is True
|
|
||||||
assert acceptable_input(Query[int]) is True
|
|
||||||
assert acceptable_input(Optional[int]) is True
|
|
||||||
assert acceptable_input(int | None) is True
|
|
||||||
assert acceptable_input(int | str | None) is True
|
|
||||||
assert acceptable_input(Query[int | None]) is True # type: ignore
|
|
||||||
assert acceptable_input(File) is True
|
|
||||||
assert acceptable_input(list[str]) is True
|
|
||||||
assert acceptable_input(dict[str, int]) is True
|
|
||||||
assert acceptable_input(Literal["a", "b"]) is True
|
|
||||||
assert acceptable_input(frozenset) is False
|
|
||||||
assert acceptable_input(dict[str, frozenset]) is False
|
|
|
@ -1,244 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from django.core import files
|
|
||||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
|
||||||
from django.http import QueryDict
|
|
||||||
from django.test import RequestFactory
|
|
||||||
from django.test.client import MULTIPART_CONTENT
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from hatchway import ApiError, Body, QueryOrBody, api_view
|
|
||||||
from hatchway.view import ApiView
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_view():
|
|
||||||
"""
|
|
||||||
Tests that a view with simple types works correctly
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view
|
|
||||||
def test_view(
|
|
||||||
request,
|
|
||||||
a: int,
|
|
||||||
b: QueryOrBody[int | None] = None,
|
|
||||||
c: str = "x",
|
|
||||||
) -> str:
|
|
||||||
if b is None:
|
|
||||||
return c * a
|
|
||||||
else:
|
|
||||||
return c * (a - b)
|
|
||||||
|
|
||||||
# Call it with a few different patterns to verify it's type coercing right
|
|
||||||
factory = RequestFactory()
|
|
||||||
|
|
||||||
# Implicit query param
|
|
||||||
response = test_view(factory.get("/test/?a=4"))
|
|
||||||
assert json.loads(response.content) == "xxxx"
|
|
||||||
|
|
||||||
# QueryOrBody pulling from query
|
|
||||||
response = test_view(factory.get("/test/?a=4&b=2"))
|
|
||||||
assert json.loads(response.content) == "xx"
|
|
||||||
|
|
||||||
# QueryOrBody pulling from formdata body
|
|
||||||
response = test_view(factory.post("/test/?a=4", {"b": "3"}))
|
|
||||||
assert json.loads(response.content) == "x"
|
|
||||||
|
|
||||||
# QueryOrBody pulling from JSON body
|
|
||||||
response = test_view(
|
|
||||||
factory.post(
|
|
||||||
"/test/?a=4", json.dumps({"b": 3}), content_type="application/json"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert json.loads(response.content) == "x"
|
|
||||||
|
|
||||||
# Implicit Query not pulling from body
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
test_view(factory.post("/test/", {"a": 4, "b": 3}))
|
|
||||||
|
|
||||||
|
|
||||||
def test_body_direct():
|
|
||||||
"""
|
|
||||||
Tests that a Pydantic model with BodyDirect gets its fields from the top level
|
|
||||||
"""
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
|
||||||
number: int
|
|
||||||
name: str
|
|
||||||
|
|
||||||
@api_view
|
|
||||||
def test_view(request, data: TestModel) -> int:
|
|
||||||
return data.number
|
|
||||||
|
|
||||||
factory = RequestFactory()
|
|
||||||
|
|
||||||
# formdata version
|
|
||||||
response = test_view(factory.post("/test/", {"number": "123", "name": "Andrew"}))
|
|
||||||
assert json.loads(response.content) == 123
|
|
||||||
|
|
||||||
# JSON body version
|
|
||||||
response = test_view(
|
|
||||||
factory.post(
|
|
||||||
"/test/",
|
|
||||||
json.dumps({"number": "123", "name": "Andrew"}),
|
|
||||||
content_type="application/json",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert json.loads(response.content) == 123
|
|
||||||
|
|
||||||
|
|
||||||
def test_list_response():
|
|
||||||
"""
|
|
||||||
Tests that a view with a list response type works correctly with both
|
|
||||||
dicts and pydantic model instances.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
|
||||||
number: int
|
|
||||||
name: str
|
|
||||||
|
|
||||||
@api_view
|
|
||||||
def test_view_dict(request) -> list[TestModel]:
|
|
||||||
return [
|
|
||||||
{"name": "Andrew", "number": 1}, # type:ignore
|
|
||||||
{"name": "Alice", "number": 0}, # type:ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
@api_view
|
|
||||||
def test_view_model(request) -> list[TestModel]:
|
|
||||||
return [TestModel(name="Andrew", number=1), TestModel(name="Alice", number=0)]
|
|
||||||
|
|
||||||
response = test_view_dict(RequestFactory().get("/test/"))
|
|
||||||
assert json.loads(response.content) == [
|
|
||||||
{"name": "Andrew", "number": 1},
|
|
||||||
{"name": "Alice", "number": 0},
|
|
||||||
]
|
|
||||||
|
|
||||||
response = test_view_model(RequestFactory().get("/test/"))
|
|
||||||
assert json.loads(response.content) == [
|
|
||||||
{"name": "Andrew", "number": 1},
|
|
||||||
{"name": "Alice", "number": 0},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_patch_body():
|
|
||||||
"""
|
|
||||||
Tests that PATCH also gets its body parsed
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view.patch
|
|
||||||
def test_view(request, a: Body[int]):
|
|
||||||
return a
|
|
||||||
|
|
||||||
factory = RequestFactory()
|
|
||||||
response = test_view(
|
|
||||||
factory.patch(
|
|
||||||
"/test/",
|
|
||||||
content_type=MULTIPART_CONTENT,
|
|
||||||
data=factory._encode_data({"a": "42"}, MULTIPART_CONTENT),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert json.loads(response.content) == 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_body():
|
|
||||||
"""
|
|
||||||
Tests that file uploads work right
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view.post
|
|
||||||
def test_view(request, a: Body[int], b: files.File) -> str:
|
|
||||||
return str(a) + b.read().decode("ascii")
|
|
||||||
|
|
||||||
factory = RequestFactory()
|
|
||||||
uploaded_file = SimpleUploadedFile(
|
|
||||||
"file.txt",
|
|
||||||
b"MY FILE IS AMAZING",
|
|
||||||
content_type="text/plain",
|
|
||||||
)
|
|
||||||
response = test_view(
|
|
||||||
factory.post(
|
|
||||||
"/test/",
|
|
||||||
data={"a": 42, "b": uploaded_file},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
assert json.loads(response.content) == "42MY FILE IS AMAZING"
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_response():
|
|
||||||
"""
|
|
||||||
Tests that a view with no response type returns the contents verbatim
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view
|
|
||||||
def test_view(request):
|
|
||||||
return [1, "woooooo"]
|
|
||||||
|
|
||||||
response = test_view(RequestFactory().get("/test/"))
|
|
||||||
assert json.loads(response.content) == [1, "woooooo"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_wrong_method():
|
|
||||||
"""
|
|
||||||
Tests that a view with a method limiter works
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view.get
|
|
||||||
def test_view(request):
|
|
||||||
return "yay"
|
|
||||||
|
|
||||||
response = test_view(RequestFactory().get("/test/"))
|
|
||||||
assert json.loads(response.content) == "yay"
|
|
||||||
|
|
||||||
response = test_view(RequestFactory().post("/test/"))
|
|
||||||
assert response.status_code == 405
|
|
||||||
|
|
||||||
|
|
||||||
def test_api_error():
|
|
||||||
"""
|
|
||||||
Tests that ApiError propagates right
|
|
||||||
"""
|
|
||||||
|
|
||||||
@api_view.get
|
|
||||||
def test_view(request):
|
|
||||||
raise ApiError(401, "you did a bad thing")
|
|
||||||
|
|
||||||
response = test_view(RequestFactory().get("/test/"))
|
|
||||||
assert json.loads(response.content) == {"error": "you did a bad thing"}
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
def test_unusable_type():
|
|
||||||
"""
|
|
||||||
Tests that you get a nice error when you use a type on an input that
|
|
||||||
Pydantic doesn't understand.
|
|
||||||
"""
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
|
|
||||||
@api_view.get
|
|
||||||
def test_view(request, a: RequestFactory):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_values():
|
|
||||||
"""
|
|
||||||
Tests that ApiView.get_values correctly handles lists
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert ApiView.get_values({"a": 2, "b": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
|
||||||
assert ApiView.get_values({"a": 2, "b[]": [3, 4]}) == {"a": 2, "b": [3, 4]}
|
|
||||||
assert ApiView.get_values(QueryDict("a=2&b=3&b=4")) == {"a": "2", "b": ["3", "4"]}
|
|
||||||
assert ApiView.get_values(QueryDict("a=2&b[]=3&b[]=4")) == {
|
|
||||||
"a": "2",
|
|
||||||
"b": ["3", "4"],
|
|
||||||
}
|
|
||||||
assert ApiView.get_values(QueryDict("a=2&b=3")) == {"a": "2", "b": "3"}
|
|
||||||
assert ApiView.get_values(QueryDict("a=2&b[]=3")) == {"a": "2", "b": ["3"]}
|
|
||||||
assert ApiView.get_values(QueryDict("a[b]=1")) == {"a": {"b": "1"}}
|
|
||||||
assert ApiView.get_values(QueryDict("a[b]=1&a[c]=2")) == {"a": {"b": "1", "c": "2"}}
|
|
||||||
assert ApiView.get_values(QueryDict("a[b][c]=1")) == {"a": {"b": {"c": "1"}}}
|
|
||||||
assert ApiView.get_values(QueryDict("a[b][c][]=1")) == {"a": {"b": {"c": ["1"]}}}
|
|
||||||
assert ApiView.get_values(QueryDict("a[b][]=1&a[b][]=2")) == {
|
|
||||||
"a": {"b": ["1", "2"]}
|
|
||||||
}
|
|
|
@ -1,145 +0,0 @@
|
||||||
from types import NoneType, UnionType
|
|
||||||
from typing import ( # type: ignore[attr-defined]
|
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
_AnnotatedAlias,
|
|
||||||
_GenericAlias,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from django.core import files
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .http import ApiResponse
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class PathType:
|
|
||||||
"""
|
|
||||||
An input pulled from the path (url resolver kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueryType:
|
|
||||||
"""
|
|
||||||
An input pulled from the query parameters (request.GET)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BodyType:
|
|
||||||
"""
|
|
||||||
An input pulled from the POST body (request.POST or a JSON body)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class FileType:
|
|
||||||
"""
|
|
||||||
An input pulled from the POST body (request.POST or a JSON body)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BodyDirectType:
|
|
||||||
"""
|
|
||||||
A Pydantic model whose keys are all looked for in the top-level
|
|
||||||
POST data, rather than in a dict under a key named after the input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class QueryOrBodyType:
|
|
||||||
"""
|
|
||||||
An input pulled from either query parameters or post data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Path = Annotated[T, PathType]
|
|
||||||
Query = Annotated[T, QueryType]
|
|
||||||
Body = Annotated[T, BodyType]
|
|
||||||
File = Annotated[T, FileType]
|
|
||||||
BodyDirect = Annotated[T, BodyDirectType]
|
|
||||||
QueryOrBody = Annotated[T, QueryOrBodyType]
|
|
||||||
|
|
||||||
|
|
||||||
def is_optional(annotation) -> tuple[bool, Any]:
|
|
||||||
"""
|
|
||||||
If an annotation is Optional or | None, returns (True, internal type).
|
|
||||||
Returns (False, annotation) otherwise.
|
|
||||||
"""
|
|
||||||
if (isinstance(annotation, _GenericAlias) and annotation.__origin__ is Union) or (
|
|
||||||
isinstance(annotation, UnionType)
|
|
||||||
):
|
|
||||||
args = get_args(annotation)
|
|
||||||
if len(args) > 2:
|
|
||||||
return False, annotation
|
|
||||||
if args[0] is NoneType:
|
|
||||||
return True, args[1]
|
|
||||||
if args[1] is NoneType:
|
|
||||||
return True, args[0]
|
|
||||||
return False, annotation
|
|
||||||
return False, annotation
|
|
||||||
|
|
||||||
|
|
||||||
def extract_signifier(annotation) -> tuple[Any, Any]:
|
|
||||||
"""
|
|
||||||
Given a type annotation, looks to see if it can find a input source
|
|
||||||
signifier (Path, Query, etc.)
|
|
||||||
|
|
||||||
If it can, returns (signifier, annotation_without_signifier)
|
|
||||||
If not, returns (None, annotation)
|
|
||||||
"""
|
|
||||||
our_generics = {
|
|
||||||
PathType,
|
|
||||||
QueryType,
|
|
||||||
BodyType,
|
|
||||||
FileType,
|
|
||||||
BodyDirectType,
|
|
||||||
QueryOrBodyType,
|
|
||||||
}
|
|
||||||
# Remove any optional-style wrapper
|
|
||||||
optional, internal_annotation = is_optional(annotation)
|
|
||||||
# Is it an annotation?
|
|
||||||
if isinstance(internal_annotation, _AnnotatedAlias):
|
|
||||||
args = get_args(internal_annotation)
|
|
||||||
for arg in args[1:]:
|
|
||||||
if arg in our_generics:
|
|
||||||
if optional:
|
|
||||||
return (arg, Optional[args[0]])
|
|
||||||
else:
|
|
||||||
return (arg, args[0])
|
|
||||||
return None, annotation
|
|
||||||
|
|
||||||
|
|
||||||
def extract_output_type(annotation):
|
|
||||||
"""
|
|
||||||
Returns the right response type for a function
|
|
||||||
"""
|
|
||||||
# If the type is ApiResponse, we want to pull out its inside
|
|
||||||
if isinstance(annotation, _GenericAlias):
|
|
||||||
if get_origin(annotation) == ApiResponse:
|
|
||||||
return get_args(annotation)[0]
|
|
||||||
return annotation
|
|
||||||
|
|
||||||
|
|
||||||
def acceptable_input(annotation) -> bool:
|
|
||||||
"""
|
|
||||||
Returns if this annotation is something we think we can accept as input
|
|
||||||
"""
|
|
||||||
_, inner_type = extract_signifier(annotation)
|
|
||||||
try:
|
|
||||||
if issubclass(inner_type, BaseModel):
|
|
||||||
return True
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
if inner_type in [str, int, list, tuple, bool, Any, files.File, type(None)]:
|
|
||||||
return True
|
|
||||||
origin = get_origin(inner_type)
|
|
||||||
if origin == Literal:
|
|
||||||
return True
|
|
||||||
if origin in [Union, UnionType, dict, list, tuple]:
|
|
||||||
return all(acceptable_input(a) for a in get_args(inner_type))
|
|
||||||
return False
|
|
|
@ -1,32 +0,0 @@
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from django.http import HttpResponseNotAllowed
|
|
||||||
|
|
||||||
|
|
||||||
class Methods:
|
|
||||||
"""
|
|
||||||
Allows easy multi-method dispatch to different functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
csrf_exempt = True
|
|
||||||
|
|
||||||
def __init__(self, **callables: Callable):
|
|
||||||
self.callables = {
|
|
||||||
method.lower(): callable for method, callable in callables.items()
|
|
||||||
}
|
|
||||||
unknown_methods = set(self.callables.keys()).difference(
|
|
||||||
{"get", "post", "patch", "put", "delete"}
|
|
||||||
)
|
|
||||||
if unknown_methods:
|
|
||||||
raise ValueError(f"Cannot route methods: {unknown_methods}")
|
|
||||||
|
|
||||||
def __call__(self, request, *args, **kwargs) -> Any:
|
|
||||||
method = request.method.lower()
|
|
||||||
if method in self.callables:
|
|
||||||
return self.callables[method](request, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return HttpResponseNotAllowed(self.callables.keys())
|
|
||||||
|
|
||||||
|
|
||||||
methods = Methods
|
|
297
hatchway/view.py
297
hatchway/view.py
|
@ -1,297 +0,0 @@
|
||||||
import json
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, Optional, get_type_hints
|
|
||||||
|
|
||||||
from django.core import files
|
|
||||||
from django.http import HttpRequest, HttpResponseNotAllowed, QueryDict
|
|
||||||
from django.http.multipartparser import MultiPartParser
|
|
||||||
from pydantic import BaseModel, create_model
|
|
||||||
|
|
||||||
from .constants import InputSource
|
|
||||||
from .http import ApiError, ApiResponse
|
|
||||||
from .types import (
|
|
||||||
BodyDirectType,
|
|
||||||
BodyType,
|
|
||||||
FileType,
|
|
||||||
PathType,
|
|
||||||
QueryOrBodyType,
|
|
||||||
QueryType,
|
|
||||||
acceptable_input,
|
|
||||||
extract_output_type,
|
|
||||||
extract_signifier,
|
|
||||||
is_optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ApiView:
|
|
||||||
"""
|
|
||||||
A view 'wrapper' object that replaces the API view for anything further
|
|
||||||
up the stack.
|
|
||||||
|
|
||||||
Unlike Django's class-based views, we don't need an as_view pattern
|
|
||||||
as we are careful never to write anything per-request to self.
|
|
||||||
"""
|
|
||||||
|
|
||||||
csrf_exempt = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
view: Callable,
|
|
||||||
input_types: dict[str, Any] | None = None,
|
|
||||||
output_type: Any = None,
|
|
||||||
implicit_lists: bool = True,
|
|
||||||
method: str | None = None,
|
|
||||||
):
|
|
||||||
self.view = view
|
|
||||||
self.implicit_lists = implicit_lists
|
|
||||||
self.view_name = getattr(view, "__name__", "unknown_view")
|
|
||||||
self.method = method
|
|
||||||
# Extract input/output types from view annotations if we need to
|
|
||||||
self.input_types = input_types
|
|
||||||
if self.input_types is None:
|
|
||||||
self.input_types = get_type_hints(view, include_extras=True)
|
|
||||||
if "return" in self.input_types:
|
|
||||||
del self.input_types["return"]
|
|
||||||
self.output_type = output_type
|
|
||||||
if self.output_type is None:
|
|
||||||
try:
|
|
||||||
self.output_type = extract_output_type(
|
|
||||||
get_type_hints(view, include_extras=True)["return"]
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
self.output_type = None
|
|
||||||
self.compile()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, view: Callable):
|
|
||||||
return cls(view=view, method="get")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def post(cls, view: Callable):
|
|
||||||
return cls(view=view, method="post")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def put(cls, view: Callable):
|
|
||||||
return cls(view=view, method="put")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def patch(cls, view: Callable):
|
|
||||||
return cls(view=view, method="patch")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def delete(cls, view: Callable):
|
|
||||||
return cls(view=view, method="delete")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sources_for_input(cls, input_type) -> tuple[list[InputSource], Any]:
|
|
||||||
"""
|
|
||||||
Given a type that can appear as a request parameter type, returns
|
|
||||||
what sources it can come from, and what its type is as understood
|
|
||||||
by Pydantic.
|
|
||||||
"""
|
|
||||||
signifier, input_type = extract_signifier(input_type)
|
|
||||||
if signifier is QueryType:
|
|
||||||
return ([InputSource.query], input_type)
|
|
||||||
elif signifier is BodyType:
|
|
||||||
return ([InputSource.body], input_type)
|
|
||||||
elif signifier is BodyDirectType:
|
|
||||||
if not issubclass(input_type, BaseModel):
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot use BodyDirect on something that is not a Pydantic model"
|
|
||||||
)
|
|
||||||
return ([InputSource.body_direct], input_type)
|
|
||||||
elif signifier is PathType:
|
|
||||||
return ([InputSource.path], input_type)
|
|
||||||
elif (
|
|
||||||
signifier is FileType
|
|
||||||
or input_type is files.File
|
|
||||||
or is_optional(input_type)[1] is files.File
|
|
||||||
):
|
|
||||||
return ([InputSource.file], input_type)
|
|
||||||
elif signifier is QueryOrBodyType:
|
|
||||||
return ([InputSource.query, InputSource.body], input_type)
|
|
||||||
# Is it a Pydantic model, which means it's implicitly body?
|
|
||||||
elif isinstance(input_type, type) and issubclass(input_type, BaseModel):
|
|
||||||
return ([InputSource.body], input_type)
|
|
||||||
# Otherwise, we look in the path first and then the query
|
|
||||||
else:
|
|
||||||
return ([InputSource.path, InputSource.query], input_type)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_values(cls, data, use_square_brackets=True) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Given a QueryDict or normal dict, returns data taking into account
|
|
||||||
lists made by repeated values or by suffixing names with [].
|
|
||||||
"""
|
|
||||||
result: dict[str, Any] = {}
|
|
||||||
for key, value in data.items():
|
|
||||||
# If it's a query dict with multiple values, make it a list
|
|
||||||
if isinstance(data, QueryDict):
|
|
||||||
values = data.getlist(key)
|
|
||||||
if len(values) > 1:
|
|
||||||
value = values
|
|
||||||
# If it is in dict-ish/list-ish syntax, adhere to that
|
|
||||||
# TODO: Make this better handle badly formed keys
|
|
||||||
if "[" in key and use_square_brackets:
|
|
||||||
parts = key.split("[")
|
|
||||||
target = result
|
|
||||||
last_key = parts[0]
|
|
||||||
for part in parts[1:]:
|
|
||||||
part = part.rstrip("]")
|
|
||||||
if not part:
|
|
||||||
target = target.setdefault(last_key, [])
|
|
||||||
else:
|
|
||||||
target = target.setdefault(last_key, {})
|
|
||||||
last_key = part
|
|
||||||
if isinstance(target, list):
|
|
||||||
if isinstance(value, list):
|
|
||||||
target.extend(value)
|
|
||||||
else:
|
|
||||||
target.append(value)
|
|
||||||
else:
|
|
||||||
target[last_key] = value
|
|
||||||
else:
|
|
||||||
result[key] = value
|
|
||||||
return result
|
|
||||||
|
|
||||||
def compile(self):
|
|
||||||
self.sources: dict[str, list[InputSource]] = {}
|
|
||||||
amount_from_body = 0
|
|
||||||
pydantic_model_dict = {}
|
|
||||||
self.input_files = set()
|
|
||||||
last_body_type = None
|
|
||||||
# For each input item, work out where to pull it from
|
|
||||||
for name, input_type in self.input_types.items():
|
|
||||||
# Do some basic typechecking to stop things that aren't allowed
|
|
||||||
if isinstance(input_type, type) and issubclass(input_type, HttpRequest):
|
|
||||||
continue
|
|
||||||
if not acceptable_input(input_type):
|
|
||||||
# Strip away any singifiers for the error
|
|
||||||
_, inner_type = extract_signifier(input_type)
|
|
||||||
raise ValueError(
|
|
||||||
f"Input argument {name} has an unsupported type {inner_type}"
|
|
||||||
)
|
|
||||||
sources, pydantic_type = self.sources_for_input(input_type)
|
|
||||||
self.sources[name] = sources
|
|
||||||
# Keep count of how many are pulling from the body
|
|
||||||
if InputSource.body in sources:
|
|
||||||
amount_from_body += 1
|
|
||||||
last_body_type = pydantic_type
|
|
||||||
if InputSource.file in sources:
|
|
||||||
self.input_files.add(name)
|
|
||||||
else:
|
|
||||||
pydantic_model_dict[name] = (Optional[pydantic_type], ...)
|
|
||||||
# If there is just one thing pulling from the body and it's a BaseModel,
|
|
||||||
# signify that it's actually pulling from the body keys directly and
|
|
||||||
# not a sub-dict
|
|
||||||
if amount_from_body == 1:
|
|
||||||
for name, sources in self.sources.items():
|
|
||||||
if (
|
|
||||||
InputSource.body in sources
|
|
||||||
and isinstance(last_body_type, type)
|
|
||||||
and issubclass(last_body_type, BaseModel)
|
|
||||||
):
|
|
||||||
self.sources[name] = [
|
|
||||||
x for x in sources if x != InputSource.body
|
|
||||||
] + [InputSource.body_direct]
|
|
||||||
# Turn all the main arguments into Pydantic parsing models
|
|
||||||
try:
|
|
||||||
self.input_model = create_model(
|
|
||||||
f"{self.view_name}_input", **pydantic_model_dict
|
|
||||||
)
|
|
||||||
except RuntimeError:
|
|
||||||
raise ValueError(
|
|
||||||
f"One or more inputs on view {self.view_name} have a bad configuration"
|
|
||||||
)
|
|
||||||
if self.output_type is not None:
|
|
||||||
self.output_model = create_model(
|
|
||||||
f"{self.view_name}_output", value=(self.output_type, ...)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, request: HttpRequest, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Entrypoint when this is called as a view.
|
|
||||||
"""
|
|
||||||
# Do a method check if we have one set
|
|
||||||
if self.method and self.method.upper() != request.method:
|
|
||||||
return HttpResponseNotAllowed([self.method])
|
|
||||||
# For each item we can source, go find it if we can
|
|
||||||
query_values = self.get_values(request.GET)
|
|
||||||
body_values = self.get_values(request.POST)
|
|
||||||
files_values = self.get_values(request.FILES)
|
|
||||||
# If it's a PUT or PATCH method, work around Django not handling FILES
|
|
||||||
# or POST on those requests
|
|
||||||
if request.method in ["PATCH", "PUT"]:
|
|
||||||
if request.content_type == "multipart/form-data":
|
|
||||||
POST, FILES = MultiPartParser(
|
|
||||||
request.META, request, request.upload_handlers, request.encoding
|
|
||||||
).parse()
|
|
||||||
body_values = self.get_values(POST)
|
|
||||||
files_values = self.get_values(FILES)
|
|
||||||
elif request.content_type == "application/x-www-form-urlencoded":
|
|
||||||
POST = QueryDict(request.body, encoding=request._encoding)
|
|
||||||
body_values = self.get_values(POST)
|
|
||||||
# If there was a JSON body, go load that
|
|
||||||
if request.content_type == "application/json" and request.body.strip():
|
|
||||||
body_values.update(self.get_values(json.loads(request.body)))
|
|
||||||
values = {}
|
|
||||||
for name, sources in self.sources.items():
|
|
||||||
for source in sources:
|
|
||||||
if source == InputSource.path:
|
|
||||||
if name in kwargs:
|
|
||||||
values[name] = kwargs[name]
|
|
||||||
break
|
|
||||||
elif source == InputSource.query:
|
|
||||||
if name in query_values:
|
|
||||||
values[name] = query_values[name]
|
|
||||||
break
|
|
||||||
elif source == InputSource.body:
|
|
||||||
if name in body_values:
|
|
||||||
values[name] = body_values[name]
|
|
||||||
break
|
|
||||||
elif source == InputSource.file:
|
|
||||||
if name in files_values:
|
|
||||||
values[name] = files_values[name]
|
|
||||||
break
|
|
||||||
elif source == InputSource.body_direct:
|
|
||||||
values[name] = body_values
|
|
||||||
break
|
|
||||||
elif source == InputSource.query_and_body_direct:
|
|
||||||
values[name] = dict(query_values)
|
|
||||||
values[name].update(body_values)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown source {source}")
|
|
||||||
else:
|
|
||||||
values[name] = None
|
|
||||||
# Give that to the Pydantic model to make it handle stuff
|
|
||||||
model_instance = self.input_model(**values)
|
|
||||||
kwargs = {
|
|
||||||
name: getattr(model_instance, name)
|
|
||||||
for name in model_instance.__fields__
|
|
||||||
if values[name] is not None # Trim out missing fields
|
|
||||||
}
|
|
||||||
# Add in any files
|
|
||||||
# TODO: HTTP error if file is not optional
|
|
||||||
for name in self.input_files:
|
|
||||||
kwargs[name] = files_values.get(name, None)
|
|
||||||
# Call the view with those as kwargs
|
|
||||||
try:
|
|
||||||
response = self.view(request, **kwargs)
|
|
||||||
except ApiError as error:
|
|
||||||
return ApiResponse(
|
|
||||||
{"error": error.error}, status=error.status, finalize=True
|
|
||||||
)
|
|
||||||
# If it's not an ApiResponse, make it one
|
|
||||||
if not isinstance(response, ApiResponse):
|
|
||||||
response = ApiResponse(response)
|
|
||||||
# Get pydantic to coerce the output response
|
|
||||||
if self.output_type is not None:
|
|
||||||
response.data = self.output_model(value=response.data).dict()["value"]
|
|
||||||
elif isinstance(response.data, BaseModel):
|
|
||||||
response.data = response.data.dict()
|
|
||||||
response.finalize()
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
api_view = ApiView
|
|
|
@ -5,6 +5,7 @@ dj_database_url~=1.0.0
|
||||||
django-cache-url~=3.4.2
|
django-cache-url~=3.4.2
|
||||||
django-cors-headers~=3.13.0
|
django-cors-headers~=3.13.0
|
||||||
django-debug-toolbar~=3.8.1
|
django-debug-toolbar~=3.8.1
|
||||||
|
django-hatchway~=0.5.0
|
||||||
django-htmx~=1.13.0
|
django-htmx~=1.13.0
|
||||||
django-oauth-toolkit~=2.2.0
|
django-oauth-toolkit~=2.2.0
|
||||||
django-storages[google,boto3]~=1.13.1
|
django-storages[google,boto3]~=1.13.1
|
||||||
|
|
|
@ -196,6 +196,7 @@ INSTALLED_APPS = [
|
||||||
"django.contrib.staticfiles",
|
"django.contrib.staticfiles",
|
||||||
"corsheaders",
|
"corsheaders",
|
||||||
"django_htmx",
|
"django_htmx",
|
||||||
|
"hatchway",
|
||||||
"core",
|
"core",
|
||||||
"activities",
|
"activities",
|
||||||
"api",
|
"api",
|
||||||
|
|
Ładowanie…
Reference in New Issue