diff --git a/tests/users/models/test_domain.py b/tests/users/models/test_domain.py index e8e6860..408fdf8 100644 --- a/tests/users/models/test_domain.py +++ b/tests/users/models/test_domain.py @@ -3,6 +3,36 @@ import pytest from users.models import Domain +def test_valid_domain(): + """ + Tests that a valid domain is valid + """ + + assert Domain.is_valid_domain("example.com") + assert Domain.is_valid_domain("xn----gtbspbbmkef.xn--p1ai") + assert Domain.is_valid_domain("underscore_subdomain.example.com") + assert Domain.is_valid_domain("something.versicherung") + assert Domain.is_valid_domain("11.com") + assert Domain.is_valid_domain("a.cn") + assert Domain.is_valid_domain("sub1.sub2.sample.co.uk") + assert Domain.is_valid_domain("somerandomexample.xn--fiqs8s") + assert not Domain.is_valid_domain("über.com") + assert not Domain.is_valid_domain("example.com:4444") + assert not Domain.is_valid_domain("example.-com") + assert not Domain.is_valid_domain("foo@bar.com") + assert not Domain.is_valid_domain("example.") + assert not Domain.is_valid_domain("example.com.") + assert not Domain.is_valid_domain("-example.com") + assert not Domain.is_valid_domain("_example.com") + assert not Domain.is_valid_domain("_example._com") + assert not Domain.is_valid_domain("example_.com") + assert not Domain.is_valid_domain("example") + assert not Domain.is_valid_domain("a......b.com") + assert not Domain.is_valid_domain("a.123") + assert not Domain.is_valid_domain("123.123") + assert not Domain.is_valid_domain("123.123.123.123") + + @pytest.mark.django_db def test_recursive_block(): """ diff --git a/users/models/domain.py b/users/models/domain.py index d9fe38f..eb82436 100644 --- a/users/models/domain.py +++ b/users/models/domain.py @@ -1,5 +1,6 @@ import json import logging +import re import ssl from functools import cached_property from typing import Optional @@ -8,6 +9,7 @@ import httpx import pydantic import urlman from django.conf import settings +from django.core.exceptions import ValidationError from django.db import models from core.models import Config @@ -53,6 +55,14 @@ class DomainStates(StateGraph): return cls.outdated +def _domain_validator(value: str): + if not Domain.is_valid_domain(value): + raise ValidationError( + "%(value)s is not a valid domain", + params={"value": value}, + ) + + class Domain(StatorModel): """ Represents a domain that a user can have an account on. @@ -71,7 +81,9 @@ class Domain(StatorModel): display domains for now, until we start doing better probing. """ - domain = models.CharField(max_length=250, primary_key=True) + domain = models.CharField( + max_length=250, primary_key=True, validators=[_domain_validator] + ) service_domain = models.CharField( max_length=250, null=True, @@ -119,6 +131,19 @@ class Domain(StatorModel): class Meta: indexes: list = [] + @classmethod + def is_valid_domain(cls, domain: str) -> bool: + """ + Check if a domain is valid, domain must be lowercase + """ + return ( + re.match( + r"^(?:[a-z0-9](?:[a-z0-9-_]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-_]{0,61}[a-z]$", + domain, + ) + is not None + ) + @classmethod def get_remote_domain(cls, domain: str) -> "Domain": return cls.objects.get_or_create(domain=domain.lower(), local=False)[0] diff --git a/users/shortcuts.py b/users/shortcuts.py index 7357ee9..7621e90 100644 --- a/users/shortcuts.py +++ b/users/shortcuts.py @@ -18,6 +18,8 @@ def by_handle_or_404(request, handle, local=True, fetch=False) -> Identity: domain = domain_instance.domain else: username, domain = handle.split("@", 1) + if not Domain.is_valid_domain(domain): + raise Http404("Invalid domain") # Resolve the domain to the display domain domain_instance = Domain.get_domain(domain) if domain_instance is None: