2019-03-25 16:02:51 +00:00
|
|
|
import json
|
|
|
|
import urllib.parse
|
|
|
|
|
|
|
|
from django import http
|
|
|
|
from django.utils import timezone
|
|
|
|
from django.db.models import Q
|
|
|
|
from rest_framework import mixins, permissions, views, viewsets
|
|
|
|
|
|
|
|
from oauth2_provider import exceptions as oauth2_exceptions
|
|
|
|
from oauth2_provider import views as oauth_views
|
|
|
|
from oauth2_provider.settings import oauth2_settings
|
|
|
|
|
2019-09-17 09:23:59 +00:00
|
|
|
from funkwhale_api.common import throttling
|
|
|
|
|
2019-03-25 16:02:51 +00:00
|
|
|
from .. import models
|
|
|
|
from .permissions import ScopePermission
|
|
|
|
from . import serializers
|
|
|
|
|
|
|
|
|
|
|
|
class ApplicationViewSet(
|
|
|
|
mixins.CreateModelMixin,
|
|
|
|
mixins.ListModelMixin,
|
|
|
|
mixins.UpdateModelMixin,
|
|
|
|
mixins.DestroyModelMixin,
|
|
|
|
mixins.RetrieveModelMixin,
|
|
|
|
viewsets.GenericViewSet,
|
|
|
|
):
|
|
|
|
anonymous_policy = True
|
|
|
|
required_scope = {
|
|
|
|
"retrieve": None,
|
|
|
|
"create": None,
|
|
|
|
"destroy": "write:security",
|
|
|
|
"update": "write:security",
|
|
|
|
"partial_update": "write:security",
|
|
|
|
"list": "read:security",
|
|
|
|
}
|
|
|
|
lookup_field = "client_id"
|
|
|
|
queryset = models.Application.objects.all().order_by("-created")
|
|
|
|
serializer_class = serializers.ApplicationSerializer
|
2019-09-17 09:23:59 +00:00
|
|
|
throttling_scopes = {
|
|
|
|
"create": {
|
|
|
|
"anonymous": "anonymous-oauth-app",
|
|
|
|
"authenticated": "authenticated-oauth-app",
|
|
|
|
}
|
|
|
|
}
|
2019-03-25 16:02:51 +00:00
|
|
|
|
|
|
|
def get_serializer_class(self):
|
|
|
|
if self.request.method.lower() == "post":
|
|
|
|
return serializers.CreateApplicationSerializer
|
|
|
|
return super().get_serializer_class()
|
|
|
|
|
|
|
|
def perform_create(self, serializer):
|
|
|
|
return serializer.save(
|
|
|
|
client_type=models.Application.CLIENT_CONFIDENTIAL,
|
|
|
|
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
|
|
|
|
user=self.request.user if self.request.user.is_authenticated else None,
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_serializer(self, *args, **kwargs):
|
|
|
|
serializer_class = self.get_serializer_class()
|
|
|
|
try:
|
|
|
|
owned = args[0].user == self.request.user
|
|
|
|
except (IndexError, AttributeError):
|
|
|
|
owned = False
|
|
|
|
if owned:
|
|
|
|
serializer_class = serializers.CreateApplicationSerializer
|
|
|
|
|
|
|
|
kwargs["context"] = self.get_serializer_context()
|
|
|
|
return serializer_class(*args, **kwargs)
|
|
|
|
|
|
|
|
def get_queryset(self):
|
|
|
|
qs = super().get_queryset()
|
|
|
|
if self.action in ["list", "destroy", "update", "partial_update"]:
|
|
|
|
qs = qs.filter(user=self.request.user)
|
|
|
|
return qs
|
|
|
|
|
|
|
|
|
|
|
|
class GrantViewSet(
|
|
|
|
mixins.RetrieveModelMixin,
|
|
|
|
mixins.DestroyModelMixin,
|
|
|
|
mixins.ListModelMixin,
|
|
|
|
viewsets.GenericViewSet,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
This is a viewset that list applications that have access to the request user
|
|
|
|
account, to allow revoking tokens easily.
|
|
|
|
"""
|
|
|
|
|
|
|
|
permission_classes = [permissions.IsAuthenticated, ScopePermission]
|
|
|
|
required_scope = "security"
|
|
|
|
lookup_field = "client_id"
|
|
|
|
queryset = models.Application.objects.all().order_by("-created")
|
|
|
|
serializer_class = serializers.ApplicationSerializer
|
|
|
|
pagination_class = None
|
|
|
|
|
|
|
|
def get_queryset(self):
|
|
|
|
now = timezone.now()
|
|
|
|
queryset = super().get_queryset()
|
|
|
|
grants = models.Grant.objects.filter(user=self.request.user, expires__gt=now)
|
|
|
|
access_tokens = models.AccessToken.objects.filter(user=self.request.user)
|
|
|
|
refresh_tokens = models.RefreshToken.objects.filter(
|
|
|
|
user=self.request.user, revoked=None
|
|
|
|
)
|
|
|
|
|
|
|
|
return queryset.filter(
|
|
|
|
Q(pk__in=access_tokens.values("application"))
|
|
|
|
| Q(pk__in=refresh_tokens.values("application"))
|
|
|
|
| Q(pk__in=grants.values("application"))
|
|
|
|
).distinct()
|
|
|
|
|
|
|
|
def perform_create(self, serializer):
|
|
|
|
return serializer.save(
|
|
|
|
client_type=models.Application.CLIENT_CONFIDENTIAL,
|
|
|
|
authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE,
|
|
|
|
)
|
|
|
|
|
|
|
|
def perform_destroy(self, instance):
|
|
|
|
application = instance
|
|
|
|
|
|
|
|
access_tokens = application.accesstoken_set.filter(user=self.request.user)
|
|
|
|
for token in access_tokens:
|
|
|
|
token.revoke()
|
|
|
|
|
|
|
|
refresh_tokens = application.refreshtoken_set.filter(user=self.request.user)
|
|
|
|
for token in refresh_tokens:
|
|
|
|
try:
|
|
|
|
token.revoke()
|
|
|
|
except models.AccessToken.DoesNotExist:
|
|
|
|
token.access_token = None
|
|
|
|
token.revoked = timezone.now()
|
|
|
|
token.save(update_fields=["access_token", "revoked"])
|
|
|
|
grants = application.grant_set.filter(user=self.request.user)
|
|
|
|
grants.delete()
|
|
|
|
|
|
|
|
|
|
|
|
class AuthorizeView(views.APIView, oauth_views.AuthorizationView):
|
|
|
|
permission_classes = [permissions.IsAuthenticated]
|
|
|
|
server_class = oauth2_settings.OAUTH2_SERVER_CLASS
|
|
|
|
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
|
|
|
|
oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS
|
|
|
|
skip_authorization_completely = False
|
|
|
|
oauth2_data = {}
|
|
|
|
|
|
|
|
def form_invalid(self, form):
|
|
|
|
"""
|
|
|
|
Return a JSON response instead of a template one
|
|
|
|
"""
|
|
|
|
errors = form.errors
|
|
|
|
|
|
|
|
return self.json_payload(errors, status_code=400)
|
|
|
|
|
2019-09-17 09:23:59 +00:00
|
|
|
def post(self, request, *args, **kwargs):
|
|
|
|
throttling.check_request(request, "oauth-authorize")
|
|
|
|
return super().post(request, *args, **kwargs)
|
|
|
|
|
2019-03-25 16:02:51 +00:00
|
|
|
def form_valid(self, form):
|
|
|
|
try:
|
|
|
|
response = super().form_valid(form)
|
|
|
|
|
|
|
|
except models.Application.DoesNotExist:
|
|
|
|
return self.json_payload({"non_field_errors": ["Invalid application"]}, 400)
|
|
|
|
|
|
|
|
if self.request.is_ajax() and response.status_code == 302:
|
|
|
|
# Web client need this to be able to redirect the user
|
|
|
|
query = urllib.parse.urlparse(response["Location"]).query
|
|
|
|
code = urllib.parse.parse_qs(query)["code"][0]
|
|
|
|
return self.json_payload(
|
|
|
|
{"redirect_uri": response["Location"], "code": code}, status_code=200
|
|
|
|
)
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
def error_response(self, error, application):
|
|
|
|
if isinstance(error, oauth2_exceptions.FatalClientError):
|
|
|
|
return self.json_payload({"detail": error.oauthlib_error.description}, 400)
|
|
|
|
return super().error_response(error, application)
|
|
|
|
|
|
|
|
def json_payload(self, payload, status_code):
|
|
|
|
return http.HttpResponse(
|
|
|
|
json.dumps(payload), status=status_code, content_type="application/json"
|
|
|
|
)
|
|
|
|
|
|
|
|
def handle_no_permission(self):
|
|
|
|
return self.json_payload(
|
|
|
|
{"detail": "Authentication credentials were not provided."}, 401
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TokenView(oauth_views.TokenView):
|
2019-09-17 09:23:59 +00:00
|
|
|
def post(self, request, *args, **kwargs):
|
|
|
|
throttling.check_request(request, "oauth-token")
|
|
|
|
return super().post(request, *args, **kwargs)
|
2019-03-25 16:02:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
class RevokeTokenView(oauth_views.RevokeTokenView):
|
2019-09-17 09:23:59 +00:00
|
|
|
def post(self, request, *args, **kwargs):
|
|
|
|
throttling.check_request(request, "oauth-revoke-token")
|
|
|
|
return super().post(request, *args, **kwargs)
|