2017-06-23 21:00:42 +00:00
|
|
|
import random
|
2018-06-10 08:55:16 +00:00
|
|
|
|
2017-06-23 21:00:42 +00:00
|
|
|
from django.core.exceptions import ValidationError
|
2018-06-10 08:55:16 +00:00
|
|
|
from django.db.models import Count
|
|
|
|
from rest_framework import serializers
|
2017-06-23 21:00:42 +00:00
|
|
|
from taggit.models import Tag
|
2018-06-10 08:55:16 +00:00
|
|
|
|
|
|
|
from funkwhale_api.music.models import Artist, Track
|
2017-06-23 21:00:42 +00:00
|
|
|
from funkwhale_api.users.models import User
|
2018-01-07 21:13:32 +00:00
|
|
|
|
2018-06-10 08:55:16 +00:00
|
|
|
from . import filters, models
|
2017-06-23 21:00:42 +00:00
|
|
|
from .registries import registry
|
|
|
|
|
2018-01-07 21:13:32 +00:00
|
|
|
|
2017-06-23 21:00:42 +00:00
|
|
|
class SimpleRadio(object):
|
|
|
|
def clean(self, instance):
|
|
|
|
return
|
|
|
|
|
|
|
|
def pick(self, choices, previous_choices=[]):
|
|
|
|
return random.sample(set(choices).difference(previous_choices), 1)[0]
|
|
|
|
|
|
|
|
def pick_many(self, choices, quantity):
|
|
|
|
return random.sample(set(choices), quantity)
|
|
|
|
|
|
|
|
def weighted_pick(self, choices, previous_choices=[]):
|
|
|
|
total = sum(weight for c, weight in choices)
|
|
|
|
r = random.uniform(0, total)
|
|
|
|
upto = 0
|
|
|
|
for choice, weight in choices:
|
|
|
|
if upto + weight >= r:
|
|
|
|
return choice
|
|
|
|
upto += weight
|
|
|
|
|
|
|
|
|
|
|
|
class SessionRadio(SimpleRadio):
|
|
|
|
def __init__(self, session=None):
|
|
|
|
self.session = session
|
|
|
|
|
|
|
|
def start_session(self, user, **kwargs):
|
2018-06-09 13:36:16 +00:00
|
|
|
self.session = models.RadioSession.objects.create(
|
|
|
|
user=user, radio_type=self.radio_type, **kwargs
|
|
|
|
)
|
2017-06-23 21:00:42 +00:00
|
|
|
return self.session
|
|
|
|
|
2018-02-27 17:35:54 +00:00
|
|
|
def get_queryset(self, **kwargs):
|
2018-09-22 12:29:30 +00:00
|
|
|
qs = Track.objects.annotate(uploads_count=Count("uploads"))
|
|
|
|
return qs.filter(uploads_count__gt=0)
|
2017-06-23 21:00:42 +00:00
|
|
|
|
|
|
|
def get_queryset_kwargs(self):
|
|
|
|
return {}
|
|
|
|
|
|
|
|
def get_choices(self, **kwargs):
|
|
|
|
kwargs.update(self.get_queryset_kwargs())
|
|
|
|
queryset = self.get_queryset(**kwargs)
|
|
|
|
if self.session:
|
|
|
|
queryset = self.filter_from_session(queryset)
|
|
|
|
return queryset
|
|
|
|
|
|
|
|
def filter_from_session(self, queryset):
|
2018-06-09 13:36:16 +00:00
|
|
|
already_played = self.session.session_tracks.all().values_list(
|
|
|
|
"track", flat=True
|
|
|
|
)
|
2018-01-07 21:13:32 +00:00
|
|
|
queryset = queryset.exclude(pk__in=already_played)
|
2017-06-23 21:00:42 +00:00
|
|
|
return queryset
|
|
|
|
|
|
|
|
def pick(self, **kwargs):
|
|
|
|
return self.pick_many(quantity=1, **kwargs)[0]
|
|
|
|
|
|
|
|
def pick_many(self, quantity, **kwargs):
|
|
|
|
choices = self.get_choices(**kwargs)
|
|
|
|
picked_choices = super().pick_many(choices=choices, quantity=quantity)
|
|
|
|
if self.session:
|
|
|
|
for choice in picked_choices:
|
|
|
|
self.session.add(choice)
|
|
|
|
return picked_choices
|
|
|
|
|
2018-01-07 21:13:32 +00:00
|
|
|
def validate_session(self, data, **context):
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
@registry.register(name="random")
|
2017-06-23 21:00:42 +00:00
|
|
|
class RandomRadio(SessionRadio):
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
2018-06-09 13:36:16 +00:00
|
|
|
return qs.order_by("?")
|
2018-02-27 17:35:54 +00:00
|
|
|
|
2017-06-23 21:00:42 +00:00
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
@registry.register(name="favorites")
|
2017-06-23 21:00:42 +00:00
|
|
|
class FavoritesRadio(SessionRadio):
|
|
|
|
def get_queryset_kwargs(self):
|
|
|
|
kwargs = super().get_queryset_kwargs()
|
|
|
|
if self.session:
|
2018-06-09 13:36:16 +00:00
|
|
|
kwargs["user"] = self.session.user
|
2017-06-23 21:00:42 +00:00
|
|
|
return kwargs
|
|
|
|
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
2018-06-09 13:36:16 +00:00
|
|
|
track_ids = kwargs["user"].track_favorites.all().values_list("track", flat=True)
|
2018-02-27 17:35:54 +00:00
|
|
|
return qs.filter(pk__in=track_ids)
|
2017-06-23 21:00:42 +00:00
|
|
|
|
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
@registry.register(name="custom")
|
2018-01-07 21:13:32 +00:00
|
|
|
class CustomRadio(SessionRadio):
|
|
|
|
def get_queryset_kwargs(self):
|
|
|
|
kwargs = super().get_queryset_kwargs()
|
2018-06-09 13:36:16 +00:00
|
|
|
kwargs["user"] = self.session.user
|
|
|
|
kwargs["custom_radio"] = self.session.custom_radio
|
2018-01-07 21:13:32 +00:00
|
|
|
return kwargs
|
|
|
|
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
2018-06-09 13:36:16 +00:00
|
|
|
return filters.run(kwargs["custom_radio"].config, candidates=qs)
|
2018-01-07 21:13:32 +00:00
|
|
|
|
|
|
|
def validate_session(self, data, **context):
|
|
|
|
data = super().validate_session(data, **context)
|
|
|
|
try:
|
2018-06-09 13:36:16 +00:00
|
|
|
user = data["user"]
|
2018-01-07 21:13:32 +00:00
|
|
|
except KeyError:
|
2018-06-09 13:36:16 +00:00
|
|
|
user = context["user"]
|
2018-01-07 21:13:32 +00:00
|
|
|
try:
|
2018-06-09 13:36:16 +00:00
|
|
|
assert data["custom_radio"].user == user or data["custom_radio"].is_public
|
2018-01-07 21:13:32 +00:00
|
|
|
except KeyError:
|
2018-06-09 13:36:16 +00:00
|
|
|
raise serializers.ValidationError("You must provide a custom radio")
|
2018-01-07 21:13:32 +00:00
|
|
|
except AssertionError:
|
2018-06-09 13:36:16 +00:00
|
|
|
raise serializers.ValidationError("You don't have access to this radio")
|
2018-01-07 21:13:32 +00:00
|
|
|
return data
|
|
|
|
|
|
|
|
|
2017-06-23 21:00:42 +00:00
|
|
|
class RelatedObjectRadio(SessionRadio):
|
|
|
|
"""Abstract radio related to an object (tag, artist, user...)"""
|
|
|
|
|
|
|
|
def clean(self, instance):
|
|
|
|
super().clean(instance)
|
|
|
|
if not instance.related_object:
|
2018-06-09 13:36:16 +00:00
|
|
|
raise ValidationError(
|
|
|
|
"Cannot start RelatedObjectRadio without related object"
|
|
|
|
)
|
2017-06-23 21:00:42 +00:00
|
|
|
if not isinstance(instance.related_object, self.model):
|
2018-06-09 13:36:16 +00:00
|
|
|
raise ValidationError("Trying to start radio with bad related object")
|
2017-06-23 21:00:42 +00:00
|
|
|
|
|
|
|
def get_related_object(self, pk):
|
|
|
|
return self.model.objects.get(pk=pk)
|
|
|
|
|
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
@registry.register(name="tag")
|
2017-06-23 21:00:42 +00:00
|
|
|
class TagRadio(RelatedObjectRadio):
|
|
|
|
model = Tag
|
|
|
|
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
2018-06-09 15:41:59 +00:00
|
|
|
return qs.filter(tags__in=[self.session.related_object])
|
2017-06-23 21:00:42 +00:00
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
|
|
|
|
@registry.register(name="artist")
|
2017-06-23 21:00:42 +00:00
|
|
|
class ArtistRadio(RelatedObjectRadio):
|
|
|
|
model = Artist
|
|
|
|
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
|
|
|
return qs.filter(artist=self.session.related_object)
|
2017-06-23 21:00:42 +00:00
|
|
|
|
|
|
|
|
2018-06-09 13:36:16 +00:00
|
|
|
@registry.register(name="less-listened")
|
2017-06-23 21:00:42 +00:00
|
|
|
class LessListenedRadio(RelatedObjectRadio):
|
|
|
|
model = User
|
|
|
|
|
|
|
|
def clean(self, instance):
|
|
|
|
instance.related_object = instance.user
|
|
|
|
super().clean(instance)
|
|
|
|
|
|
|
|
def get_queryset(self, **kwargs):
|
2018-02-27 17:35:54 +00:00
|
|
|
qs = super().get_queryset(**kwargs)
|
2018-06-09 13:36:16 +00:00
|
|
|
listened = self.session.user.listenings.all().values_list("track", flat=True)
|
|
|
|
return qs.exclude(pk__in=listened).order_by("?")
|