Fetch and upload embedded images before sending to Matrix

test
Jason Robinson 2020-12-30 00:06:35 +02:00
rodzic 82ac0ce3cf
commit e4db91503b
3 zmienionych plików z 82 dodań i 4 usunięć

Wyświetl plik

@ -1,15 +1,19 @@
import json
import logging
import mimetypes
import os
from typing import Dict, List
from urllib.parse import quote
from uuid import uuid4
import requests
from federation.entities.base import Post, Profile
from federation.entities.matrix.enums import EventType
from federation.entities.mixins import BaseEntity
from federation.entities.utils import get_base_attributes
from federation.utils.matrix import get_matrix_configuration, appservice_auth_header
from federation.utils.network import fetch_document
from federation.utils.network import fetch_document, fetch_file
logger = logging.getLogger("federation")
@ -41,6 +45,11 @@ class MatrixEntityMixin(BaseEntity):
config = get_matrix_configuration()
return f"{config['homeserver_base_url']}/_matrix/client/r0"
# noinspection PyMethodMayBeStatic
def get_endpoint_media(self) -> str:
config = get_matrix_configuration()
return f"{config['homeserver_base_url']}/_matrix/media/r0"
def get_profile_room_id(self):
# TODO: we should cache these.
doc, status, error = fetch_document(
@ -89,9 +98,60 @@ class MatrixRoomMessage(Post, MatrixEntityMixin):
def pre_send(self):
"""
Get profile room ID.
Do various pre-send things.
"""
super().pre_send()
# Get profile room ID
self.get_profile_room_id()
# Upload embedded images and replace the HTTP urls in the message with MXC urls so clients show the images
self.upload_embedded_images()
def upload_embedded_images(self):
"""
Upload embedded images
Replaces the HTTP urls in the message with MXC urls so that Matrix clients will show the images.
"""
for image in self.embedded_images:
url, name = image
headers = appservice_auth_header()
content_type, _encoding = mimetypes.guess_type(url)
headers["Content-Type"] = content_type
# Random name if none
if not name:
name = f"{uuid4()}{mimetypes.guess_extension(content_type, strict=False)}"
# Need to fetch it locally first
# noinspection PyBroadException
try:
image_file = fetch_file(url=url, timeout=60)
except Exception as ex:
logger.warning("MatrixRoomMessage.pre_send | Failed to retrieve image %s to be uploaded: %s",
url, ex)
continue
# Then upload
headers["Content-Length"] = str(os.stat(image_file).st_size)
# noinspection PyBroadException
try:
with open(image_file, "rb") as f:
response = requests.post(
f"{super().get_endpoint_media()}/upload?filename={quote(name)}&user_id={self.mxid}",
data=f.read(),
headers=headers,
timeout=60,
)
response.raise_for_status()
except Exception as ex:
logger.warning("MatrixRoomMessage.pre_send | Failed to upload image %s: %s",
url, ex)
continue
# Replace in raw content
try:
logger.debug("MatrixRoomMessage.pre_send | Got response %s", response.json())
content_uri = response.json()["content_uri"]
self.raw_content = self.raw_content.replace(url, content_uri)
except Exception as ex:
logger.error("MatrixRoomMessage.pre_send | Failed to find content_uri from the image upload "
"response: %s", ex)
class MatrixProfile(Profile, MatrixEntityMixin):

Wyświetl plik

@ -211,7 +211,7 @@ class RawContentMixin(BaseEntity):
Returns a Tuple of (url, filename).
"""
images = []
if self._media_type != "text/markdown":
if self._media_type != "text/markdown" or self.raw_content is None:
return images
regex = r"!\[([\w ]*)\]\((https?://[\w\d\-\./]+\.[\w]*((?<=jpg)|(?<=gif)|(?<=png)|(?<=jpeg)))\)"
matches = re.finditer(regex, self.raw_content, re.MULTILINE | re.IGNORECASE)

Wyświetl plik

@ -3,8 +3,9 @@ import datetime
import logging
import re
import socket
from typing import Optional
from typing import Optional, Dict
from urllib.parse import quote
from uuid import uuid4
import requests
from requests.exceptions import RequestException, HTTPError, SSLError
@ -107,6 +108,22 @@ def fetch_host_ip(host: str) -> str:
return ip
def fetch_file(url: str, timeout: int = 30, extra_headers: Dict = None) -> str:
"""
Download a file with a temporary name and return the name.
"""
headers = {'user-agent': USER_AGENT}
if extra_headers:
headers.update(extra_headers)
response = requests.get(url, timeout=timeout, headers=headers, stream=True)
response.raise_for_status()
name = f"/tmp/{str(uuid4())}"
with open(name, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return name
def parse_http_date(date):
"""
Parse a date format as specified by HTTP RFC7231 section 7.1.1.1.
@ -185,6 +202,7 @@ def send_document(url, data, timeout=10, method="post", *args, **kwargs):
response = request_func(url, *args, **kwargs)
logger.debug("send_document: response status code %s", response.status_code)
return response.status_code, None
# TODO support rate limit 429 code
except RequestException as ex:
logger.debug("send_document: exception %s", ex)
return None, ex