Added stream query parser

pull/105/head
Neeraj Kashyap 2021-08-18 07:48:39 -07:00
rodzic b5ac5fbd96
commit 8df9cdeee1
2 zmienionych plików z 117 dodań i 0 usunięć

Wyświetl plik

@ -0,0 +1,56 @@
"""
Stream queries - data structure, and parser.
"""
from dataclasses import dataclass, field
import logging
from typing import cast, List, Tuple
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@dataclass
class StreamQuery:
subscription_types: List[str] = field(default_factory=list)
# Subscriptions are expected to be specified in the form of an ordered pair:
# (<subscription_type>, <address>)
subscriptions: List[Tuple[str, str]] = field(default_factory=list)
SUBSCRIPTION_TYPE_PREFIX = "type:"
SUBSCRIPTION_PREFIX = "sub:"
SUBSCRIPTION_SEPARATOR = ":"
def parse_query_string(q: str) -> StreamQuery:
"""
Parses a query string (as specified in query parameters on a call to the /streams/ endpoint).
Args:
1. q - Query string. It is parsed as follows:
a. Query string is tokenized (by splitting on whitespace).
b. Tokens of the form "type:<subscription_type>" populate the subscription_types field of the resulting StreamQuery
c. Tokens of the form "sub:<subscription_type>:<address> populate the subscriptions field of the resulting StreamQuery
Returns: Parsed StreamQuery object.
"""
subscription_types: List[str] = []
subscriptions: List[Tuple[str, str]] = []
tokens = q.split()
for token in tokens:
if token.startswith(SUBSCRIPTION_TYPE_PREFIX):
subscription_types.append(token[len(SUBSCRIPTION_TYPE_PREFIX) :])
elif token.startswith(SUBSCRIPTION_PREFIX):
contents = token[len(SUBSCRIPTION_PREFIX) :]
components = tuple(contents.split(SUBSCRIPTION_SEPARATOR))
if len(components) == 2:
subscriptions.append(cast(Tuple[str, str], components))
else:
logger.error(f"Invalid subscription token: {token}")
else:
logger.error(f"Invalid token: {token}")
return StreamQuery(
subscription_types=subscription_types, subscriptions=subscriptions
)

Wyświetl plik

@ -0,0 +1,61 @@
import unittest
from urllib import parse
from .stream_queries import parse_query_string
class TestParseQueryString(unittest.TestCase):
def test_single_subscription_type(self):
q = "type:ethereum_blockchain"
query = parse_query_string(q)
self.assertListEqual(query.subscription_types, ["ethereum_blockchain"])
self.assertListEqual(query.subscriptions, [])
def test_multiple_subscription_types(self):
q = "type:ethereum_blockchain type:ethereum_whalewatch"
query = parse_query_string(q)
self.assertListEqual(
query.subscription_types, ["ethereum_blockchain", "ethereum_whalewatch"]
)
self.assertListEqual(query.subscriptions, [])
def test_single_subscription(self):
q = "sub:ethereum_blockchain:0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66"
query = parse_query_string(q)
self.assertListEqual(query.subscription_types, [])
self.assertListEqual(
query.subscriptions,
[("ethereum_blockchain", "0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66")],
)
def test_multiple_subscriptions(self):
q = "sub:ethereum_blockchain:0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66 sub:ethereum_blockchain:0x2819c144d5946404c0516b6f817a960db37d4929 sub:ethereum_txpool:0x2819c144d5946404c0516b6f817a960db37d4929"
query = parse_query_string(q)
self.assertListEqual(query.subscription_types, [])
self.assertListEqual(
query.subscriptions,
[
("ethereum_blockchain", "0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66"),
("ethereum_blockchain", "0x2819c144d5946404c0516b6f817a960db37d4929"),
("ethereum_txpool", "0x2819c144d5946404c0516b6f817a960db37d4929"),
],
)
def test_multiple_subscription_types_and_subscriptions(self):
q = "type:ethereum_whalewatch type:solana_blockchain sub:ethereum_blockchain:0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66 sub:ethereum_blockchain:0x2819c144d5946404c0516b6f817a960db37d4929 sub:ethereum_txpool:0x2819c144d5946404c0516b6f817a960db37d4929"
query = parse_query_string(q)
self.assertListEqual(
query.subscription_types, ["ethereum_whalewatch", "solana_blockchain"]
)
self.assertListEqual(
query.subscriptions,
[
("ethereum_blockchain", "0xbb2569ca55552fb4c1d73ec536e06a620c3d3d66"),
("ethereum_blockchain", "0x2819c144d5946404c0516b6f817a960db37d4929"),
("ethereum_txpool", "0x2819c144d5946404c0516b6f817a960db37d4929"),
],
)
if __name__ == "__main__":
unittest.main()