Cache radio queryset. New api endpoint for radio tracks : api/v2/radios/sessions/$sessionid/tracks?count=$count

environments/review-docs-colle-cwr3tw/deployments/18479
petitminion 2023-09-25 22:28:11 +00:00
rodzic 04acd056e6
commit 4ad806b8e9
13 zmienionych plików z 845 dodań i 77 usunięć

Wyświetl plik

@ -10,6 +10,10 @@ v2_patterns += [
r"^instance/",
include(("funkwhale_api.instance.urls", "instance"), namespace="instance"),
),
url(
r"^radios/",
include(("funkwhale_api.radios.urls_v2", "radios"), namespace="radios"),
),
]
urlpatterns = [url("", include((v2_patterns, "v2"), namespace="v2"))]

Wyświetl plik

@ -54,10 +54,6 @@ class RadioSession(models.Model):
CONFIG_VERSION = 0
config = JSONField(encoder=DjangoJSONEncoder, blank=True, null=True)
def save(self, **kwargs):
self.radio.clean(self)
super().save(**kwargs)
@property
def next_position(self):
next_position = 1
@ -68,16 +64,24 @@ class RadioSession(models.Model):
return next_position
def add(self, track):
new_session_track = RadioSessionTrack.objects.create(
track=track, session=self, position=self.next_position
)
def add(self, tracks):
next_position = self.next_position
radio_session_tracks = []
for i, track in enumerate(tracks):
radio_session_track = RadioSessionTrack(
track=track, session=self, position=next_position + i
)
radio_session_tracks.append(radio_session_track)
return new_session_track
new_session_tracks = RadioSessionTrack.objects.bulk_create(radio_session_tracks)
@property
def radio(self):
from .registries import registry
return new_session_tracks
def radio(self, api_version):
if api_version == 2:
from .registries_v2 import registry
else:
from .registries import registry
return registry[self.radio_type](session=self)

Wyświetl plik

