From 550dbe46cc7d5fe2b239733bcd00e2a214b5ab11 Mon Sep 17 00:00:00 2001
From: Agate
", 1) css = "".format(css) tail = body + "\n" + css + "\n" + tail - return http.HttpResponse(head + tail) + + # set a csrf token so that visitor can login / query API if needed + token = csrf.get_token(request) + response = http.HttpResponse(head + tail) + response.set_cookie("csrftoken", token, max_age=None) + return response MANIFEST_LINK_REGEX = re.compile(r"]*rel=(?:'|\")?manifest(?:'|\")?[^>]*>") diff --git a/api/funkwhale_api/users/api_urls.py b/api/funkwhale_api/users/api_urls.py index 89930f57b..1c39797f2 100644 --- a/api/funkwhale_api/users/api_urls.py +++ b/api/funkwhale_api/users/api_urls.py @@ -1,8 +1,11 @@ +from django.conf.urls import url from funkwhale_api.common import routers - from . import views router = routers.OptionalSlashRouter() router.register(r"users", views.UserViewSet, "users") -urlpatterns = router.urls +urlpatterns = [ + url(r"^users/login/?$", views.login, name="login"), + url(r"^users/logout/?$", views.logout, name="logout"), +] + router.urls diff --git a/api/funkwhale_api/users/serializers.py b/api/funkwhale_api/users/serializers.py index 542f6e58a..8646d3b4a 100644 --- a/api/funkwhale_api/users/serializers.py +++ b/api/funkwhale_api/users/serializers.py @@ -4,6 +4,8 @@ from django.core import validators from django.utils.deconstruct import deconstructible from django.utils.translation import gettext_lazy as _ +from django.contrib import auth + from rest_auth.serializers import PasswordResetSerializer as PRS from rest_auth.registration.serializers import RegisterSerializer as RS, get_adapter from rest_framework import serializers @@ -265,3 +267,23 @@ class UserDeleteSerializer(serializers.Serializer): if not value: raise serializers.ValidationError("Please confirm deletion") return value + + +class LoginSerializer(serializers.Serializer): + username = serializers.CharField() + password = serializers.CharField() + + def validate(self, data): + user = auth.authenticate(request=self.context.get("request"), **data) + if not user: + raise serializers.ValidationError( + "Unable to log in with provided credentials" + ) + + if not user.is_active: + raise serializers.ValidationError("This account was disabled") + + return user + + def save(self, request): + return auth.login(request, self.validated_data) diff --git a/api/funkwhale_api/users/views.py b/api/funkwhale_api/users/views.py index 848bc7e6b..a143c4fd2 100644 --- a/api/funkwhale_api/users/views.py +++ b/api/funkwhale_api/users/views.py @@ -1,12 +1,20 @@ +import json + +from django import http +from django.contrib import auth +from django.middleware import csrf + from allauth.account.adapter import get_adapter from rest_auth import views as rest_auth_views from rest_auth.registration import views as registration_views -from rest_framework import mixins, viewsets +from rest_framework import mixins +from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.response import Response from funkwhale_api.common import authentication from funkwhale_api.common import preferences +from funkwhale_api.common import throttling from . import models, serializers, tasks @@ -105,3 +113,26 @@ class UserViewSet(mixins.UpdateModelMixin, viewsets.GenericViewSet): if not self.request.user.username == kwargs.get("username"): return Response(status=403) return super().partial_update(request, *args, **kwargs) + + +def login(request): + throttling.check_request(request, "login") + if request.method != "POST": + return http.HttpResponse(status=405) + serializer = serializers.LoginSerializer( + data=request.POST, context={"request": request} + ) + if not serializer.is_valid(): + return http.HttpResponse( + json.dumps(serializer.errors), status=400, content_type="application/json" + ) + serializer.save(request) + csrf.rotate_token(request) + return http.HttpResponse(status=200) + + +def logout(request): + if request.method != "POST": + return http.HttpResponse(status=405) + auth.logout(request) + return http.HttpResponse(status=200) diff --git a/api/tests/common/test_middleware.py b/api/tests/common/test_middleware.py index 8f04ba318..b5d4d02f1 100644 --- a/api/tests/common/test_middleware.py +++ b/api/tests/common/test_middleware.py @@ -14,7 +14,7 @@ from funkwhale_api.common import utils def test_spa_fallback_middleware_no_404(mocker): get_response = mocker.Mock() get_response.return_value = mocker.Mock(status_code=200) - request = mocker.Mock(path="/") + request = mocker.Mock(path="/", META={}) m = middleware.SPAFallbackMiddleware(get_response) assert m(request) == get_response.return_value @@ -26,7 +26,7 @@ def test_spa_middleware_calls_should_fallback_false(mocker): should_falback = mocker.patch.object( middleware, "should_fallback_to_spa", return_value=False ) - request = mocker.Mock(path="/") + request = mocker.Mock(path="/", META={}) m = middleware.SPAFallbackMiddleware(get_response) @@ -37,7 +37,7 @@ def test_spa_middleware_calls_should_fallback_false(mocker): def test_spa_middleware_should_fallback_true(mocker): get_response = mocker.Mock() get_response.return_value = mocker.Mock(status_code=404) - request = mocker.Mock(path="/") + request = mocker.Mock(path="/", META={}) mocker.patch.object(middleware, "should_fallback_to_spa", return_value=True) serve_spa = mocker.patch.object(middleware, "serve_spa") m = middleware.SPAFallbackMiddleware(get_response) @@ -56,7 +56,7 @@ def test_should_fallback(path, expected, mocker): def test_serve_spa_from_cache(mocker, settings, preferences, no_api_auth): preferences["instance__name"] = 'Best Funkwhale "pod"' - request = mocker.Mock(path="/") + request = mocker.Mock(path="/", META={}) get_spa_html = mocker.patch.object( middleware, "get_spa_html", @@ -155,7 +155,7 @@ def test_get_route_head_tags(mocker, settings): def test_serve_spa_includes_custom_css(mocker, no_api_auth): - request = mocker.Mock(path="/") + request = mocker.Mock(path="/", META={}) mocker.patch.object( middleware, "get_spa_html", @@ -178,6 +178,23 @@ def test_serve_spa_includes_custom_css(mocker, no_api_auth): assert response.content == "\n".join(expected).encode() +def test_serve_spa_sets_csrf_token(mocker, no_api_auth): + request = mocker.Mock(path="/", META={}) + get_token = mocker.patch.object(middleware.csrf, "get_token", return_value="test") + mocker.patch.object( + middleware, + "get_spa_html", + return_value="