From 90ab582ca354d3ee25b65a60d39a745e0e02f89e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konstantin=20Gru=CC=88ndger?= Date: Sun, 27 Jan 2019 13:59:12 +0100 Subject: [PATCH] SQL UPSERT function for #63 --- ogn/collect/database.py | 26 +++++++++++++++++++++++++- tests/collect/test_database.py | 25 ++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/ogn/collect/database.py b/ogn/collect/database.py index 30cdafc..03ddf10 100644 --- a/ogn/collect/database.py +++ b/ogn/collect/database.py @@ -1,7 +1,9 @@ from celery.utils.log import get_task_logger from sqlalchemy import insert, distinct -from sqlalchemy.sql import null, and_, func, not_ +from sqlalchemy.sql import null, and_, func, not_, case +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import insert from ogn.collect.celery import app from ogn.model import Country, DeviceInfo, DeviceInfoOrigin, AircraftBeacon, ReceiverBeacon, Device, Receiver @@ -11,6 +13,28 @@ from ogn.utils import get_ddb, get_flarmnet logger = get_task_logger(__name__) +def compile_query(query): + """Via http://nicolascadou.com/blog/2014/01/printing-actual-sqlalchemy-queries""" + compiler = query.compile if not hasattr(query, 'statement') else query.statement.compile + return compiler(dialect=postgresql.dialect()) + + +def upsert(session, model, rows, update_cols): + """Insert rows in model. On conflicting update columns if new value IS NOT NULL.""" + + table = model.__table__ + + stmt = insert(table).values(rows) + + on_conflict_stmt = stmt.on_conflict_do_update( + index_elements=table.primary_key.columns, + set_={k: case([(getattr(stmt.excluded, k) != null(), getattr(stmt.excluded, k))], else_=getattr(model, k)) for k in update_cols}, + ) + + # print(compile_query(on_conflict_stmt)) + session.execute(on_conflict_stmt) + + def update_device_infos(session, address_origin, path=None): if address_origin == DeviceInfoOrigin.flarmnet: device_infos = get_flarmnet(fln_file=path) diff --git a/tests/collect/test_database.py b/tests/collect/test_database.py index be6b33c..4fba4ed 100644 --- a/tests/collect/test_database.py +++ b/tests/collect/test_database.py @@ -3,7 +3,7 @@ import unittest from tests.base import TestBaseDB from ogn.model import AircraftBeacon, ReceiverBeacon, Device, Receiver -from ogn.collect.database import add_missing_devices, add_missing_receivers +from ogn.collect.database import add_missing_devices, add_missing_receivers, upsert class TestDatabase(TestBaseDB): @@ -30,6 +30,29 @@ class TestDatabase(TestBaseDB): receiver_beacon = receiver_beacons[0] self.assertEqual(receiver_beacon.receiver.name, 'Bene') + def test_insert_duplicate_beacons(self): + session = self.session + + row1 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': None} + row2 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 0} + row3 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': 1} + row4 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None} + + upsert(session=session, model=AircraftBeacon, rows=[row1, row2, row3, row4], update_cols=['ground_speed']) + + row5 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:51:00', 'ground_speed': 2} + row6 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:52:00', 'ground_speed': 3} + row7 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:53:00', 'ground_speed': None} + row8 = {'name': 'FLRDD0815', 'receiver_name': 'Koenigsdf', 'timestamp': '2019-01-26 11:54:00', 'ground_speed': None} + + upsert(session=session, model=AircraftBeacon, rows=[row5, row6, row7, row8], update_cols=['ground_speed']) + + result = session.query(AircraftBeacon).order_by(AircraftBeacon.timestamp).all() + self.assertEqual(result[0].ground_speed, 2) + self.assertEqual(result[1].ground_speed, 3) + self.assertEqual(result[2].ground_speed, 1) + self.assertEqual(result[3].ground_speed, None) + if __name__ == '__main__': unittest.main()