diff --git a/moonstreamdb-v3/moonstreamdbv3/db.py b/moonstreamdb-v3/moonstreamdbv3/db.py index a4d733ab..9bb47505 100644 --- a/moonstreamdb-v3/moonstreamdbv3/db.py +++ b/moonstreamdb-v3/moonstreamdbv3/db.py @@ -5,7 +5,7 @@ Moonstream database connection. import logging import os from contextlib import contextmanager -from typing import Generator +from typing import Generator, Optional from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker @@ -56,23 +56,28 @@ def create_moonstream_engine( pool_size: int, statement_timeout: int, pool_pre_ping: bool = False, + schema: Optional[str] = None, ): # Pooling: https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool # Statement timeout: https://stackoverflow.com/a/44936982 + options = f"-c statement_timeout={statement_timeout}" + if schema is not None: + options += f" -c search_path={schema}" return create_engine( url=url, pool_pre_ping=pool_pre_ping, pool_size=pool_size, - connect_args={"options": f"-c statement_timeout={statement_timeout}"}, + connect_args={"options": options}, ) class MoonstreamDBEngine: - def __init__(self) -> None: + def __init__(self, schema: Optional[str] = None) -> None: self._engine = create_moonstream_engine( url=MOONSTREAM_DB_URI, # type: ignore pool_size=MOONSTREAM_POOL_SIZE, statement_timeout=MOONSTREAM_DB_STATEMENT_TIMEOUT_MILLIS, + schema=schema, ) self._session_local = sessionmaker(bind=self.engine) @@ -106,11 +111,12 @@ class MoonstreamDBEngine: class MoonstreamDBEngineRO: - def __init__(self) -> None: + def __init__(self, schema: Optional[str] = None) -> None: self._RO_engine = create_moonstream_engine( url=MOONSTREAM_DB_URI_READ_ONLY, # type: ignore pool_size=MOONSTREAM_POOL_SIZE, statement_timeout=MOONSTREAM_DB_STATEMENT_TIMEOUT_MILLIS, + schema=schema, ) self._RO_session_local = sessionmaker(bind=self.RO_engine)