diff --git a/api/views/accounts.py b/api/views/accounts.py index dc68488..007ac8b 100644 --- a/api/views/accounts.py +++ b/api/views/accounts.py @@ -3,6 +3,7 @@ from typing import Any from django.core.files import File from django.http import HttpRequest from django.shortcuts import get_object_or_404 +from hatchway import ApiResponse, QueryOrBody, api_view from activities.models import Post from activities.services import SearchService @@ -10,7 +11,6 @@ from api import schemas from api.decorators import identity_required from api.pagination import MastodonPaginator, PaginatingApiResponse, PaginationResult from core.models import Config -from hatchway import ApiResponse, QueryOrBody, api_view from users.models import Identity from users.services import IdentityService from users.shortcuts import by_handle_or_404 @@ -224,8 +224,8 @@ def account_follow(request, id: str, reblogs: bool = True) -> schemas.Relationsh identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) - service = IdentityService(identity) - service.follow_from(request.identity, boosts=reblogs) + service = IdentityService(request.identity) + service.follow(identity, boosts=reblogs) return schemas.Relationship.from_identity_pair(identity, request.identity) @@ -235,8 +235,8 @@ def account_unfollow(request, id: str) -> schemas.Relationship: identity = get_object_or_404( Identity.objects.exclude(restriction=Identity.Restriction.blocked), pk=id ) - service = IdentityService(identity) - service.unfollow_from(request.identity) + service = IdentityService(request.identity) + service.unfollow(identity) return schemas.Relationship.from_identity_pair(identity, request.identity) @@ -244,8 +244,8 @@ def account_unfollow(request, id: str) -> schemas.Relationship: @identity_required def account_block(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) - service = IdentityService(identity) - service.block_from(request.identity) + service = IdentityService(request.identity) + service.block(identity) return schemas.Relationship.from_identity_pair(identity, request.identity) @@ -253,8 +253,8 @@ def account_block(request, id: str) -> schemas.Relationship: @identity_required def account_unblock(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) - service = IdentityService(identity) - service.unblock_from(request.identity) + service = IdentityService(request.identity) + service.unblock(identity) return schemas.Relationship.from_identity_pair(identity, request.identity) @@ -267,9 +267,9 @@ def account_mute( duration: QueryOrBody[int] = 0, ) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) - service = IdentityService(identity) - service.mute_from( - request.identity, + service = IdentityService(request.identity) + service.mute( + identity, duration=duration, include_notifications=notifications, ) @@ -280,8 +280,8 @@ def account_mute( @api_view.post def account_unmute(request, id: str) -> schemas.Relationship: identity = get_object_or_404(Identity, pk=id) - service = IdentityService(identity) - service.unmute_from(request.identity) + service = IdentityService(request.identity) + service.unmute(identity) return schemas.Relationship.from_identity_pair(identity, request.identity) diff --git a/takahe/urls.py b/takahe/urls.py index 456636a..9cb157d 100644 --- a/takahe/urls.py +++ b/takahe/urls.py @@ -55,6 +55,21 @@ urlpatterns = [ settings.InterfacePage.as_view(), name="settings_interface", ), + path( + "settings/import_export/", + settings.ImportExportPage.as_view(), + name="settings_import_export", + ), + path( + "settings/import_export/following.csv", + settings.CsvFollowing.as_view(), + name="settings_export_following_csv", + ), + path( + "settings/import_export/followers.csv", + settings.CsvFollowers.as_view(), + name="settings_export_followers_csv", + ), path( "admin/", admin.AdminRoot.as_view(), diff --git a/templates/settings/_menu.html b/templates/settings/_menu.html index b038e61..2cc01b0 100644 --- a/templates/settings/_menu.html +++ b/templates/settings/_menu.html @@ -6,6 +6,9 @@ Interface + + Import/Export +

Account

Login & Security diff --git a/templates/settings/import_export.html b/templates/settings/import_export.html new file mode 100644 index 0000000..d0d5de9 --- /dev/null +++ b/templates/settings/import_export.html @@ -0,0 +1,68 @@ +{% extends "settings/base.html" %} + +{% block subtitle %}Import/Export{% endblock %} + +{% block content %} +
+ {% csrf_token %} + +
+ Import + {% if bad_format %} +
Error: The file you uploaded was not a valid {{ bad_format }} CSV.
+ {% endif %} + {% if success %} +
Your {{ success }} CSV import was received. It will be processed in the background.
+ {% endif %} + {% include "forms/_field.html" with field=form.csv %} + {% include "forms/_field.html" with field=form.import_type %} + {% include "forms/_field.html" with field=form.replace %} +
+ +
+ +
+ +
+ Export + + + + + + + + + + + + + + + + + +
+ Following list + {{ numbers.outbound_follows }} {{ numbers.outbound_follows|pluralize:"follow,follows" }} + + Download CSV +
+ Followers list + {{ numbers.inbound_follows }} {{ numbers.inbound_follows|pluralize:"follower,followers" }} + + Download CSV +
+ Individual blocks + {{ numbers.blocks }} {{ numbers.blocks|pluralize:"people,people" }} + + +
+ Individual mutes + {{ numbers.mutes }} {{ numbers.mutes|pluralize:"people,people" }} + + +
+
+
+{% endblock %} diff --git a/tests/activities/models/test_timeline_event.py b/tests/activities/models/test_timeline_event.py index b0fff39..704e405 100644 --- a/tests/activities/models/test_timeline_event.py +++ b/tests/activities/models/test_timeline_event.py @@ -207,8 +207,8 @@ def test_clear_timeline( Ensures that timeline clearing works as expected. """ # Follow the remote user - service = IdentityService(remote_identity) - service.follow_from(identity) + service = IdentityService(identity) + service.follow(remote_identity) # Create an inbound new post message mentioning us message = { "id": "test", @@ -243,9 +243,9 @@ def test_clear_timeline( # Now, submit either a user block (for full clear) or unfollow (for post clear) if full: - service.block_from(identity) + service.block(remote_identity) else: - service.unfollow_from(identity) + service.unfollow(remote_identity) # Run stator once to process the timeline clear message stator.run_single_cycle_sync() diff --git a/tests/users/models/test_follow.py b/tests/users/models/test_follow.py index c02ce15..d23dc1e 100644 --- a/tests/users/models/test_follow.py +++ b/tests/users/models/test_follow.py @@ -20,7 +20,7 @@ def test_follow( Ensures that follow sending and acceptance works """ # Make the follow - follow = IdentityService(remote_identity).follow_from(identity) + follow = IdentityService(identity).follow(remote_identity) assert Follow.objects.get(pk=follow.pk).state == FollowStates.unrequested # Run stator to make it try and send out the remote request httpx_mock.add_response( diff --git a/users/models/inbox_message.py b/users/models/inbox_message.py index d65495d..82a166b 100644 --- a/users/models/inbox_message.py +++ b/users/models/inbox_message.py @@ -16,6 +16,7 @@ class InboxMessageStates(StateGraph): async def handle_received(cls, instance: "InboxMessage"): from activities.models import Post, PostInteraction, TimelineEvent from users.models import Block, Follow, Identity, Report + from users.services import IdentityService match instance.message_type: case "follow": @@ -154,6 +155,10 @@ class InboxMessageStates(StateGraph): await sync_to_async(TimelineEvent.handle_clear_timeline)( instance.message["object"] ) + case "addfollow": + await sync_to_async(IdentityService.handle_internal_add_follow)( + instance.message["object"] + ) case unknown: raise ValueError( f"Cannot handle activity of type __internal__.{unknown}" diff --git a/users/services/identity.py b/users/services/identity.py index 049ab40..5a4b741 100644 --- a/users/services/identity.py +++ b/users/services/identity.py @@ -58,8 +58,10 @@ class IdentityService: def following(self) -> models.QuerySet[Identity]: return ( - Identity.objects.active() - .filter(inbound_follows__source=self.identity) + Identity.objects.filter( + inbound_follows__source=self.identity, + inbound_follows__state__in=FollowStates.group_active(), + ) .not_deleted() .order_by("username") .select_related("domain") @@ -67,91 +69,94 @@ class IdentityService: def followers(self) -> models.QuerySet[Identity]: return ( - Identity.objects.filter(outbound_follows__target=self.identity) + Identity.objects.filter( + outbound_follows__target=self.identity, + inbound_follows__state__in=FollowStates.group_active(), + ) .not_deleted() .order_by("username") .select_related("domain") ) - def follow_from(self, from_identity: Identity, boosts=True) -> Follow: + def follow(self, target_identity: Identity, boosts=True) -> Follow: """ Follows a user (or does nothing if already followed). Returns the follow. """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot follow yourself") - return Follow.create_local(from_identity, self.identity, boosts=boosts) + return Follow.create_local(self.identity, target_identity, boosts=boosts) - def unfollow_from(self, from_identity: Identity): + def unfollow(self, target_identity: Identity): """ Unfollows a user (or does nothing if not followed). """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot unfollow yourself") - existing_follow = Follow.maybe_get(from_identity, self.identity) + existing_follow = Follow.maybe_get(self.identity, target_identity) if existing_follow: existing_follow.transition_perform(FollowStates.undone) InboxMessage.create_internal( { "type": "ClearTimeline", - "actor": from_identity.pk, - "object": self.identity.pk, + "object": target_identity.pk, + "actor": self.identity.pk, } ) - def block_from(self, from_identity: Identity) -> Block: + def block(self, target_identity: Identity) -> Block: """ Blocks a user. """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot block yourself") - self.unfollow_from(from_identity) - block = Block.create_local_block(from_identity, self.identity) + self.unfollow(target_identity) + block = Block.create_local_block(self.identity, target_identity) InboxMessage.create_internal( { "type": "ClearTimeline", - "actor": from_identity.pk, - "object": self.identity.pk, + "actor": self.identity.pk, + "object": target_identity.pk, "fullErase": True, } ) return block - def unblock_from(self, from_identity: Identity): + def unblock(self, target_identity: Identity): """ Unlocks a user """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot unblock yourself") - existing_block = Block.maybe_get(from_identity, self.identity, mute=False) + existing_block = Block.maybe_get(self.identity, target_identity, mute=False) if existing_block and existing_block.active: existing_block.transition_perform(BlockStates.undone) - def mute_from( + def mute( self, - from_identity: Identity, + target_identity: Identity, duration: int = 0, include_notifications: bool = False, ) -> Block: """ Mutes a user. """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot mute yourself") return Block.create_local_mute( - from_identity, self.identity, + target_identity, duration=duration or None, include_notifications=include_notifications, ) - def unmute_from(self, from_identity: Identity): + def unmute(self, target_identity: Identity): """ Unmutes a user """ - if from_identity == self.identity: + if target_identity == self.identity: raise ValueError("You cannot unmute yourself") - existing_block = Block.maybe_get(from_identity, self.identity, mute=True) + existing_block = Block.maybe_get(self.identity, target_identity, mute=True) if existing_block and existing_block.active: existing_block.transition_perform(BlockStates.undone) @@ -234,3 +239,26 @@ class IdentityService: file.name, resize_image(file, size=(1500, 500)), ) + + @classmethod + def handle_internal_add_follow(cls, payload): + """ + Handles an inbox message saying we need to follow a handle + + Message format: + { + "type": "AddFollow", + "source": "90310938129083", + "target_handle": "andrew@aeracode.org", + "boosts": true, + } + """ + # Retrieve ourselves + self = cls(Identity.objects.get(pk=payload["source"])) + # Get the remote end (may need a fetch) + username, domain = payload["target_handle"].split("@") + target_identity = Identity.by_username_and_domain(username, domain, fetch=True) + if target_identity is None: + raise ValueError(f"Cannot find identity to follow: {target_identity}") + # Follow! + self.follow(target_identity=target_identity, boosts=payload.get("boosts", True)) diff --git a/users/views/identity.py b/users/views/identity.py index 2cce29f..84664f2 100644 --- a/users/views/identity.py +++ b/users/views/identity.py @@ -249,21 +249,21 @@ class ActionIdentity(View): # See what action we should perform action = self.request.POST["action"] if action == "follow": - IdentityService(identity).follow_from(self.request.identity) + IdentityService(request.identity).follow(identity) elif action == "unfollow": - IdentityService(identity).unfollow_from(self.request.identity) + IdentityService(request.identity).unfollow(identity) elif action == "block": - IdentityService(identity).block_from(self.request.identity) + IdentityService(request.identity).block(identity) elif action == "unblock": - IdentityService(identity).unblock_from(self.request.identity) + IdentityService(request.identity).unblock(identity) elif action == "mute": - IdentityService(identity).mute_from(self.request.identity) + IdentityService(request.identity).mute(identity) elif action == "unmute": - IdentityService(identity).unmute_from(self.request.identity) + IdentityService(request.identity).unmute(identity) elif action == "hide_boosts": - IdentityService(identity).follow_from(self.request.identity, boosts=False) + IdentityService(request.identity).follow(identity, boosts=False) elif action == "show_boosts": - IdentityService(identity).follow_from(self.request.identity, boosts=True) + IdentityService(request.identity).follow(identity, boosts=True) else: raise ValueError(f"Cannot handle identity action {action}") return redirect(identity.urls.view) diff --git a/users/views/settings/__init__.py b/users/views/settings/__init__.py index f3332d3..e58c4bf 100644 --- a/users/views/settings/__init__.py +++ b/users/views/settings/__init__.py @@ -2,6 +2,11 @@ from django.utils.decorators import method_decorator from django.views.generic import RedirectView from users.decorators import identity_required +from users.views.settings.import_export import ( # noqa + CsvFollowers, + CsvFollowing, + ImportExportPage, +) from users.views.settings.interface import InterfacePage # noqa from users.views.settings.profile import ProfilePage # noqa from users.views.settings.security import SecurityPage # noqa diff --git a/users/views/settings/import_export.py b/users/views/settings/import_export.py new file mode 100644 index 0000000..b165503 --- /dev/null +++ b/users/views/settings/import_export.py @@ -0,0 +1,154 @@ +import csv + +from django import forms +from django.http import HttpResponse +from django.shortcuts import redirect +from django.utils.decorators import method_decorator +from django.views.generic import FormView, View + +from users.decorators import identity_required +from users.models import Follow, InboxMessage + + +@method_decorator(identity_required, name="dispatch") +class ImportExportPage(FormView): + """ + Lets the identity's profile be edited + """ + + template_name = "settings/import_export.html" + extra_context = {"section": "importexport"} + + class form_class(forms.Form): + csv = forms.FileField(help_text="The CSV file you want to import") + import_type = forms.ChoiceField( + help_text="The type of data you wish to import", + choices=[("following", "Following list")], + ) + + def form_valid(self, form): + # Load CSV (we don't touch the DB till the whole file comes in clean) + try: + lines = form.cleaned_data["csv"].read().decode("utf-8").splitlines() + reader = csv.DictReader(lines) + prepared_data = [] + for row in reader: + entry = { + "handle": row["Account address"], + "boosts": not (row["Show boosts"].lower().strip()[0] == "f"), + } + if len(entry["handle"].split("@")) != 2: + raise ValueError("Handle looks wrong") + prepared_data.append(entry) + except (TypeError, ValueError): + return redirect(".?bad_format=following") + # For each one, add an inbox message to create that follow + # We can't do them all inline here as the identity fetch might take ages + for entry in prepared_data: + InboxMessage.create_internal( + { + "type": "AddFollow", + "source": self.request.identity.pk, + "target_handle": entry["handle"], + "boosts": entry["boosts"], + } + ) + return redirect(".?success=following") + + def get_context_data(self, **kwargs): + context = super().get_context_data(**kwargs) + context["numbers"] = { + "outbound_follows": self.request.identity.outbound_follows.active().count(), + "inbound_follows": self.request.identity.inbound_follows.active().count(), + "blocks": self.request.identity.outbound_blocks.active() + .filter(mute=False) + .count(), + "mutes": self.request.identity.outbound_blocks.active() + .filter(mute=True) + .count(), + } + context["bad_format"] = self.request.GET.get("bad_format") + context["success"] = self.request.GET.get("success") + return context + + +class CsvView(View): + """ + Generic view that exports a queryset as a CSV + """ + + # Mapping of CSV column title to method or model attribute name + # We rely on the fact that python dicts are stably ordered! + columns: dict[str, str] + + # Filename to download as + filename: str = "export.csv" + + def get_queryset(self): + raise NotImplementedError() + + def get(self, request): + response = HttpResponse( + content_type="text/csv", + headers={"Content-Disposition": f'attachment; filename="{self.filename}"'}, + ) + writer = csv.writer(response) + writer.writerow(self.columns.keys()) + for item in self.get_queryset(request): + row = [] + for attrname in self.columns.values(): + # Get value + getter = getattr(self, attrname, None) + if getter: + value = getter(item) + elif hasattr(item, attrname): + value = getattr(item, attrname) + else: + raise ValueError(f"Cannot export attribute {attrname}") + # Make it into CSV format + if type(value) == bool: + value = "true" if value else "false" + elif type(value) == int: + value = str(value) + row.append(value) + writer.writerow(row) + return response + + +class CsvFollowing(CsvView): + + columns = { + "Account address": "get_handle", + "Show boosts": "boosts", + "Notify on new posts": "get_notify", + "Languages": "get_languages", + } + + filename = "following.csv" + + def get_queryset(self, request): + return self.request.identity.outbound_follows.active() + + def get_handle(self, follow: Follow): + return follow.target.handle + + def get_notify(self, follow: Follow): + return False + + def get_languages(self, follow: Follow): + return "" + + +class CsvFollowers(CsvView): + + columns = { + "Account address": "get_handle", + } + + filename = "followers.csv" + + def get_queryset(self, request): + return self.request.identity.inbound_follows.active() + + def get_handle(self, follow: Follow): + return follow.target.handle