diff --git a/backend/moonstream/providers/bugout.py b/backend/moonstream/providers/bugout.py index 1ace3ce9..313a3027 100644 --- a/backend/moonstream/providers/bugout.py +++ b/backend/moonstream/providers/bugout.py @@ -21,6 +21,8 @@ from ..settings import ETHTXPOOL_HUMBUG_CLIENT_ID logger = logging.getLogger(__name__) logger.setLevel(logging.WARN) +allowed_tags = ["tag:erc721"] + class BugoutEventProviderError(Exception): """ @@ -315,9 +317,12 @@ class EthereumTXPoolProvider(BugoutEventProvider): ] subscriptions_filters = [] for address in addresses: - subscriptions_filters.extend( - [f"?#from_address:{address}", f"?#to_address:{address}"] - ) + if address in allowed_tags: + subscriptions_filters.append(address) + else: + subscriptions_filters.extend( + [f"?#from_address:{address}", f"?#to_address:{address}"] + ) return subscriptions_filters diff --git a/backend/moonstream/providers/ethereum_blockchain.py b/backend/moonstream/providers/ethereum_blockchain.py index 1eb87bde..dc607012 100644 --- a/backend/moonstream/providers/ethereum_blockchain.py +++ b/backend/moonstream/providers/ethereum_blockchain.py @@ -8,11 +8,14 @@ from bugout.data import BugoutResource from moonstreamdb.models import ( EthereumBlock, EthereumTransaction, + EthereumAddress, + EthereumLabel, ) from sqlalchemy import or_, and_, text from sqlalchemy.orm import Session, Query from sqlalchemy.sql.functions import user + from .. import data from ..stream_boundaries import validate_stream_boundary from ..stream_queries import StreamQuery @@ -23,6 +26,7 @@ logger.setLevel(logging.WARN) event_type = "ethereum_blockchain" +allowed_tags = ["tag:erc721"] description = f"""Event provider for transactions from the Ethereum blockchain. @@ -79,6 +83,7 @@ class Filters: from_addresses: List[str] = field(default_factory=list) to_addresses: List[str] = field(default_factory=list) + labels: List[str] = field(default_factory=list) def default_filters(subscriptions: List[BugoutResource]) -> Filters: @@ -91,8 +96,11 @@ def default_filters(subscriptions: List[BugoutResource]) -> Filters: Optional[str], subscription.resource_data.get("address") ) if subscription_address is not None: - filters.from_addresses.append(subscription_address) - filters.to_addresses.append(subscription_address) + if subscription_address in allowed_tags: + filters.labels.append(subscription_address.split(":")[1]) + else: + filters.from_addresses.append(subscription_address) + filters.to_addresses.append(subscription_address) else: logger.warn( f"Could not find subscription address for subscription with resource id: {subscription.id}" @@ -157,14 +165,20 @@ def parse_filters( parsed_filters.from_addresses.append(address) parsed_filters.to_addresses.append(address) - if not (parsed_filters.from_addresses or parsed_filters.to_addresses): + if not ( + parsed_filters.from_addresses + or parsed_filters.to_addresses + or parsed_filters.labels + ): return None return parsed_filters def query_ethereum_transactions( - db_session: Session, stream_boundary: data.StreamBoundary, parsed_filters: Filters + db_session: Session, + stream_boundary: data.StreamBoundary, + parsed_filters: Filters, ) -> Query: """ Builds a database query for Ethereum transactions that occurred within the window of time that @@ -198,15 +212,41 @@ def query_ethereum_transactions( query = query.filter(EthereumBlock.timestamp <= stream_boundary.end_time) # We want to take a big disjunction (OR) over ALL the filters, be they on "from" address or "to" address - address_clauses = [ - EthereumTransaction.from_address == address - for address in parsed_filters.from_addresses - ] + [ - EthereumTransaction.to_address == address - for address in parsed_filters.to_addresses - ] - if address_clauses: - query = query.filter(or_(*address_clauses)) + address_clauses = [] + + address_clauses.extend( + [ + EthereumTransaction.from_address == address + for address in parsed_filters.from_addresses + ] + + [ + EthereumTransaction.to_address == address + for address in parsed_filters.to_addresses + ] + ) + + labels_clause = [] + + if parsed_filters.labels: + label_clause = ( + db_session.query(EthereumAddress) + .join(EthereumLabel, EthereumAddress.id == EthereumLabel.address_id) + .filter( + or_( + *[ + EthereumLabel.label.contains(label) + for label in list(set(parsed_filters.labels)) + ] + ) + ) + .exists() + ) + labels_clause.append(label_clause) + + subscriptions_clause = address_clauses + labels_clause + + if subscriptions_clause: + query = query.filter(or_(*subscriptions_clause)) return query @@ -353,8 +393,7 @@ def next_event( query_ethereum_transactions(db_session, next_stream_boundary, parsed_filters) .order_by(text("timestamp asc")) .limit(1) - .one_or_none() - ) + ).one_or_none() if maybe_ethereum_transaction is None: return None @@ -394,9 +433,7 @@ def previous_event( ) .order_by(text("timestamp desc")) .limit(1) - .one_or_none() - ) - + ).one_or_none() if maybe_ethereum_transaction is None: return None return ethereum_transaction_event(maybe_ethereum_transaction)