Add locks to fix racing conditions

pull/1674/head
Adrien-ANTON-LUDWIG 2023-07-13 11:51:13 +00:00
rodzic 9b9ba724c6
commit 87f82a1582
1 zmienionych plików z 17 dodań i 5 usunięć

Wyświetl plik

@ -6,6 +6,7 @@ import math
import time
import shutil
import functools
import threading
from joblib import delayed, Parallel
from opendm.system import run
from opendm import point_cloud
@ -314,7 +315,7 @@ def median_smoothing(geotiff_path, output_path, smoothing_iterations=1, window_s
log.ODM_INFO('Starting smoothing...')
# imgout needs to be 'w+' (write/read) to work in place for all the iterations but the first
with rasterio.open(geotiff_path) as img, rasterio.open(output_path, 'w+', BIGTIFF="IF_SAFER", **img.profile) as imgout:
with rasterio.open(geotiff_path, num_threads=num_workers,) as img, rasterio.open(output_path, "w+", BIGTIFF="IF_SAFER", num_threds=num_workers, **img.profile) as imgout:
nodata = img.nodatavals[0]
dtype = img.dtypes[0]
shape = img.shape
@ -329,19 +330,26 @@ def median_smoothing(geotiff_path, output_path, smoothing_iterations=1, window_s
filter = functools.partial(ndimage.median_filter, size=9, output=dtype, mode='nearest')
# We cannot read/write to the same file from multiple threads without causing race conditions.
# To safely read/write from multiple threads, we use a lock to protect the DatasetReader/Writer.
read_lock = threading.Lock()
write_lock = threading.Lock()
# threading backend and GIL released filter are important for memory efficiency and multi-core performance
Parallel(n_jobs=num_workers, backend='threading')(delayed(window_filter_2d)(img, imgout, nodata , window, 9, filter) for window in windows)
Parallel(n_jobs=num_workers, backend='threading')(delayed(window_filter_2d)(img, imgout, nodata , window, 9, filter, read_lock, write_lock) for window in windows)
# After the first iteration, modifications are done in place
if i == 0:
img = imgout
# We now read and write to the same file
read_lock = write_lock
log.ODM_INFO('Completed smoothing to create %s in %s' % (output_path, datetime.now() - start))
return output_path
def window_filter_2d(img, imgout, nodata, window, kernel_size, filter):
def window_filter_2d(img, imgout, nodata, window, kernel_size, filter, read_lock, write_lock):
"""
Apply a filter to dem within a window, expects to work with kernal based filters
@ -350,6 +358,8 @@ def window_filter_2d(img, imgout, nodata, window, kernel_size, filter):
:param window: the window to apply the filter, should be a list contains row start, col_start, row_end, col_end
:param kernel_size: the size of the kernel for the filter, works with odd numbers, need to test if it works with even numbers
:param filter: the filter function which takes a 2d array as input and filter results as output.
:param read_lock: threading lock for the read operation
:param write_lock: threading lock for the write operation
"""
shape = img.shape[:2]
if window[0] < 0 or window[1] < 0 or window[2] > shape[0] or window[3] > shape[1]:
@ -360,7 +370,8 @@ def window_filter_2d(img, imgout, nodata, window, kernel_size, filter):
width = expanded_window[3] - expanded_window[1]
height = expanded_window[2] - expanded_window[0]
rasterio_window = rasterio.windows.Window(col_off=expanded_window[1], row_off=expanded_window[0], width=width, height=height)
win_arr = img.read(indexes=1, window=rasterio_window)
with read_lock:
win_arr = img.read(indexes=1, window=rasterio_window)
# Should have a better way to handle nodata, similar to the way the filter algorithms handle the border (reflection, nearest, interpolation, etc).
# For now will follow the old approach to guarantee identical outputs
@ -373,7 +384,8 @@ def window_filter_2d(img, imgout, nodata, window, kernel_size, filter):
width = window[3] - window[1]
height = window[2] - window[0]
rasterio_window = rasterio.windows.Window(col_off=window[1], row_off=window[0], width=width, height=height)
imgout.write(win_arr, indexes=1, window=rasterio_window)
with write_lock:
imgout.write(win_arr, indexes=1, window=rasterio_window)
def get_dem_radius_steps(stats_file, steps, resolution, multiplier = 1.0):