@ -13,10 +13,9 @@ from funkwhale_api.federation import fields as federation_fields
from funkwhale_api.federation import models as federation_models
from funkwhale_api.moderation import filters as moderation_filters
from funkwhale_api.music.models import Artist, Library, Track, Upload
from funkwhale_api.radios import lb_recommendations
from funkwhale_api.tags.models import Tag
from . import filters, models
from . import filters, lb_recommendations, models
from .registries import registry
logger = logging.getLogger(__name__)
@ -63,11 +62,19 @@ class SessionRadio(SimpleRadio):
return self.session
def get_queryset(self, **kwargs):
qs = Track.objects.all()
if not self.session:
return qs
if not self.session.user:
return qs
if not self.session or not self.session.user:
return (
Track.objects.all()
.with_playable_uploads(actor=None)
.select_related("artist", "album__artist", "attributed_to")
)
else:
qs = (
Track.objects.all()
.with_playable_uploads(self.session.user.actor)
.select_related("artist", "album__artist", "attributed_to")
)
query = moderation_filters.get_filtered_content_query(
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
user=self.session.user,
@ -77,6 +84,16 @@ class SessionRadio(SimpleRadio):
def get_queryset_kwargs(self):
return {}
def filter_queryset(self, queryset):
return queryset
def filter_from_session(self, queryset):
already_played = self.session.session_tracks.all().values_list(
"track", flat=True
)
queryset = queryset.exclude(pk__in=already_played)
return queryset
def get_choices(self, **kwargs):
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
@ -89,16 +106,6 @@ class SessionRadio(SimpleRadio):
queryset = self.filter_queryset(queryset)
return queryset
def filter_queryset(self, queryset):
return queryset
def filter_from_session(self, queryset):
already_played = self.session.session_tracks.all().values_list(
"track", flat=True
)
queryset = queryset.exclude(pk__in=already_played)
return queryset
def pick(self, **kwargs):
return self.pick_many(quantity=1, **kwargs)[0]
@ -106,8 +113,7 @@ class SessionRadio(SimpleRadio):
choices = self.get_choices(**kwargs)
picked_choices = super().pick_many(choices=choices, quantity=quantity)
if self.session:
for choice in picked_choices:
self.session.add(choice)
self.session.add(picked_choices)
return picked_choices
def validate_session(self, data, **context):
@ -191,7 +197,9 @@ class CustomMultiple(SessionRadio):
def validate_session(self, data, **context):
data = super().validate_session(data, **context)
if data.get("config") is None:
try:
data["config"] is not None
except KeyError:
raise serializers.ValidationError(
"You must provide a configuration for this radio"
)

Wyświetl plik

@ -0,0 +1,510 @@
import datetime
import json
import logging
import pickle
import random
from typing import List, Optional, Tuple
from django.core.cache import cache
from django.core.exceptions import ValidationError
from django.db import connection
from django.db.models import Q
from rest_framework import serializers
from funkwhale_api.federation import fields as federation_fields
from funkwhale_api.federation import models as federation_models
from funkwhale_api.moderation import filters as moderation_filters
from funkwhale_api.music.models import Artist, Library, Track, Upload
from funkwhale_api.tags.models import Tag
from . import filters, lb_recommendations, models
from .registries_v2 import registry
logger = logging.getLogger(__name__)
class SimpleRadio:
related_object_field = None
def clean(self, instance):
return
def weighted_pick(
self,
choices: List[Tuple[int, int]],
previous_choices: Optional[List[int]] = None,
) -> int:
total = sum(weight for c, weight in choices)
r = random.uniform(0, total)
upto = 0
for choice, weight in choices:
if upto + weight >= r:
return choice
upto += weight
class SessionRadio(SimpleRadio):
def __init__(self, session=None):
self.session = session
def start_session(self, user, **kwargs):
self.session = models.RadioSession.objects.create(
user=user, radio_type=self.radio_type, **kwargs
)
return self.session
def get_queryset(self, **kwargs):
actor = None
try:
actor = self.session.user.actor
except KeyError:
pass # Maybe logging would be helpful
qs = (
Track.objects.all()
.with_playable_uploads(actor=actor)
.select_related("artist", "album__artist", "attributed_to")
)
query = moderation_filters.get_filtered_content_query(
config=moderation_filters.USER_FILTER_CONFIG["TRACK"],
user=self.session.user,
)
return qs.exclude(query)
def get_queryset_kwargs(self):
return {}
def filter_queryset(self, queryset):
return queryset
def filter_from_session(self, queryset):
already_played = self.session.session_tracks.all().values_list(
"track", flat=True
)
queryset = queryset.exclude(pk__in=already_played)
return queryset
def cache_batch_radio_track(self, **kwargs):
BATCH_SIZE = 100
# get cached RadioTracks if any
try:
cached_evaluated_radio_tracks = pickle.loads(
cache.get(f"radiotracks{self.session.id}")
)
except TypeError:
cached_evaluated_radio_tracks = None
# get the queryset and apply filters
kwargs.update(self.get_queryset_kwargs())
queryset = self.get_queryset(**kwargs)
queryset = self.filter_from_session(queryset)
if kwargs["filter_playable"] is True:
queryset = queryset.playable_by(
self.session.user.actor if self.session.user else None
)
queryset = self.filter_queryset(queryset)
# select a random batch of the qs
sliced_queryset = queryset.order_by("?")[:BATCH_SIZE]
if len(sliced_queryset) <= 0 and not cached_evaluated_radio_tracks:
raise ValueError("No more radio candidates")
# create the radio session tracks into db in bulk
self.session.add(sliced_queryset)
# evaluate the queryset to save it in cache
radio_tracks = list(sliced_queryset)
if cached_evaluated_radio_tracks is not None:
radio_tracks.extend(cached_evaluated_radio_tracks)
logger.info(
f"Setting redis cache for radio generation with radio id {self.session.id}"
)
cache.set(f"radiotracks{self.session.id}", pickle.dumps(radio_tracks), 3600)
cache.set(f"radioqueryset{self.session.id}", sliced_queryset, 3600)
return sliced_queryset
def get_choices(self, quantity, **kwargs):
if cache.get(f"radiotracks{self.session.id}"):
cached_radio_tracks = pickle.loads(
cache.get(f"radiotracks{self.session.id}")
)
logger.info("Using redis cache for radio generation")
radio_tracks = cached_radio_tracks
if len(radio_tracks) < quantity:
logger.info(
"Not enough radio tracks in cache. Trying to generate new cache"
)
sliced_queryset = self.cache_batch_radio_track(**kwargs)
sliced_queryset = cache.get(f"radioqueryset{self.session.id}")
else:
sliced_queryset = self.cache_batch_radio_track(**kwargs)
return sliced_queryset[:quantity]
def pick_many(self, quantity, **kwargs):
if self.session:
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
else:
logger.info(
"No radio session. Can't track user playback. Won't cache queryset results"
)
sliced_queryset = self.get_choices(quantity=quantity, **kwargs)
return sliced_queryset
def validate_session(self, data, **context):
return data
@registry.register(name="random")
class RandomRadio(SessionRadio):
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
return qs.filter(artist__content_category="music").order_by("?")
@registry.register(name="random_library")
class RandomLibraryRadio(SessionRadio):
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
"id", flat=True
)
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
return qs.filter(query).order_by("?")
@registry.register(name="favorites")
class FavoritesRadio(SessionRadio):
def get_queryset_kwargs(self):
kwargs = super().get_queryset_kwargs()
if self.session:
kwargs["user"] = self.session.user
return kwargs
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True)
return qs.filter(pk__in=track_ids, artist__content_category="music")
@registry.register(name="custom")
class CustomRadio(SessionRadio):
def get_queryset_kwargs(self):
kwargs = super().get_queryset_kwargs()
kwargs["user"] = self.session.user
kwargs["custom_radio"] = self.session.custom_radio
return kwargs
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
return filters.run(kwargs["custom_radio"].config, candidates=qs)
def validate_session(self, data, **context):
data = super().validate_session(data, **context)
try:
user = data["user"]
except KeyError:
user = context.get("user")
try:
assert data["custom_radio"].user == user or data["custom_radio"].is_public
except KeyError:
raise serializers.ValidationError("You must provide a custom radio")
except AssertionError:
raise serializers.ValidationError("You don't have access to this radio")
return data
@registry.register(name="custom_multiple")
class CustomMultiple(SessionRadio):
"""
Receive a vuejs generated config and use it to launch a radio session
"""
config = serializers.JSONField(required=True)
def get_config(self, data):
return data["config"]
def get_queryset_kwargs(self):
kwargs = super().get_queryset_kwargs()
kwargs["config"] = self.session.config
return kwargs
def validate_session(self, data, **context):
data = super().validate_session(data, **context)
try:
data["config"] is not None
except KeyError:
raise serializers.ValidationError(
"You must provide a configuration for this radio"
)
return data
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
return filters.run([kwargs["config"]], candidates=qs)
class RelatedObjectRadio(SessionRadio):
"""Abstract radio related to an object (tag, artist, user...)"""
related_object_field = serializers.IntegerField(required=True)
def clean(self, instance):
super().clean(instance)
if not instance.related_object:
raise ValidationError(
"Cannot start RelatedObjectRadio without related object"
)
if not isinstance(instance.related_object, self.model):
raise ValidationError("Trying to start radio with bad related object")
def get_related_object(self, pk):
return self.model.objects.get(pk=pk)
@registry.register(name="tag")
class TagRadio(RelatedObjectRadio):
model = Tag
related_object_field = serializers.CharField(required=True)
def get_related_object(self, name):
return self.model.objects.get(name=name)
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
query = (
Q(tagged_items__tag=self.session.related_object)
| Q(artist__tagged_items__tag=self.session.related_object)
| Q(album__tagged_items__tag=self.session.related_object)
)
return qs.filter(query)
def get_related_object_id_repr(self, obj):
return obj.name
def weighted_choice(choices):
total = sum(w for c, w in choices)
r = random.uniform(0, total)
upto = 0
for c, w in choices:
if upto + w >= r:
return c
upto += w
assert False, "Shouldn't get here"
class NextNotFound(Exception):
pass
@registry.register(name="similar")
class SimilarRadio(RelatedObjectRadio):
model = Track
def filter_queryset(self, queryset):
queryset = super().filter_queryset(queryset)
seeds = list(
self.session.session_tracks.all()
.values_list("track_id", flat=True)
.order_by("-id")[:3]
) + [self.session.related_object.pk]
for seed in seeds:
try:
return queryset.filter(pk=self.find_next_id(queryset, seed))
except NextNotFound:
continue
return queryset.none()
def find_next_id(self, queryset, seed):
with connection.cursor() as cursor:
query = """
SELECT next, count(next) AS c
FROM (
SELECT
track_id,
creation_date,
LEAD(track_id) OVER (
PARTITION by user_id order by creation_date asc
) AS next
FROM history_listening
INNER JOIN users_user ON (users_user.id = user_id)
WHERE users_user.privacy_level = 'instance' OR users_user.privacy_level = 'everyone' OR user_id = %s
ORDER BY creation_date ASC
) t WHERE track_id = %s AND next != %s GROUP BY next ORDER BY c DESC;
"""
cursor.execute(query, [self.session.user_id, seed, seed])
next_candidates = list(cursor.fetchall())
if not next_candidates:
raise NextNotFound()
matching_tracks = list(
queryset.filter(pk__in=[c[0] for c in next_candidates]).values_list(
"id", flat=True
)
)
next_candidates = [n for n in next_candidates if n[0] in matching_tracks]
if not next_candidates:
raise NextNotFound()
return random.choice([c[0] for c in next_candidates])
@registry.register(name="artist")
class ArtistRadio(RelatedObjectRadio):
model = Artist
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
return qs.filter(artist=self.session.related_object)
@registry.register(name="less-listened")
class LessListenedRadio(SessionRadio):
def clean(self, instance):
instance.related_object = instance.user
super().clean(instance)
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
listened = self.session.user.listenings.all().values_list("track", flat=True)
return (
qs.filter(artist__content_category="music")
.exclude(pk__in=listened)
.order_by("?")
)
@registry.register(name="less-listened_library")
class LessListenedLibraryRadio(SessionRadio):
def clean(self, instance):
instance.related_object = instance.user
super().clean(instance)
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
listened = self.session.user.listenings.all().values_list("track", flat=True)
tracks_ids = self.session.user.actor.attributed_tracks.all().values_list(
"id", flat=True
)
query = Q(artist__content_category="music") & Q(pk__in=tracks_ids)
return qs.filter(query).exclude(pk__in=listened).order_by("?")
@registry.register(name="actor-content")
class ActorContentRadio(RelatedObjectRadio):
"""
Play content from given actor libraries
"""
model = federation_models.Actor
related_object_field = federation_fields.ActorRelatedField(required=True)
def get_related_object(self, value):
return value
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
actor_uploads = Upload.objects.filter(
library__actor=self.session.related_object,
)
return qs.filter(pk__in=actor_uploads.values("track"))
def get_related_object_id_repr(self, obj):
return obj.full_username
@registry.register(name="library")
class LibraryRadio(RelatedObjectRadio):
"""
Play content from a given library
"""
model = Library
related_object_field = serializers.UUIDField(required=True)
def get_related_object(self, value):
return Library.objects.get(uuid=value)
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
actor_uploads = Upload.objects.filter(
library=self.session.related_object,
)
return qs.filter(pk__in=actor_uploads.values("track"))
def get_related_object_id_repr(self, obj):
return obj.uuid
@registry.register(name="recently-added")
class RecentlyAdded(SessionRadio):
def get_queryset(self, **kwargs):
date = datetime.date.today() - datetime.timedelta(days=30)
qs = super().get_queryset(**kwargs)
return qs.filter(
Q(artist__content_category="music"),
Q(creation_date__gt=date),
)
# Use this to experiment on the custom multiple radio with troi
@registry.register(name="troi")
class Troi(SessionRadio):
"""
Receive a vuejs generated config and use it to launch a troi radio session.
The config data should follow :
{"patch": "troi_patch_name", "troi_arg1":"troi_arg_1", "troi_arg2": ...}
Validation of the config (args) is done by troi during track fetch.
Funkwhale only checks if the patch is implemented
"""
config = serializers.JSONField(required=True)
def append_lb_config(self, data):
if self.session.user.settings is None:
logger.warning(
"No lb_user_name set in user settings. Some troi patches will fail"
)
return data
elif self.session.user.settings.get("lb_user_name") is None:
logger.warning(
"No lb_user_name set in user settings. Some troi patches will fail"
)
else:
data["user_name"] = self.session.user.settings["lb_user_name"]
if self.session.user.settings.get("lb_user_token") is None:
logger.warning(
"No lb_user_token set in user settings. Some troi patch will fail"
)
else:
data["user_token"] = self.session.user.settings["lb_user_token"]
return data
def get_queryset_kwargs(self):
kwargs = super().get_queryset_kwargs()
kwargs["config"] = self.session.config
return kwargs
def validate_session(self, data, **context):
data = super().validate_session(data, **context)
if data.get("config") is None:
raise serializers.ValidationError(
"You must provide a configuration for this radio"
)
return data
def get_queryset(self, **kwargs):
qs = super().get_queryset(**kwargs)
config = self.append_lb_config(json.loads(kwargs["config"]))
return lb_recommendations.run(config, candidates=qs)

Wyświetl plik

@ -0,0 +1,10 @@
import persisting_theory
class RadioRegistry_v2(persisting_theory.Registry):
def prepare_name(self, data, name=None):
setattr(data, "radio_type", name)
return name
registry = RadioRegistry_v2()

Wyświetl plik

@ -40,9 +40,11 @@ class RadioSerializer(serializers.ModelSerializer):
class RadioSessionTrackSerializerCreate(serializers.ModelSerializer):
count = serializers.IntegerField(required=False, allow_null=True)
class Meta:
model = models.RadioSessionTrack
fields = ("session",)
fields = ("session", "count")
class RadioSessionTrackSerializer(serializers.ModelSerializer):

Wyświetl plik

@ -5,7 +5,7 @@ from . import views
router = routers.OptionalSlashRouter()
router.register(r"sessions", views.RadioSessionViewSet, "sessions")
router.register(r"radios", views.RadioViewSet, "radios")
router.register(r"tracks", views.RadioSessionTrackViewSet, "tracks")
router.register(r"tracks", views.V1_RadioSessionTrackViewSet, "tracks")
urlpatterns = router.urls

Wyświetl plik

@ -0,0 +1,10 @@
from funkwhale_api.common import routers
from . import views
router = routers.OptionalSlashRouter()
router.register(r"sessions", views.V2_RadioSessionViewSet, "sessions")
urlpatterns = router.urls

