federation/federation/entities/mixins.py

313 wiersze
11 KiB
Python
Czysty Zwykły widok Historia

import datetime
import importlib
import re
import warnings
from typing import List, Set, Union, Dict, Tuple
from commonmark import commonmark
2018-09-25 20:50:34 +00:00
from federation.entities.activitypub.enums import ActivityType
from federation.entities.utils import get_name_for_profile
from federation.utils.text import process_text_links, find_tags
2018-09-25 20:50:34 +00:00
class BaseEntity:
2018-09-25 20:50:34 +00:00
_allowed_children: tuple = ()
_children: List = None
_mentions: Set = None
_receivers: List = None
2018-09-25 20:50:34 +00:00
_source_protocol: str = ""
# Contains the original object from payload as a string
_source_object: Union[str, Dict] = None
2018-09-25 20:50:34 +00:00
_sender_key: str = ""
# ActivityType
activity: ActivityType = None
activity_id: str = ""
actor_id: str = ""
# Server base url
2018-09-25 20:50:34 +00:00
base_url: str = ""
guid: str = ""
handle: str = ""
id: str = ""
mxid: str = ""
2018-09-25 20:50:34 +00:00
signature: str = ""
def __init__(self, *args, **kwargs):
self._required = ["id", "actor_id"]
self._children = []
self._mentions = set()
self._receivers = []
# make the assumption that if a schema is being used, the payload
# is deserialized and validated properly
if kwargs.get('has_schema'):
for key, value in kwargs.items():
setattr(self, key, value)
else:
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
else:
warnings.warn("%s.__init__ got parameter %s which this class does not support - ignoring." % (
self.__class__.__name__, key
))
if not self.activity:
# Fill a default activity if not given and type of entity class has one
self.activity = getattr(self, "_default_activity", None)
def as_protocol(self, protocol):
entities = importlib.import_module(f"federation.entities.{protocol}.entities")
klass = getattr(entities, f"{protocol.title()}{self.__class__.__name__}")
return klass.from_base(self)
def post_receive(self):
"""
Run any actions after deserializing the payload into an entity.
"""
pass
def pre_send(self):
"""
Run any actions before serializing the entity for sending.
"""
pass
def validate(self, direction: str = "inbound") -> None:
"""Do validation.
1) Check `_required` have been given
2) Make sure all attrs in required have a non-empty value
3) Loop through attributes and call their `validate_<attr>` methods, if any.
4) Validate allowed children
5) Validate signatures (if inbound)
"""
attributes = []
validates = []
# Collect attributes and validation methods
for attr in dir(self):
if not attr.startswith("_"):
attr_type = type(getattr(self, attr))
if attr_type != "method":
if getattr(self, "validate_{attr}".format(attr=attr), None):
validates.append(getattr(self, "validate_{attr}".format(attr=attr)))
attributes.append(attr)
self._validate_empty_attributes(attributes)
self._validate_required(attributes)
self._validate_attributes(validates)
self._validate_children()
if direction == "inbound":
self._validate_signatures()
def _validate_required(self, attributes):
"""Ensure required attributes are present."""
required_fulfilled = set(self._required).issubset(set(attributes))
if not required_fulfilled:
raise ValueError(
"Not all required attributes fulfilled. Required: {required}".format(required=set(self._required))
)
def _validate_attributes(self, validates):
"""Call individual attribute validators."""
for validator in validates:
validator()
def _validate_empty_attributes(self, attributes):
"""Check that required attributes are not empty."""
attrs_to_check = set(self._required) & set(attributes)
for attr in attrs_to_check:
value = getattr(self, attr) # We should always have a value here
if value is None or value == "":
raise ValueError(
"Attribute %s cannot be None or an empty string since it is required." % attr
)
def _validate_children(self):
"""Check that the children we have are allowed here."""
for child in self._children:
2019-09-07 23:03:51 +00:00
if not isinstance(child, self._allowed_children):
raise ValueError(
"Child %s is not allowed as a children for this %s type entity." % (
child, self.__class__
)
)
def _validate_signatures(self):
"""Override in subclasses where necessary"""
pass
def sign(self, private_key):
"""Implement in subclasses if needed."""
pass
def sign_with_parent(self, private_key):
"""Implement in subclasses if needed."""
pass
class PublicMixin(BaseEntity):
public = False
def validate_public(self):
if not isinstance(self.public, bool):
raise ValueError("Public is not valid - it should be True or False")
class TargetIDMixin(BaseEntity):
target_id = ""
target_handle = ""
target_guid = ""
def validate(self, *args, **kwargs) -> None:
super().validate(*args, **kwargs)
# Ensure one of the target attributes is filled at least
if not self.target_id and not self.target_handle and not self.target_guid:
raise ValueError("Must give one of the target attributes for TargetIDMixin.")
class RootTargetIDMixin(BaseEntity):
root_target_id = ""
root_target_handle = ""
root_target_guid = ""
class ParticipationMixin(TargetIDMixin):
"""Reflects a participation to something."""
participation = ""
_participation_valid_values = ["reaction", "subscription", "comment"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._required += ["participation"]
def validate_participation(self):
"""Ensure participation is of a certain type."""
if self.participation not in self._participation_valid_values:
raise ValueError("participation should be one of: {valid}".format(
valid=", ".join(self._participation_valid_values)
))
class CreatedAtMixin(BaseEntity):
created_at = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._required += ["created_at"]
if not "created_at" in kwargs:
self.created_at = datetime.datetime.now()
class RawContentMixin(BaseEntity):
_media_type: str = "text/markdown"
_mentions: Set = None
_rendered_content: str = ""
raw_content: str = ""
def __init__(self, *args, **kwargs):
self._mentions = set()
super().__init__(*args, **kwargs)
self._required += ["raw_content"]
@property
def embedded_images(self) -> List[Tuple[str, str]]:
"""
Returns a list of images from the raw_content.
Currently only markdown supported.
Returns a Tuple of (url, filename).
"""
images = []
if self._media_type != "text/markdown" or self.raw_content is None:
return images
regex = r"!\[([\w ]*)\]\((https?://[\w\d\-\./]+\.[\w]*((?<=jpg)|(?<=gif)|(?<=png)|(?<=jpeg)))\)"
matches = re.finditer(regex, self.raw_content, re.MULTILINE | re.IGNORECASE)
for match in matches:
groups = match.groups()
images.append((groups[1], groups[0] or ""))
return images
@property
def rendered_content(self) -> str:
"""Returns the rendered version of raw_content, or just raw_content."""
from federation.utils.django import get_configuration
try:
config = get_configuration()
if config["tags_path"]:
def linkifier(tag: str) -> str:
return f'<a class="mention hashtag" ' \
f' href="{config["base_url"]}{config["tags_path"].replace(":tag:", tag.lower())}" ' \
f'rel="noopener noreferrer">' \
f'#<span>{tag}</span></a>'
else:
linkifier = None
except ImportError:
linkifier = None
if self._rendered_content:
return self._rendered_content
elif self._media_type == "text/markdown" and self.raw_content:
# Do tags
_tags, rendered = find_tags(self.raw_content, replacer=linkifier)
# Render markdown to HTML
rendered = commonmark(rendered).strip()
# Do mentions
if self._mentions:
for mention in self._mentions:
# Only linkify mentions that are URL's
if not mention.startswith("http"):
continue
display_name = get_name_for_profile(mention)
if not display_name:
display_name = mention
rendered = rendered.replace(
"@{%s}" % mention,
f'@<a class="mention" href="{mention}"><span>{display_name}</span></a>',
)
# Finally linkify remaining URL's that are not links
rendered = process_text_links(rendered)
return rendered
return self.raw_content
@property
def tags(self) -> List[str]:
"""Returns a `list` of unique tags contained in `raw_content`."""
if not self.raw_content:
return []
tags, _text = find_tags(self.raw_content)
return sorted(tags)
def extract_mentions(self):
matches = re.findall(r'@{([\S ][^{}]+)}', self.raw_content)
if not matches:
return
for mention in matches:
splits = mention.split(";")
if len(splits) == 1:
self._mentions.add(splits[0].strip(' }'))
elif len(splits) == 2:
self._mentions.add(splits[1].strip(' }'))
class OptionalRawContentMixin(RawContentMixin):
"""A version of the RawContentMixin where `raw_content` is not required."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._required.remove("raw_content")
class EntityTypeMixin(BaseEntity):
"""
Provides a field for entity type.
"""
entity_type = ""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._required += ["entity_type"]
class ProviderDisplayNameMixin(BaseEntity):
"""Provides a field for provider display name."""
provider_display_name = ""