import json import secrets import urllib.parse from django import http from django.db.models import Q from django.utils import timezone from drf_spectacular.utils import extend_schema from oauth2_provider import exceptions as oauth2_exceptions from oauth2_provider import views as oauth_views from oauth2_provider.settings import oauth2_settings from rest_framework import mixins, permissions, response, views, viewsets from rest_framework.decorators import action from funkwhale_api.common import throttling from .. import models from . import serializers from .permissions import ScopePermission class ApplicationViewSet( mixins.CreateModelMixin, mixins.ListModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet, ): anonymous_policy = True required_scope = { "retrieve": None, "create": None, "destroy": "write:security", "update": "write:security", "partial_update": "write:security", "refresh_token": "write:security", "list": "read:security", } lookup_field = "client_id" queryset = models.Application.objects.all().order_by("-created") serializer_class = serializers.ApplicationSerializer throttling_scopes = { "create": { "anonymous": "anonymous-oauth-app", "authenticated": "authenticated-oauth-app", } } def create(self, request, *args, **kwargs): request_data = request.data.copy() secret = secrets.token_hex(64) request_data["client_secret"] = secret serializer = self.get_serializer(data=request_data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) data = serializer.data # Since the serializer returns a hashed secret, we need to override it for the response. data["client_secret"] = secret return response.Response(data, status=201, headers=headers) def get_serializer_class(self): if self.request.method.lower() == "post": return serializers.CreateApplicationSerializer return super().get_serializer_class() def perform_create(self, serializer): return serializer.save( client_type=models.Application.CLIENT_CONFIDENTIAL, authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE, user=self.request.user if self.request.user.is_authenticated else None, token=models.get_token() if self.request.user.is_authenticated else None, ) def get_serializer(self, *args, **kwargs): serializer_class = self.get_serializer_class() try: owned = args[0].user == self.request.user except (IndexError, AttributeError): owned = False if owned: serializer_class = serializers.CreateApplicationSerializer kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) def get_queryset(self): qs = super().get_queryset() if self.action in [ "list", "destroy", "update", "partial_update", "refresh_token", ]: qs = qs.filter(user=self.request.user) return qs @extend_schema(operation_id="refresh_oauth_token") @action( detail=True, methods=["post"], url_name="refresh_token", url_path="refresh-token", ) def refresh_token(self, request, *args, **kwargs): app = self.get_object() if not app.user_id or request.user != app.user: return response.Response(status=404) app.token = models.get_token() app.save(update_fields=["token"]) serializer = serializers.CreateApplicationSerializer(app) return response.Response(serializer.data, status=200) class GrantViewSet( mixins.RetrieveModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, ): """ This is a viewset that list applications that have access to the request user account, to allow revoking tokens easily. """ permission_classes = [permissions.IsAuthenticated, ScopePermission] required_scope = "security" lookup_field = "client_id" queryset = models.Application.objects.all().order_by("-created") serializer_class = serializers.ApplicationSerializer pagination_class = None def get_queryset(self): now = timezone.now() queryset = super().get_queryset() grants = models.Grant.objects.filter(user=self.request.user, expires__gt=now) access_tokens = models.AccessToken.objects.filter(user=self.request.user) refresh_tokens = models.RefreshToken.objects.filter( user=self.request.user, revoked=None ) return queryset.filter( Q(pk__in=access_tokens.values("application")) | Q(pk__in=refresh_tokens.values("application")) | Q(pk__in=grants.values("application")) ).distinct() def perform_create(self, serializer): return serializer.save( client_type=models.Application.CLIENT_CONFIDENTIAL, authorization_grant_type=models.Application.GRANT_AUTHORIZATION_CODE, ) def perform_destroy(self, instance): application = instance access_tokens = application.accesstoken_set.filter(user=self.request.user) for token in access_tokens: token.revoke() refresh_tokens = application.refreshtoken_set.filter(user=self.request.user) for token in refresh_tokens: try: token.revoke() except models.AccessToken.DoesNotExist: token.access_token = None token.revoked = timezone.now() token.save(update_fields=["access_token", "revoked"]) grants = application.grant_set.filter(user=self.request.user) grants.delete() class AuthorizeView(views.APIView, oauth_views.AuthorizationView): permission_classes = [permissions.IsAuthenticated] server_class = oauth2_settings.OAUTH2_SERVER_CLASS validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS oauthlib_backend_class = oauth2_settings.OAUTH2_BACKEND_CLASS skip_authorization_completely = False oauth2_data = {} def form_invalid(self, form): """ Return a JSON response instead of a template one """ errors = form.errors return self.json_payload(errors, status_code=400) def post(self, request, *args, **kwargs): throttling.check_request(request, "oauth-authorize") return super().post(request, *args, **kwargs) def form_valid(self, form): try: return super().form_valid(form) except models.Application.DoesNotExist: return self.json_payload({"non_field_errors": ["Invalid application"]}, 400) def redirect(self, redirect_to, application): if self.request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest": # Web client need this to be able to redirect the user query = urllib.parse.urlparse(redirect_to).query code = urllib.parse.parse_qs(query)["code"][0] return self.json_payload( {"redirect_uri": redirect_to, "code": code}, status_code=200 ) return super().redirect(redirect_to, application) def error_response(self, error, application): if isinstance(error, oauth2_exceptions.FatalClientError): return self.json_payload({"detail": error.oauthlib_error.description}, 400) return super().error_response(error, application) def json_payload(self, payload, status_code): return http.HttpResponse( json.dumps(payload), status=status_code, content_type="application/json" ) def handle_no_permission(self): return self.json_payload( {"detail": "Authentication credentials were not provided."}, 401 ) class TokenView(oauth_views.TokenView): def post(self, request, *args, **kwargs): throttling.check_request(request, "oauth-token") return super().post(request, *args, **kwargs) class RevokeTokenView(oauth_views.RevokeTokenView): def post(self, request, *args, **kwargs): throttling.check_request(request, "oauth-revoke-token") return super().post(request, *args, **kwargs)