From b584459fc97da0f90fc2c86d6fbeaabb8dc34dbf Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Thu, 14 Jul 2022 16:02:59 -0400 Subject: [PATCH] AI-powered automatic sky removal --- .gitignore | 1 + opendm/ai.py | 51 ++++++++++ opendm/concurrency.py | 7 +- opendm/config.py | 6 ++ opendm/net.py | 164 ++++++++++++++++++++++++++++++ opendm/skyremoval/__init__.py | 0 opendm/skyremoval/guidedfilter.py | 37 +++++++ opendm/skyremoval/skyfilter.py | 103 +++++++++++++++++++ requirements.txt | 2 + stages/dataset.py | 48 ++++++++- 10 files changed, 416 insertions(+), 3 deletions(-) create mode 100644 opendm/ai.py create mode 100644 opendm/net.py create mode 100644 opendm/skyremoval/__init__.py create mode 100644 opendm/skyremoval/guidedfilter.py create mode 100644 opendm/skyremoval/skyfilter.py diff --git a/.gitignore b/.gitignore index 7f4f29ac..2bf60d58 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ settings.yaml .setupdevenv __pycache__ *.snap +storage/ diff --git a/opendm/ai.py b/opendm/ai.py new file mode 100644 index 00000000..834d4043 --- /dev/null +++ b/opendm/ai.py @@ -0,0 +1,51 @@ +import os +from opendm.net import download +from opendm import log +import zipfile +import time + +def get_model(namespace, url, version, name = "model.onnx"): + version = version.replace(".", "_") + + base_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), "storage", "models") + namespace_dir = os.path.join(base_dir, namespace) + versioned_dir = os.path.join(namespace_dir, version) + + if not os.path.isdir(versioned_dir): + os.makedirs(versioned_dir, exist_ok=True) + + # Check if we need to download it + model_file = os.path.join(versioned_dir, name) + if not os.path.isfile(model_file): + log.ODM_INFO("Downloading AI model from %s ..." % url) + + last_update = 0 + + def callback(progress): + nonlocal last_update + + time_has_elapsed = time.time() - last_update >= 2 + + if time_has_elapsed or int(progress) == 100: + log.ODM_INFO("Downloading: %s%%" % int(progress)) + last_update = time.time() + + try: + downloaded_file = download(url, versioned_dir, progress_callback=callback) + except Exception as e: + log.ODM_WARNING("Cannot download %s: %s" % (url, str(e))) + return None + + if os.path.basename(downloaded_file).lower().endswith(".zip"): + log.ODM_INFO("Extracting %s ..." % downloaded_file) + with zipfile.ZipFile(downloaded_file, 'r') as z: + z.extractall(versioned_dir) + os.remove(downloaded_file) + + if not os.path.isfile(model_file): + log.ODM_WARNING("Cannot find %s (is the URL to the AI model correct?)" % model_file) + return None + else: + return model_file + else: + return model_file \ No newline at end of file diff --git a/opendm/concurrency.py b/opendm/concurrency.py index 3cb62488..adbeccba 100644 --- a/opendm/concurrency.py +++ b/opendm/concurrency.py @@ -25,7 +25,7 @@ def get_max_memory_mb(minimum = 100, use_at_most = 0.5): """ return max(minimum, (virtual_memory().available / 1024 / 1024) * use_at_most) -def parallel_map(func, items, max_workers=1, single_thread_fallback=True): +def parallel_map(func, items, max_workers=1, single_thread_fallback=True, copy_queue_items=True): """ Our own implementation for parallel processing which handles gracefully CTRL+C and reverts to @@ -66,7 +66,10 @@ def parallel_map(func, items, max_workers=1, single_thread_fallback=True): i = 1 for t in items: - pq.put((i, t.copy())) + if copy_queue_items: + pq.put((i, t.copy())) + else: + pq.put((i, t)) i += 1 def stop_workers(): diff --git a/opendm/config.py b/opendm/config.py index edef6f82..96bd5a64 100755 --- a/opendm/config.py +++ b/opendm/config.py @@ -237,6 +237,12 @@ def config(argv=None, parser=None): 'Can be one of: %(choices)s. Default: ' '%(default)s')) + parser.add_argument('--sky-removal', + action=StoreTrue, + nargs=0, + default=False, + help='Automatically compute image masks using AI to remove the sky. Default: %(default)s') + parser.add_argument('--use-3dmesh', action=StoreTrue, nargs=0, diff --git a/opendm/net.py b/opendm/net.py new file mode 100644 index 00000000..f2bb5b3e --- /dev/null +++ b/opendm/net.py @@ -0,0 +1,164 @@ +import requests +import math +import os +import time +try: + import queue +except ImportError: + import Queue as queue +import threading +from pyodm.utils import AtomicCounter +from pyodm.exceptions import RangeNotAvailableError, OdmError +from urllib3.exceptions import ReadTimeoutError + +def download(url, destination, progress_callback=None, parallel_downloads=16, parallel_chunks_size=10, timeout=30): + """Download files in parallel (download accelerator) + + Args: + url (str): URL to download + destination (str): directory where to download file. If the directory does not exist, it will be created. + progress_callback (function): an optional callback with one parameter, the download progress percentage. + parallel_downloads (int): maximum number of parallel downloads if the node supports http range. + parallel_chunks_size (int): size in MB of chunks for parallel downloads + timeout (int): seconds before timing out + Returns: + str: path to file + """ + if not os.path.exists(destination): + os.makedirs(destination, exist_ok=True) + + try: + + download_stream = requests.get(url, timeout=timeout, stream=True) + headers = download_stream.headers + + output_path = os.path.join(destination, os.path.basename(url)) + + # Keep track of download progress (if possible) + content_length = download_stream.headers.get('content-length') + total_length = int(content_length) if content_length is not None else None + downloaded = 0 + chunk_size = int(parallel_chunks_size * 1024 * 1024) + use_fallback = False + accept_ranges = headers.get('accept-ranges') + + # Can we do parallel downloads? + if accept_ranges is not None and accept_ranges.lower() == 'bytes' and total_length is not None and total_length > chunk_size and parallel_downloads > 1: + num_chunks = int(math.ceil(total_length / float(chunk_size))) + num_workers = parallel_downloads + + class nonloc: + completed_chunks = AtomicCounter(0) + merge_chunks = [False] * num_chunks + error = None + + def merge(): + current_chunk = 0 + + with open(output_path, "wb") as out_file: + while current_chunk < num_chunks and nonloc.error is None: + if nonloc.merge_chunks[current_chunk]: + chunk_file = "%s.part%s" % (output_path, current_chunk) + with open(chunk_file, "rb") as fd: + out_file.write(fd.read()) + + os.unlink(chunk_file) + + current_chunk += 1 + else: + time.sleep(0.1) + + def worker(): + while True: + task = q.get() + part_num, bytes_range = task + if bytes_range is None or nonloc.error is not None: + q.task_done() + break + + try: + # Download chunk + res = requests.get(url, stream=True, timeout=timeout, headers={'Range': 'bytes=%s-%s' % bytes_range}) + if res.status_code == 206: + with open("%s.part%s" % (output_path, part_num), 'wb') as fd: + bytes_written = 0 + try: + for chunk in res.iter_content(4096): + bytes_written += fd.write(chunk) + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: + raise OdmError(str(e)) + + if bytes_written != (bytes_range[1] - bytes_range[0] + 1): + # Process again + q.put((part_num, bytes_range)) + return + + with nonloc.completed_chunks.lock: + nonloc.completed_chunks.value += 1 + + if progress_callback is not None: + progress_callback(100.0 * nonloc.completed_chunks.value / num_chunks) + + nonloc.merge_chunks[part_num] = True + else: + nonloc.error = RangeNotAvailableError() + except OdmError as e: + time.sleep(5) + q.put((part_num, bytes_range)) + except Exception as e: + nonloc.error = e + finally: + q.task_done() + + q = queue.PriorityQueue() + threads = [] + for i in range(num_workers): + t = threading.Thread(target=worker) + t.start() + threads.append(t) + + merge_thread = threading.Thread(target=merge) + merge_thread.start() + + range_start = 0 + + for i in range(num_chunks): + range_end = min(range_start + chunk_size - 1, total_length - 1) + q.put((i, (range_start, range_end))) + range_start = range_end + 1 + + # block until all tasks are done + while not all(nonloc.merge_chunks) and nonloc.error is None: + time.sleep(0.1) + + # stop workers + for i in range(len(threads)): + q.put((-1, None)) + for t in threads: + t.join() + + merge_thread.join() + + if nonloc.error is not None: + if isinstance(nonloc.error, RangeNotAvailableError): + use_fallback = True + else: + raise nonloc.error + else: + use_fallback = True + + if use_fallback: + # Single connection, boring download + with open(output_path, 'wb') as fd: + for chunk in download_stream.iter_content(4096): + downloaded += len(chunk) + + if progress_callback is not None and total_length is not None: + progress_callback((100.0 * float(downloaded) / total_length)) + + fd.write(chunk) + + except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, ReadTimeoutError) as e: + raise OdmError(e) + + return output_path \ No newline at end of file diff --git a/opendm/skyremoval/__init__.py b/opendm/skyremoval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/opendm/skyremoval/guidedfilter.py b/opendm/skyremoval/guidedfilter.py new file mode 100644 index 00000000..1bf29777 --- /dev/null +++ b/opendm/skyremoval/guidedfilter.py @@ -0,0 +1,37 @@ +import numpy as np + +# Based on Fast Guided Filter +# Kaiming He, Jian Sun +# https://arxiv.org/abs/1505.00996 + +def box(img, radius): + dst = np.zeros_like(img) + (r, c) = img.shape + + s = [radius, 1] + c_sum = np.cumsum(img, 0) + dst[0:radius+1, :, ...] = c_sum[radius:2*radius+1, :, ...] + dst[radius+1:r-radius, :, ...] = c_sum[2*radius+1:r, :, ...] - c_sum[0:r-2*radius-1, :, ...] + dst[r-radius:r, :, ...] = np.tile(c_sum[r-1:r, :, ...], s) - c_sum[r-2*radius-1:r-radius-1, :, ...] + + s = [1, radius] + c_sum = np.cumsum(dst, 1) + dst[:, 0:radius+1, ...] = c_sum[:, radius:2*radius+1, ...] + dst[:, radius+1:c-radius, ...] = c_sum[:, 2*radius+1 : c, ...] - c_sum[:, 0 : c-2*radius-1, ...] + dst[:, c-radius: c, ...] = np.tile(c_sum[:, c-1:c, ...], s) - c_sum[:, c-2*radius-1 : c-radius-1, ...] + + return dst + + +def guided_filter(img, guide, radius, eps): + (r, c) = img.shape + + CNT = box(np.ones([r, c]), radius) + + mean_img = box(img, radius) / CNT + mean_guide = box(guide, radius) / CNT + + a = ((box(img * guide, radius) / CNT) - mean_img * mean_guide) / (((box(img * img, radius) / CNT) - mean_img * mean_img) + eps) + b = mean_guide - a * mean_img + + return (box(a, radius) / CNT) * img + (box(b, radius) / CNT) diff --git a/opendm/skyremoval/skyfilter.py b/opendm/skyremoval/skyfilter.py new file mode 100644 index 00000000..d82f4b7d --- /dev/null +++ b/opendm/skyremoval/skyfilter.py @@ -0,0 +1,103 @@ + +import time +import numpy as np +import cv2 +import os +import onnx +import onnxruntime as ort +from .guidedfilter import guided_filter +from opendm import log +from threading import Lock + +mutex = Lock() + +# Use GPU if it is available, otherwise CPU +provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider" + +class SkyFilter(): + + def __init__(self, model, width = 384, height = 384): + + self.model = model + self.width, self.height = width, height + + log.ODM_INFO(' ?> Using provider %s' % provider) + self.load_model() + + + def load_model(self): + log.ODM_INFO(' -> Loading the model') + onnx_model = onnx.load(self.model) + + # Check the model + try: + onnx.checker.check_model(onnx_model) + except onnx.checker.ValidationError as e: + log.ODM_INFO(' !> The model is invalid: %s' % e) + raise + else: + log.ODM_INFO(' ?> The model is valid!') + + self.session = ort.InferenceSession(self.model, providers=[provider]) + + + def get_mask(self, img): + + height, width, c = img.shape + + # Resize image to fit the model input + new_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA) + new_img = np.array(new_img, dtype=np.float32) + + # Input vector for onnx model + input_v = np.expand_dims(new_img.transpose((2, 0, 1)), axis=0) + ort_inputs = {self.session.get_inputs()[0].name: input_v} + + # Run the model + with mutex: + ort_outs = self.session.run(None, ort_inputs) + + # Get the output + output = np.array(ort_outs) + output = output[0][0].transpose((1, 2, 0)) + output = cv2.resize(output, (width, height), interpolation=cv2.INTER_LANCZOS4) + output = np.array([output, output, output]).transpose((1, 2, 0)) + output = np.clip(output, a_max=1.0, a_min=0.0) + + return self.refine(output, img) + + + def refine(self, pred, img): + guided_filter_radius, guided_filter_eps = 20, 0.01 + refined = guided_filter(img[:,:,2], pred[:,:,0], guided_filter_radius, guided_filter_eps) + + res = np.clip(refined, a_min=0, a_max=1) + + # Convert res to CV_8UC1 + res = np.array(res * 255., dtype=np.uint8) + + # Thresholding + res = cv2.threshold(res, 127, 255, cv2.THRESH_BINARY_INV)[1] + + return res + + + def run_img(self, img_path, dest): + + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + if img is None: + return None + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = np.array(img / 255., dtype=np.float32) + + mask = self.get_mask(img) + + img_name = os.path.basename(img_path) + fpath = os.path.join(dest, img_name) + + fname, _ = os.path.splitext(fpath) + mask_name = fname + '_mask.png' + cv2.imwrite(mask_name, mask) + + return mask_name diff --git a/requirements.txt b/requirements.txt index 474059be..70c18b31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,5 @@ scipy==1.5.4 xmltodict==0.12.0 fpdf2==2.4.6 Shapely==1.7.1 +onnx==1.12.0 +onnxruntime==1.11.1 \ No newline at end of file diff --git a/stages/dataset.py b/stages/dataset.py index 75a1dc87..916cdb17 100644 --- a/stages/dataset.py +++ b/stages/dataset.py @@ -11,6 +11,9 @@ from opendm.geo import GeoFile from shutil import copyfile from opendm import progress from opendm import boundary +from opendm import ai +from opendm.skyremoval.skyfilter import SkyFilter +from opendm.concurrency import parallel_map def save_images_database(photos, database_file): with open(database_file, 'w') as f: @@ -113,7 +116,7 @@ class ODMLoadDatasetStage(types.ODM_Stage): try: p = types.ODM_Photo(f) p.set_mask(find_mask(f, masks)) - photos += [p] + photos.append(p) dataset_list.write(photos[-1].filename + '\n') except PhotoCorruptedException: log.ODM_WARNING("%s seems corrupted and will not be used" % os.path.basename(f)) @@ -145,6 +148,49 @@ class ODMLoadDatasetStage(types.ODM_Stage): for p in photos: p.override_camera_projection(args.camera_lens) + # Automatic sky removal + if args.sky_removal: + # For each image that : + # - Doesn't already have a mask, AND + # - Is not nadir (or if orientation info is missing), AND + # - There are no spaces in the image filename (OpenSfM requirement) + # Automatically generate a sky mask + + # Generate list of sky images + sky_images = [] + for p in photos: + if p.mask is None and (p.pitch is None or (-10 > p.pitch > 10)) and (not " " in p.filename): + sky_images.append({'file': os.path.join(images_dir, p.filename), 'p': p}) + + if len(sky_images) > 0: + log.ODM_INFO("Automatically generating sky masks for %s images" % len(sky_images)) + model = ai.get_model("skyremoval", "https://github.com/OpenDroneMap/SkyRemoval/releases/download/v1.0.5/model.zip", "v1.0.5") + if model is not None: + sf = SkyFilter(model=model) + + def parallel_sky_filter(item): + try: + mask_file = sf.run_img(item['file'], images_dir) + + # Check and set + if mask_file is not None and os.path.isfile(mask_file): + item['p'].set_mask(os.path.basename(mask_file)) + log.ODM_INFO("Wrote %s" % os.path.basename(mask_file)) + else: + log.ODM_WARNING("Cannot generate mask for %s" % img) + except Exception as e: + log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e))) + + parallel_map(parallel_sky_filter, sky_images, max_workers=args.max_concurrency) + + log.ODM_INFO("Sky masks generation completed!") + else: + log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)") + else: + log.ODM_WARNING("No images suitable for sky mask generation detected (are they all nadir?)") + + # End sky removal + # Save image database for faster restart save_images_database(photos, images_database_file) else: