diff --git a/static/css/style.css b/static/css/style.css index d9a841b..bbf3238 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -1358,6 +1358,10 @@ table.metadata td .emoji { cursor: pointer; } +.message.error { + background-color: var(--color-bg-error); +} + /* Identity banner */ .identity-banner { diff --git a/takahe/urls.py b/takahe/urls.py index f40eb0c..8ed647e 100644 --- a/takahe/urls.py +++ b/takahe/urls.py @@ -133,6 +133,11 @@ urlpatterns = [ admin.FederationRoot.as_view(), name="admin_federation", ), + path( + "admin/federation/blocklist/", + admin.FederationBlocklist.as_view(), + name="admin_federation_blocklist", + ), path( "admin/federation//", admin.FederationEdit.as_view(), diff --git a/templates/admin/federation.html b/templates/admin/federation.html index 0e687d9..7ca744c 100644 --- a/templates/admin/federation.html +++ b/templates/admin/federation.html @@ -8,6 +8,9 @@ +
+ Import Blocklist +
{% for domain in page_obj %} diff --git a/templates/admin/federation_blocklist.html b/templates/admin/federation_blocklist.html new file mode 100644 index 0000000..d2327d0 --- /dev/null +++ b/templates/admin/federation_blocklist.html @@ -0,0 +1,18 @@ +{% extends "admin/base_main.html" %} + +{% block subtitle %}Federation Blocklist{% endblock %} + +{% block settings_content %} + + {% csrf_token %} +

Import Blocklist

+ +
+ {% include "forms/_field.html" with field=form.blocklist %} +
+
+ Back + +
+ +{% endblock %} diff --git a/tests/users/services/test_domain.py b/tests/users/services/test_domain.py new file mode 100644 index 0000000..d890309 --- /dev/null +++ b/tests/users/services/test_domain.py @@ -0,0 +1,11 @@ +import pytest + +from users.models import Domain +from users.services import DomainService + + +@pytest.mark.django_db +def test_block(): + DomainService.block(["block1.example.com", "block2.example.com"]) + + assert Domain.objects.filter(blocked=True).count() == 2 diff --git a/users/services/__init__.py b/users/services/__init__.py index aec7009..947be6e 100644 --- a/users/services/__init__.py +++ b/users/services/__init__.py @@ -1,3 +1,4 @@ from .announcement import AnnouncementService # noqa +from .domain import DomainService # noqa from .identity import IdentityService # noqa from .user import UserService # noqa diff --git a/users/services/domain.py b/users/services/domain.py new file mode 100644 index 0000000..e0940fa --- /dev/null +++ b/users/services/domain.py @@ -0,0 +1,22 @@ +from users.models import Domain + + +class DomainService: + """ + High-level domain handling methods + """ + + @classmethod + def block(cls, domains: list[str]) -> None: + domains_to_block = Domain.objects.filter(domain__in=domains) + domains_to_block.update(blocked=True) + + already_blocked = domains_to_block.values_list("domain", flat=True) + domains_to_create = [] + for domain in domains: + if domain not in already_blocked: + domains_to_create.append( + Domain(domain=domain, blocked=True, local=False) + ) + + Domain.objects.bulk_create(domains_to_create) diff --git a/users/views/admin/__init__.py b/users/views/admin/__init__.py index 02f6c13..3330b1b 100644 --- a/users/views/admin/__init__.py +++ b/users/views/admin/__init__.py @@ -23,7 +23,11 @@ from users.views.admin.emoji import ( # noqa EmojiEnable, EmojiRoot, ) -from users.views.admin.federation import FederationEdit, FederationRoot # noqa +from users.views.admin.federation import ( # noqa + FederationBlocklist, + FederationEdit, + FederationRoot, +) from users.views.admin.hashtags import HashtagEdit, HashtagEnable, Hashtags # noqa from users.views.admin.identities import IdentitiesRoot, IdentityEdit # noqa from users.views.admin.invites import InviteCreate, InvitesRoot, InviteView # noqa diff --git a/users/views/admin/federation.py b/users/views/admin/federation.py index a480a0e..22105cb 100644 --- a/users/views/admin/federation.py +++ b/users/views/admin/federation.py @@ -1,4 +1,8 @@ +import csv + from django import forms +from django.contrib import messages +from django.core.validators import FileExtensionValidator, ValidationError from django.db import models from django.shortcuts import get_object_or_404, redirect from django.utils.decorators import method_decorator @@ -6,11 +10,12 @@ from django.views.generic import FormView, ListView from users.decorators import admin_required from users.models import Domain +from users.services import DomainService +from users.views.admin.domains import DomainValidator @method_decorator(admin_required, name="dispatch") class FederationRoot(ListView): - template_name = "admin/federation.html" paginate_by = 50 @@ -35,7 +40,6 @@ class FederationRoot(ListView): @method_decorator(admin_required, name="dispatch") class FederationEdit(FormView): - template_name = "admin/federation_edit.html" extra_context = {"section": "federation"} @@ -78,3 +82,61 @@ class FederationEdit(FormView): "blocked": self.domain.blocked, "notes": self.domain.notes, } + + +@method_decorator(admin_required, name="dispatch") +class FederationBlocklist(FormView): + template_name = "admin/federation_blocklist.html" + extra_context = {"section": "federation"} + error_msg = "The uploaded file has an invalid blocklist CSV format." + success_msg = "The blocklist CSV was processed processed with success!" + + class form_class(forms.Form): + blocklist = forms.FileField( + help_text=( + "Blocklist file with one domain per line. " + "Oliphant blocklist format is also supported." + ), + validators=[FileExtensionValidator(allowed_extensions=["txt", "csv"])], + ) + + def form_valid(self, form): + validator = DomainValidator() + domains = [] + + try: + lines = form.cleaned_data["blocklist"].read().decode("utf-8").splitlines() + + if "#domain" in lines[0]: + reader = csv.DictReader(lines) + else: + reader = csv.DictReader(lines, fieldnames=["#domain"]) + + for row in reader: + domain = row["#domain"].strip() + + try: + validator(domain) + except ValidationError: + # skip adding invalid domain + # to the blocklist + continue + + domains.append(domain) + except (TypeError, ValueError): + messages.error(self.request, self.error_msg) + return redirect(".") + + if not domains: + messages.error(self.request, self.error_msg) + return redirect(".") + + DomainService.block(domains) + + messages.success(self.request, self.success_msg) + return redirect("admin_federation") + + def get_context_data(self, **kwargs): + context = super().get_context_data(**kwargs) + context["page"] = self.request.GET.get("page") + return context