Wyświetl plik

@ -1,3 +1,6 @@
import pickle
from django.core.cache import cache
from django.db.models import Q
from drf_spectacular.utils import extend_schema
from rest_framework import mixins, status, viewsets
@ -121,7 +124,7 @@ class RadioSessionViewSet(
return context
class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
class V1_RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
serializer_class = serializers.RadioSessionTrackSerializer
queryset = models.RadioSessionTrack.objects.all()
permission_classes = []
@ -133,21 +136,19 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
session = serializer.validated_data["session"]
if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create()
try:
assert (request.user == session.user) or (
request.session.session_key == session.session_key
and session.session_key
)
except AssertionError:
if not request.user == session.user or (
not request.session.session_key == session.session_key
and not session.session_key
):
return Response(status=status.HTTP_403_FORBIDDEN)
try:
session.radio.pick()
session.radio(api_version=1).pick()
except ValueError:
return Response(
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
)
session_track = session.session_tracks.all().latest("id")
# self.perform_create(serializer)
# dirty override here, since we use a different serializer for creation and detail
serializer = self.serializer_class(
instance=session_track, context=self.get_serializer_context()
@ -161,3 +162,99 @@ class RadioSessionTrackViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet)
if self.action == "create":
return serializers.RadioSessionTrackSerializerCreate
return super().get_serializer_class(*args, **kwargs)
class V2_RadioSessionViewSet(
mixins.CreateModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet
):
"""Returns a list of RadioSessions"""
serializer_class = serializers.RadioSessionSerializer
queryset = models.RadioSession.objects.all()
permission_classes = []
@action(detail=True, serializer_class=serializers.RadioSessionTrackSerializerCreate)
def tracks(self, request, pk, *args, **kwargs):
data = {"session": pk}
data["count"] = (
request.query_params["count"]
if "count" in request.query_params.keys()
else 1
)
serializer = serializers.RadioSessionTrackSerializerCreate(data=data)
serializer.is_valid(raise_exception=True)
session = serializer.validated_data["session"]
count = int(data["count"])
# this is used for test purpose.
filter_playable = (
request.query_params["filter_playable"]
if "filter_playable" in request.query_params.keys()
else True
)
if not request.user.is_authenticated and not request.session.session_key:
self.request.session.create()
if not request.user == session.user or (
not request.session.session_key == session.session_key
and not session.session_key
):
return Response(status=status.HTTP_403_FORBIDDEN)
try:
from . import radios_v2 # noqa
session.radio(api_version=2).pick_many(
count, filter_playable=filter_playable
)
except ValueError:
return Response(
"Radio doesn't have more candidates", status=status.HTTP_404_NOT_FOUND
)
# dirty override here, since we use a different serializer for creation and detail
evaluated_radio_tracks = pickle.loads(cache.get(f"radiotracks{session.id}"))
batch = evaluated_radio_tracks[:count]
serializer = TrackSerializer(
data=batch,
many="true",
)
serializer.is_valid()
# delete the tracks we sent from the cache
new_cached_radiotracks = evaluated_radio_tracks[count:]
cache.set(f"radiotracks{session.id}", pickle.dumps(new_cached_radiotracks))
return Response(
serializer.data,
status=status.HTTP_201_CREATED,
)
def get_queryset(self):
queryset = super().get_queryset()
if self.request.user.is_authenticated:
return queryset.filter(
Q(user=self.request.user)
| Q(session_key=self.request.session.session_key)
)
return queryset.filter(session_key=self.request.session.session_key).exclude(
session_key=None
)
def perform_create(self, serializer):
if (
not self.request.user.is_authenticated
and not self.request.session.session_key
):
self.request.session.create()
return serializer.save(
user=self.request.user if self.request.user.is_authenticated else None,
session_key=self.request.session.session_key,
)
def get_serializer_context(self):
context = super().get_serializer_context()
context["user"] = (
self.request.user if self.request.user.is_authenticated else None
)
return context

Wyświetl plik

@ -2,8 +2,8 @@ import json
import random
import pytest
from django.core.exceptions import ValidationError
from django.urls import reverse
from rest_framework.exceptions import ValidationError
from funkwhale_api.favorites.models import TrackFavorite
from funkwhale_api.radios import models, radios, serializers
@ -98,7 +98,7 @@ def test_can_get_choices_for_custom_radio(factories):
session = factories["radios.CustomRadioSession"](
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio(api_version=1).get_choices(filter_playable=False)
expected = [t.pk for t in tracks]
assert list(choices.values_list("id", flat=True)) == expected
@ -191,16 +191,17 @@ def test_can_get_track_for_session_from_api(factories, logged_in_api_client):
def test_related_object_radio_validate_related_object(factories):
user = factories["users.User"]()
# cannot start without related object
radio = radios.ArtistRadio()
radio = {"radio_type": "tag"}
serializer = serializers.RadioSessionSerializer()
with pytest.raises(ValidationError):
radio.start_session(user)
serializer.validate(data=radio)
# cannot start with bad related object type
radio = radios.ArtistRadio()
radio = {"radio_type": "tag", "related_object": "whatever"}
serializer = serializers.RadioSessionSerializer()
with pytest.raises(ValidationError):
radio.start_session(user, related_object=user)
serializer.validate(data=radio)
def test_can_start_artist_radio(factories):
@ -391,7 +392,7 @@ def test_get_choices_for_custom_radio_exclude_artist(factories):
{"type": "artist", "ids": [excluded_artist.pk], "not": True},
]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio(api_version=1).get_choices(filter_playable=False)
expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected
@ -409,7 +410,7 @@ def test_get_choices_for_custom_radio_exclude_tag(factories):
{"type": "tag", "names": ["rock"], "not": True},
]
)
choices = session.radio.get_choices(filter_playable=False)
choices = session.radio(api_version=1).get_choices(filter_playable=False)
expected = [u.track.pk for u in included_uploads]
assert list(choices.values_list("id", flat=True)) == expected
@ -429,28 +430,3 @@ def test_can_start_custom_multiple_radio_from_api(api_client, factories):
format="json",
)
assert response.status_code == 201
def test_can_start_periodic_jams_troi_radio_from_api(api_client, factories):
factories["music.Track"].create_batch(5)
url = reverse("api:v1:radios:sessions-list")
config = {"patch": "periodic-jams", "type": "daily-jams"}
response = api_client.post(
url,
{"radio_type": "troi", "config": config},
format="json",
)
assert response.status_code == 201
# to do : send error to api ?
def test_can_catch_troi_radio_error(api_client, factories):
factories["music.Track"].create_batch(5)
url = reverse("api:v1:radios:sessions-list")
config = {"patch": "periodic-jams", "type": "not_existing_type"}
response = api_client.post(
url,
{"radio_type": "troi", "config": config},
format="json",
)
assert response.status_code == 201

