SQL UPSERT function for #63

pull/68/head
Konstantin Gründger 2019-01-27 13:59:12 +01:00
rodzic 23cc34da0a
commit 90ab582ca3
2 zmienionych plików z 49 dodań i 2 usunięć

Wyświetl plik

@ -1,7 +1,9 @@
from celery.utils.log import get_task_logger
from sqlalchemy import insert, distinct
from sqlalchemy.sql import null, and_, func, not_
from sqlalchemy.sql import null, and_, func, not_, case
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
from ogn.collect.celery import app
from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver
@ -11,6 +13,28 @@ from ogn.utils import get_ddb, get_flarmnet
logger = get_task_logger(__name__)
def compile_query(query):
"""Via http://nicolascadou.com/blog/2014/01/printing-actual-sqlalchemy-queries"""
compiler = query.compile if not hasattr(query, 'statement') else query.statement.compile
return compiler(dialect=postgresql.dialect())
def upsert(session, model, rows, update_cols):
"""Insert rows in model. On conflicting update columns if new value IS NOT NULL."""
table = model.__table__
stmt = insert(table).values(rows)
on_conflict_stmt = stmt.on_conflict_do_update(
index_elements=table.primary_key.columns,
set_={k: case([(getattr(stmt.excluded, k) != null(), getattr(stmt.excluded, k))], else_=getattr(model, k)) for k in update_cols},
)
# print(compile_query(on_conflict_stmt))
session.execute(on_conflict_stmt)
def update_device_infos(session, address_origin, path=None):
if address_origin == DeviceInfoOrigin.flarmnet:
device_infos = get_flarmnet(fln_file=path)

Wyświetl plik

@ -3,7 +3,7 @@ import unittest
from tests.base import TestBaseDB
from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver
from ogn.collect.database import add_missing_devices, add_missing_receivers
from ogn.collect.database import add_missing_devices, add_missing_receivers, upsert
class TestDatabase(TestBaseDB):
@ -30,6 +30,29 @@ class TestDatabase(TestBaseDB):
receiver_beacon = receiver_beacons[0]
self.assertEqual(receiver_beacon.receiver.name, 'Bene')
def test_insert_duplicate_beacons(self):
session = self.session
row1 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': None}
row2 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 0}
row3 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': 1}
row4 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
upsert(session=session, model=AircraftBeacon, rows=[row1, row2, row3, row4], update_cols=['ground_speed'])
row5 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': 2}
row6 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 3}
row7 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': None}
row8 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None}
upsert(session=session, model=AircraftBeacon, rows=[row5, row6, row7, row8], update_cols=['ground_speed'])
result = session.query(AircraftBeacon).order_by(AircraftBeacon.timestamp).all()
self.assertEqual(result[0].ground_speed, 2)
self.assertEqual(result[1].ground_speed, 3)
self.assertEqual(result[2].ground_speed, 1)
self.assertEqual(result[3].ground_speed, None)
if __name__ == '__main__':
unittest.main()