diff --git a/wagtail/core/models/__init__.py b/wagtail/core/models/__init__.py index 5199b26063..a1651505db 100644 --- a/wagtail/core/models/__init__.py +++ b/wagtail/core/models/__init__.py @@ -55,6 +55,7 @@ from wagtail.core.utils import ( from wagtail.search import index from .sites import Site, SiteManager, SiteRootPath # noqa +from .view_restrictions import BaseViewRestriction logger = logging.getLogger('wagtail.core') @@ -3433,64 +3434,6 @@ class PagePermissionTester: return not self.page_is_root -class BaseViewRestriction(models.Model): - NONE = 'none' - PASSWORD = 'password' - GROUPS = 'groups' - LOGIN = 'login' - - RESTRICTION_CHOICES = ( - (NONE, _("Public")), - (LOGIN, _("Private, accessible to logged-in users")), - (PASSWORD, _("Private, accessible with the following password")), - (GROUPS, _("Private, accessible to users in specific groups")), - ) - - restriction_type = models.CharField( - max_length=20, choices=RESTRICTION_CHOICES) - password = models.CharField(verbose_name=_('password'), max_length=255, blank=True) - groups = models.ManyToManyField(Group, verbose_name=_('groups'), blank=True) - - def accept_request(self, request): - if self.restriction_type == BaseViewRestriction.PASSWORD: - passed_restrictions = request.session.get(self.passed_view_restrictions_session_key, []) - if self.id not in passed_restrictions: - return False - - elif self.restriction_type == BaseViewRestriction.LOGIN: - if not request.user.is_authenticated: - return False - - elif self.restriction_type == BaseViewRestriction.GROUPS: - if not request.user.is_superuser: - current_user_groups = request.user.groups.all() - - if not any(group in current_user_groups for group in self.groups.all()): - return False - - return True - - def mark_as_passed(self, request): - """ - Update the session data in the request to mark the user as having passed this - view restriction - """ - has_existing_session = (settings.SESSION_COOKIE_NAME in request.COOKIES) - passed_restrictions = request.session.setdefault(self.passed_view_restrictions_session_key, []) - if self.id not in passed_restrictions: - passed_restrictions.append(self.id) - request.session[self.passed_view_restrictions_session_key] = passed_restrictions - if not has_existing_session: - # if this is a session we've created, set it to expire at the end - # of the browser session - request.session.set_expiry(0) - - class Meta: - abstract = True - verbose_name = _('view restriction') - verbose_name_plural = _('view restrictions') - - class PageViewRestriction(BaseViewRestriction): page = models.ForeignKey( 'Page', verbose_name=_('page'), related_name='view_restrictions', on_delete=models.CASCADE diff --git a/wagtail/core/models/view_restrictions.py b/wagtail/core/models/view_restrictions.py new file mode 100644 index 0000000000..e7619dc993 --- /dev/null +++ b/wagtail/core/models/view_restrictions.py @@ -0,0 +1,62 @@ +from django.conf import settings +from django.contrib.auth.models import Group +from django.db import models +from django.utils.translation import gettext_lazy as _ + + +class BaseViewRestriction(models.Model): + NONE = 'none' + PASSWORD = 'password' + GROUPS = 'groups' + LOGIN = 'login' + + RESTRICTION_CHOICES = ( + (NONE, _("Public")), + (LOGIN, _("Private, accessible to logged-in users")), + (PASSWORD, _("Private, accessible with the following password")), + (GROUPS, _("Private, accessible to users in specific groups")), + ) + + restriction_type = models.CharField( + max_length=20, choices=RESTRICTION_CHOICES) + password = models.CharField(verbose_name=_('password'), max_length=255, blank=True) + groups = models.ManyToManyField(Group, verbose_name=_('groups'), blank=True) + + def accept_request(self, request): + if self.restriction_type == BaseViewRestriction.PASSWORD: + passed_restrictions = request.session.get(self.passed_view_restrictions_session_key, []) + if self.id not in passed_restrictions: + return False + + elif self.restriction_type == BaseViewRestriction.LOGIN: + if not request.user.is_authenticated: + return False + + elif self.restriction_type == BaseViewRestriction.GROUPS: + if not request.user.is_superuser: + current_user_groups = request.user.groups.all() + + if not any(group in current_user_groups for group in self.groups.all()): + return False + + return True + + def mark_as_passed(self, request): + """ + Update the session data in the request to mark the user as having passed this + view restriction + """ + has_existing_session = (settings.SESSION_COOKIE_NAME in request.COOKIES) + passed_restrictions = request.session.setdefault(self.passed_view_restrictions_session_key, []) + if self.id not in passed_restrictions: + passed_restrictions.append(self.id) + request.session[self.passed_view_restrictions_session_key] = passed_restrictions + if not has_existing_session: + # if this is a session we've created, set it to expire at the end + # of the browser session + request.session.set_expiry(0) + + class Meta: + abstract = True + verbose_name = _('view restriction') + verbose_name_plural = _('view restrictions')