Wyświetl plik

@ -0,0 +1,144 @@
import json
import logging
import pickle
import random
from django.core.cache import cache
from django.urls import reverse
from funkwhale_api.favorites.models import TrackFavorite
from funkwhale_api.radios import models, radios_v2
def test_can_get_track_for_session_from_api_v2(factories, logged_in_api_client):
actor = logged_in_api_client.user.create_actor()
track = factories["music.Upload"](
library__actor=actor, import_status="finished"
).track
url = reverse("api:v2:radios:sessions-list")
response = logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id")
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
response = logged_in_api_client.get(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8"))
assert data[0]["id"] == track.pk
next_track = factories["music.Upload"](
library__actor=actor, import_status="finished"
).track
response = logged_in_api_client.get(url, {"session": session.pk})
data = json.loads(response.content.decode("utf-8"))
assert data[0]["id"] == next_track.id
def test_can_use_radio_session_to_filter_choices_v2(factories):
factories["music.Upload"].create_batch(10)
user = factories["users.User"]()
radio = radios_v2.RandomRadio()
session = radio.start_session(user)
radio.pick_many(quantity=10, filter_playable=False)
# ensure 10 different tracks have been suggested
tracks_id = [
session_track.track.pk for session_track in session.session_tracks.all()
]
assert len(set(tracks_id)) == 10
def test_session_radio_excludes_previous_picks_v2(factories, logged_in_api_client):
tracks = factories["music.Track"].create_batch(5)
url = reverse("api:v2:radios:sessions-list")
response = logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id")
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
previous_choices = []
for i in range(5):
response = logged_in_api_client.get(
url, {"session": session.pk, "filter_playable": False}
)
pick = json.loads(response.content.decode("utf-8"))
assert pick[0]["title"] not in previous_choices
assert pick[0]["title"] in [t.title for t in tracks]
previous_choices.append(pick[0]["title"])
response = logged_in_api_client.get(url, {"session": session.pk})
assert (
json.loads(response.content.decode("utf-8"))
== "Radio doesn't have more candidates"
)
def test_can_get_choices_for_favorites_radio_v2(factories):
files = factories["music.Upload"].create_batch(10)
tracks = [f.track for f in files]
user = factories["users.User"]()
for i in range(5):
TrackFavorite.add(track=random.choice(tracks), user=user)
radio = radios_v2.FavoritesRadio()
session = radio.start_session(user=user)
choices = session.radio(api_version=2).get_choices(
quantity=100, filter_playable=False
)
assert len(choices) == user.track_favorites.all().count()
for favorite in user.track_favorites.all():
assert favorite.track in choices
def test_can_get_choices_for_custom_radio_v2(factories):
artist = factories["music.Artist"]()
files = factories["music.Upload"].create_batch(5, track__artist=artist)
tracks = [f.track for f in files]
factories["music.Upload"].create_batch(5)
session = factories["radios.CustomRadioSession"](
custom_radio__config=[{"type": "artist", "ids": [artist.pk]}]
)
choices = session.radio(api_version=2).get_choices(
quantity=1, filter_playable=False
)
expected = [t.pk for t in tracks]
for t in choices:
assert t.id in expected
def test_can_cache_radio_track(factories):
uploads = factories["music.Track"].create_batch(10)
user = factories["users.User"]()
radio = radios_v2.RandomRadio()
session = radio.start_session(user)
picked = session.radio(api_version=2).pick_many(quantity=1, filter_playable=False)
assert len(picked) == 1
for t in pickle.loads(cache.get(f"radiotracks{session.id}")):
assert t in uploads
def test_regenerate_cache_if_not_enought_tracks_in_it(
factories, caplog, logged_in_api_client
):
logger = logging.getLogger("funkwhale_api.radios.radios_v2")
caplog.set_level(logging.INFO)
logger.addHandler(caplog.handler)
factories["music.Track"].create_batch(10)
factories["users.User"]()
url = reverse("api:v2:radios:sessions-list")
response = logged_in_api_client.post(url, {"radio_type": "random"})
session = models.RadioSession.objects.latest("id")
url = reverse("api:v2:radios:sessions-tracks", kwargs={"pk": session.pk})
logged_in_api_client.get(url, {"count": 9, "filter_playable": False})
response = logged_in_api_client.get(url, {"count": 10, "filter_playable": False})
pick = json.loads(response.content.decode("utf-8"))
assert (
"Not enough radio tracks in cache. Trying to generate new cache" in caplog.text
)
assert len(pick) == 1

Wyświetl plik

@ -0,0 +1 @@
Cache radio queryset into redis. New radio track endpoint for api v2 is /api/v2/radios/sessions/{radiosessionid}/tracks (#2135)

Wyświetl plik

@ -98,6 +98,8 @@ services:
env_file:
- .env
image: typesense/typesense:0.24.0
networks:
- internal
volumes:
- ./typesense/data:/data
command: --data-dir /data --enable-cors