kopia lustrzana https://github.com/glidernet/ogn-python
SQL UPSERT function for #63
rodzic
23cc34da0a
commit
90ab582ca3
|
@ -1,7 +1,9 @@
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
|
|
||||||
from sqlalchemy import insert, distinct
|
from sqlalchemy import insert, distinct
|
||||||
from sqlalchemy.sql import null, and_, func, not_
|
from sqlalchemy.sql import null, and_, func, not_, case
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
|
||||||
from ogn.collect.celery import app
|
from ogn.collect.celery import app
|
||||||
from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver
|
from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver
|
||||||
|
@ -11,6 +13,28 @@ from ogn.utils import get_ddb, get_flarmnet
|
||||||
logger = get_task_logger(__name__)
|
logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_query(query):
|
||||||
|
"""Via http://nicolascadou.com/blog/2014/01/printing-actual-sqlalchemy-queries"""
|
||||||
|
compiler = query.compile if not hasattr(query, 'statement') else query.statement.compile
|
||||||
|
return compiler(dialect=postgresql.dialect())
|
||||||
|
|
||||||
|
|
||||||
|
def upsert(session, model, rows, update_cols):
|
||||||
|
"""Insert rows in model. On conflicting update columns if new value IS NOT NULL."""
|
||||||
|
|
||||||
|
table = model.__table__
|
||||||
|
|
||||||
|
stmt = insert(table).values(rows)
|
||||||
|
|
||||||
|
on_conflict_stmt = stmt.on_conflict_do_update(
|
||||||
|
index_elements=table.primary_key.columns,
|
||||||
|
set_={k: case([(getattr(stmt.excluded, k) != null(), getattr(stmt.excluded, k))], else_=getattr(model, k)) for k in update_cols},
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(compile_query(on_conflict_stmt))
|
||||||
|
session.execute(on_conflict_stmt)
|
||||||
|
|
||||||
|
|
||||||
def update_device_infos(session, address_origin, path=None):
|
def update_device_infos(session, address_origin, path=None):
|
||||||
if address_origin == DeviceInfoOrigin.flarmnet:
|
if address_origin == DeviceInfoOrigin.flarmnet:
|
||||||
device_infos = get_flarmnet(fln_file=path)
|
device_infos = get_flarmnet(fln_file=path)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import unittest
|
||||||
from tests.base import TestBaseDB
|
from tests.base import TestBaseDB
|
||||||
|
|
||||||
from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver
|
from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver
|
||||||
from ogn.collect.database import add_missing_devices, add_missing_receivers
|
from ogn.collect.database import add_missing_devices, add_missing_receivers, upsert
|
||||||
|
|
||||||
|
|
||||||
class TestDatabase(TestBaseDB):
|
class TestDatabase(TestBaseDB):
|
||||||
|
@ -30,6 +30,29 @@ class TestDatabase(TestBaseDB):
|
||||||
receiver_beacon = receiver_beacons[0]
|
receiver_beacon = receiver_beacons[0]
|
||||||
self.assertEqual(receiver_beacon.receiver.name, 'Bene')
|
self.assertEqual(receiver_beacon.receiver.name, 'Bene')
|
||||||
|
|
||||||
|
def test_insert_duplicate_beacons(self):
|
||||||
|
session = self.session
|
||||||
|
|
||||||
|
row1 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': None}
|
||||||
|
row2 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 0}
|
||||||
|
row3 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': 1}
|
||||||
|
row4 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
|
||||||
|
|
||||||
|
upsert(session=session, model=AircraftBeacon, rows=[row1, row2, row3, row4], update_cols=['ground_speed'])
|
||||||
|
|
||||||
|
row5 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': 2}
|
||||||
|
row6 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 3}
|
||||||
|
row7 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': None}
|
||||||
|
row8 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
|
||||||
|
|
||||||
|
upsert(session=session, model=AircraftBeacon, rows=[row5, row6, row7, row8], update_cols=['ground_speed'])
|
||||||
|
|
||||||
|
result = session.query(AircraftBeacon).order_by(AircraftBeacon.timestamp).all()
|
||||||
|
self.assertEqual(result[0].ground_speed, 2)
|
||||||
|
self.assertEqual(result[1].ground_speed, 3)
|
||||||
|
self.assertEqual(result[2].ground_speed, 1)
|
||||||
|
self.assertEqual(result[3].ground_speed, None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Ładowanie…
Reference in New Issue