From 460e263588a2bcf607d26d17a35e2bd3f65c4281 Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Fri, 12 Apr 2019 13:45:47 -0400 Subject: [PATCH] Multithread DEM tiling, more memory tollerant, better DEM gapfill --- opendm/concurrency.py | 15 ----- opendm/dem/commands.py | 145 +++++++++++++++++++++++++++++------------ opendm/mesh.py | 3 +- scripts/odm_dem.py | 5 +- 4 files changed, 105 insertions(+), 63 deletions(-) diff --git a/opendm/concurrency.py b/opendm/concurrency.py index 5e390528..36b332be 100644 --- a/opendm/concurrency.py +++ b/opendm/concurrency.py @@ -9,18 +9,3 @@ def get_max_memory(minimum = 5, use_at_most = 0.5): """ return max(minimum, (100 - virtual_memory().percent) * use_at_most) - -def get_max_concurrency_for_dem(available_cores, input_file, use_at_most = 0.8): - """ - DEM generation requires ~2x the input file size of memory per available core. - :param available_cores number of cores available (return value will never exceed this value) - :param input_file path to input file - :use_at_most use at most this fraction of the available memory when calculating a concurrency value. 0.9 = assume that we can only use 90% of available memory. - :return maximum number of cores recommended to use for DEM processing. - """ - memory_available = virtual_memory().available * use_at_most - file_size = os.path.getsize(input_file) - memory_required_per_core = max(1, file_size * 2) - - return min(available_cores, max(1, int(memory_available) / int(memory_required_per_core))) - diff --git a/opendm/dem/commands.py b/opendm/dem/commands.py index 056a7b9c..bd75df37 100644 --- a/opendm/dem/commands.py +++ b/opendm/dem/commands.py @@ -1,17 +1,20 @@ -import os, glob +import os +import sys import gippy import numpy import math +import time from opendm.system import run from opendm import point_cloud from opendm.concurrency import get_max_memory -import pprint - -from scipy import ndimage +from scipy import ndimage, signal from datetime import datetime from opendm import log -from loky import get_reusable_executor -from functools import partial +try: + import Queue as queue +except: + import queue +import threading from . import pdal @@ -26,11 +29,15 @@ def classify(lasFile, slope=0.15, cellsize=1, maxWindowSize=18, verbose=False): log.ODM_INFO('Created %s in %s' % (os.path.relpath(lasFile), datetime.now() - start)) return lasFile +error = None def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'], gapfill=True, - outdir='', resolution=0.1, max_workers=None, max_tile_size=4096, + outdir='', resolution=0.1, max_workers=1, max_tile_size=2048, verbose=False, decimation=None): """ Create DEM from multiple radii, and optionally gapfill """ + global error + error = None + start = datetime.now() if not os.path.exists(outdir): @@ -39,18 +46,14 @@ def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'] extent = point_cloud.get_extent(input_point_cloud) log.ODM_INFO("Point cloud bounds are [minx: %s, maxx: %s] [miny: %s, maxy: %s]" % (extent['minx'], extent['maxx'], extent['miny'], extent['maxy'])) - # extent = { - # 'maxx': 100, - # 'minx': 0, - # 'maxy': 100, - # 'miny': 0 - # } ext_width = extent['maxx'] - extent['minx'] ext_height = extent['maxy'] - extent['miny'] final_dem_resolution = (int(math.ceil(ext_width / float(resolution))), int(math.ceil(ext_height / float(resolution)))) - num_splits = int(math.ceil(max(final_dem_resolution) / float(max_tile_size))) + final_dem_pixels = final_dem_resolution[0] * final_dem_resolution[1] + + num_splits = int(max(1, math.ceil(math.log(math.ceil(final_dem_pixels / float(max_tile_size * max_tile_size)))/math.log(2)))) num_tiles = num_splits * num_splits log.ODM_INFO("DEM resolution is %s, max tile size is %s, will split DEM generation into %s tiles" % (final_dem_resolution, max_tile_size, num_tiles)) @@ -94,12 +97,7 @@ def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'] # Sort tiles by increasing radius tiles.sort(key=lambda t: float(t['radius']), reverse=True) - # pp = pprint.PrettyPrinter(indent=4) - # pp.pprint(queue) - # TODO: parallel queue - queue = tiles[:] - - for q in queue: + def process_one(q): log.ODM_INFO("Generating %s (%s, radius: %s, resolution: %s)" % (q['filename'], output_type, q['radius'], resolution)) d = pdal.json_gdal_base(q['filename'], output_type, q['radius'], resolution, q['bounds']) @@ -114,7 +112,64 @@ def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'] pdal.json_add_readers(d, [input_point_cloud]) pdal.run_pipeline(d, verbose=verbose) - + + def worker(): + global error + + while True: + (num, q) = pq.get() + if q is None or error is not None: + pq.task_done() + break + + try: + process_one(q) + except Exception as e: + error = e + finally: + pq.task_done() + + if max_workers > 1: + use_single_thread = False + pq = queue.PriorityQueue() + threads = [] + for i in range(max_workers): + t = threading.Thread(target=worker) + t.start() + threads.append(t) + + for t in tiles: + pq.put((i, t.copy())) + + def stop_workers(): + for i in range(len(threads)): + pq.put((-1, None)) + for t in threads: + t.join() + + # block until all tasks are done + try: + while pq.unfinished_tasks > 0: + time.sleep(0.5) + except KeyboardInterrupt: + print("CTRL+C terminating...") + stop_workers() + sys.exit(1) + + stop_workers() + + if error is not None: + # Try to reprocess using a single thread + # in case this was a memory error + log.ODM_WARNING("DEM processing failed with multiple threads, let's retry with a single thread...") + use_single_thread = True + else: + use_single_thread = True + + if use_single_thread: + # Boring, single thread processing + for q in tiles: + process_one(q) output_file = "%s.tif" % dem_type output_path = os.path.abspath(os.path.join(outdir, output_file)) @@ -128,7 +183,7 @@ def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'] vrt_path = os.path.abspath(os.path.join(outdir, "merged.vrt")) run('gdalbuildvrt "%s" "%s"' % (vrt_path, '" "'.join(map(lambda t: t['filename'], tiles)))) - geotiff_path = os.path.abspath(os.path.join(outdir, 'merged.tiff')) + geotiff_path = os.path.abspath(os.path.join(outdir, 'merged.tif')) # Build GeoTIFF kwargs = { @@ -138,48 +193,48 @@ def create_dem(input_point_cloud, dem_type, output_type='max', radiuses=['0.56'] 'geotiff': geotiff_path } - run('gdal_translate ' + if gapfill: + run('gdal_fillnodata.py ' '-co NUM_THREADS={threads} ' '--config GDAL_CACHEMAX {max_memory}% ' + '-b 1 ' + '-of GTiff ' '{vrt} {geotiff}'.format(**kwargs)) - - if gapfill: - gapfill_and_smooth(geotiff_path, output_path) - os.remove(geotiff_path) else: - log.ODM_INFO("Skipping gap-fill interpolation") - os.rename(geotiff_path, output_path) + run('gdal_translate ' + '-co NUM_THREADS={threads} ' + '--config GDAL_CACHEMAX {max_memory}% ' + '{vrt} {geotiff}'.format(**kwargs)) - # TODO cleanup + post_process(geotiff_path, output_path) + os.remove(geotiff_path) + + if os.path.exists(vrt_path): os.remove(vrt_path) + for t in tiles: + if os.path.exists(t['filename']): os.remove(t['filename']) log.ODM_INFO('Completed %s in %s' % (output_file, datetime.now() - start)) - -def gapfill_and_smooth(geotiff_path, output_path): - """ Gap fill with nearest neighbor interpolation and apply median smoothing """ +def post_process(geotiff_path, output_path, smoothing_iterations=1): + """ Apply median smoothing """ start = datetime.now() if not os.path.exists(geotiff_path): raise Exception('File %s does not exist!' % geotiff_path) - log.ODM_INFO('Starting gap-filling with nearest interpolation...') + log.ODM_INFO('Starting post processing (smoothing)...') img = gippy.GeoImage(geotiff_path) nodata = img[0].nodata() arr = img[0].read() - # Nearest neighbor interpolation at bad points - indices = ndimage.distance_transform_edt(arr == nodata, - return_distances=False, - return_indices=True) - arr = arr[tuple(indices)] - # Median filter (careful, changing the value 5 might require tweaking) # the lines below. There's another numpy function that takes care of # these edge cases, but it's slower. - from scipy import signal - arr = signal.medfilt(arr, 5) + for i in range(smoothing_iterations): + log.ODM_INFO("Smoothing iteration %s" % str(i + 1)) + arr = signal.medfilt(arr, 5) # Fill corner points with nearest value if arr.shape >= (4, 4): @@ -188,6 +243,10 @@ def gapfill_and_smooth(geotiff_path, output_path): arr[-1][:2] = arr[-2][0] = arr[-2][1] arr[-1][-2:] = arr[-2][-1] = arr[-2][-2] + # Median filter leaves a bunch of zeros in nodata areas + locs = numpy.where(arr == 0.0) + arr[locs] = nodata + # write output imgout = gippy.GeoImage.create_from(img, output_path) imgout.set_nodata(nodata) @@ -195,6 +254,6 @@ def gapfill_and_smooth(geotiff_path, output_path): output_path = imgout.filename() imgout = None - log.ODM_INFO('Completed gap-filling to create %s in %s' % (os.path.relpath(output_path), datetime.now() - start)) + log.ODM_INFO('Completed post processing to create %s in %s' % (os.path.relpath(output_path), datetime.now() - start)) return output_path \ No newline at end of file diff --git a/opendm/mesh.py b/opendm/mesh.py index 4e090914..d17e4470 100644 --- a/opendm/mesh.py +++ b/opendm/mesh.py @@ -5,7 +5,6 @@ from opendm.dem import commands from opendm import system from opendm import log from opendm import context -from opendm.concurrency import get_max_concurrency_for_dem from scipy import signal, ndimage import numpy as np @@ -33,7 +32,7 @@ def create_25dmesh(inPointCloud, outMesh, dsm_radius=0.07, dsm_resolution=0.05, outdir=tmp_directory, resolution=dsm_resolution, verbose=verbose, - max_workers=get_max_concurrency_for_dem(available_cores, inPointCloud) + max_workers=available_cores ) if method == 'gridded': diff --git a/scripts/odm_dem.py b/scripts/odm_dem.py index dfaf45d5..dc20b7cf 100644 --- a/scripts/odm_dem.py +++ b/scripts/odm_dem.py @@ -9,7 +9,6 @@ from opendm import types from opendm import gsd from opendm.dem import commands from opendm.cropper import Cropper -from opendm.concurrency import get_max_concurrency_for_dem class ODMDEMCell(ecto.Cell): def declare_params(self, params): @@ -89,14 +88,14 @@ class ODMDEMCell(ecto.Cell): commands.create_dem( tree.odm_georeferencing_model_laz, product, - output_type='idw' if product == 'dtm' else 'max' + output_type='idw' if product == 'dtm' else 'max', radiuses=map(str, radius_steps), gapfill=args.dem_gapfill_steps > 0, outdir=odm_dem_root, resolution=resolution / 100.0, decimation=args.dem_decimation, verbose=args.verbose, - max_workers=get_max_concurrency_for_dem(args.max_concurrency,tree.odm_georeferencing_model_laz) + max_workers=args.max_concurrency ) if args.crop > 0: