SQL UPSERT function for #63

pull/68/head
Konstantin Gründger 2019-01-27 13:59:12 +01:00
rodzic 23cc34da0a
commit 90ab582ca3
2 zmienionych plików z 49 dodań i 2 usunięć

Wyświetl plik

@ -1,7 +1,9 @@
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy import insert, distinct from sqlalchemy import insert, distinct
from sqlalchemy.sql import null, and_, func, not_ from sqlalchemy.sql import null, and_, func, not_, case
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
from ogn.collect.celery import app from ogn.collect.celery import app
from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver
@ -11,6 +13,28 @@ from ogn.utils import get_ddb, get_flarmnet
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
def compile_query(query):
"""Via http://nicolascadou.com/blog/2014/01/printing-actual-sqlalchemy-queries"""
compiler = query.compile if not hasattr(query, 'statement') else query.statement.compile
return compiler(dialect=postgresql.dialect())
def upsert(session, model, rows, update_cols):
"""Insert rows in model. On conflicting update columns if new value IS NOT NULL."""
table = model.__table__
stmt = insert(table).values(rows)
on_conflict_stmt = stmt.on_conflict_do_update(
index_elements=table.primary_key.columns,
set_={k: case([(getattr(stmt.excluded, k) != null(), getattr(stmt.excluded, k))], else_=getattr(model, k)) for k in update_cols},
)
# print(compile_query(on_conflict_stmt))
session.execute(on_conflict_stmt)
def update_device_infos(session, address_origin, path=None): def update_device_infos(session, address_origin, path=None):
if address_origin == DeviceInfoOrigin.flarmnet: if address_origin == DeviceInfoOrigin.flarmnet:
device_infos = get_flarmnet(fln_file=path) device_infos = get_flarmnet(fln_file=path)

Wyświetl plik

@ -3,7 +3,7 @@ import unittest
from tests.base import TestBaseDB from tests.base import TestBaseDB
from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver
from ogn.collect.database import add_missing_devices, add_missing_receivers from ogn.collect.database import add_missing_devices, add_missing_receivers, upsert
class TestDatabase(TestBaseDB): class TestDatabase(TestBaseDB):
@ -30,6 +30,29 @@ class TestDatabase(TestBaseDB):
receiver_beacon = receiver_beacons[0] receiver_beacon = receiver_beacons[0]
self.assertEqual(receiver_beacon.receiver.name, 'Bene') self.assertEqual(receiver_beacon.receiver.name, 'Bene')
def test_insert_duplicate_beacons(self):
session = self.session
row1 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': None}
row2 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 0}
row3 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': 1}
row4 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
upsert(session=session, model=AircraftBeacon, rows=[row1, row2, row3, row4], update_cols=['ground_speed'])
row5 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': 2}
row6 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 3}
row7 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': None}
row8 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
upsert(session=session, model=AircraftBeacon, rows=[row5, row6, row7, row8], update_cols=['ground_speed'])
result = session.query(AircraftBeacon).order_by(AircraftBeacon.timestamp).all()
self.assertEqual(result[0].ground_speed, 2)
self.assertEqual(result[1].ground_speed, 3)
self.assertEqual(result[2].ground_speed, 1)
self.assertEqual(result[3].ground_speed, None)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()