AI-powered automatic sky removal

pull/1502/head
Piero Toffanin 2022-07-14 16:02:59 -04:00
rodzic d61d0e0cbe
commit b584459fc9
10 zmienionych plików z 416 dodań i 3 usunięć

1
.gitignore vendored
Wyświetl plik

@ -27,3 +27,4 @@ settings.yaml
.setupdevenv .setupdevenv
__pycache__ __pycache__
*.snap *.snap
storage/

51
opendm/ai.py 100644
Wyświetl plik

@ -0,0 +1,51 @@
import os
from opendm.net import download
from opendm import log
import zipfile
import time
def get_model(namespace, url, version, name = "model.onnx"):
version = version.replace(".", "_")
base_dir = os.path.join(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")), "storage", "models")
namespace_dir = os.path.join(base_dir, namespace)
versioned_dir = os.path.join(namespace_dir, version)
if not os.path.isdir(versioned_dir):
os.makedirs(versioned_dir, exist_ok=True)
# Check if we need to download it
model_file = os.path.join(versioned_dir, name)
if not os.path.isfile(model_file):
log.ODM_INFO("Downloading AI model from %s ..." % url)
last_update = 0
def callback(progress):
nonlocal last_update
time_has_elapsed = time.time() - last_update >= 2
if time_has_elapsed or int(progress) == 100:
log.ODM_INFO("Downloading: %s%%" % int(progress))
last_update = time.time()
try:
downloaded_file = download(url, versioned_dir, progress_callback=callback)
except Exception as e:
log.ODM_WARNING("Cannot download %s: %s" % (url, str(e)))
return None
if os.path.basename(downloaded_file).lower().endswith(".zip"):
log.ODM_INFO("Extracting %s ..." % downloaded_file)
with zipfile.ZipFile(downloaded_file, 'r') as z:
z.extractall(versioned_dir)
os.remove(downloaded_file)
if not os.path.isfile(model_file):
log.ODM_WARNING("Cannot find %s (is the URL to the AI model correct?)" % model_file)
return None
else:
return model_file
else:
return model_file

Wyświetl plik

@ -25,7 +25,7 @@ def get_max_memory_mb(minimum = 100, use_at_most = 0.5):
""" """
return max(minimum, (virtual_memory().available / 1024 / 1024) * use_at_most) return max(minimum, (virtual_memory().available / 1024 / 1024) * use_at_most)
def parallel_map(func, items, max_workers=1, single_thread_fallback=True): def parallel_map(func, items, max_workers=1, single_thread_fallback=True, copy_queue_items=True):
""" """
Our own implementation for parallel processing Our own implementation for parallel processing
which handles gracefully CTRL+C and reverts to which handles gracefully CTRL+C and reverts to
@ -66,7 +66,10 @@ def parallel_map(func, items, max_workers=1, single_thread_fallback=True):
i = 1 i = 1
for t in items: for t in items:
pq.put((i, t.copy())) if copy_queue_items:
pq.put((i, t.copy()))
else:
pq.put((i, t))
i += 1 i += 1
def stop_workers(): def stop_workers():

Wyświetl plik

@ -237,6 +237,12 @@ def config(argv=None, parser=None):
'Can be one of: %(choices)s. Default: ' 'Can be one of: %(choices)s. Default: '
'%(default)s')) '%(default)s'))
parser.add_argument('--sky-removal',
action=StoreTrue,
nargs=0,
default=False,
help='Automatically compute image masks using AI to remove the sky. Default: %(default)s')
parser.add_argument('--use-3dmesh', parser.add_argument('--use-3dmesh',
action=StoreTrue, action=StoreTrue,
nargs=0, nargs=0,

164
opendm/net.py 100644
Wyświetl plik

@ -0,0 +1,164 @@
import requests
import math
import os
import time
try:
import queue
except ImportError:
import Queue as queue
import threading
from pyodm.utils import AtomicCounter
from pyodm.exceptions import RangeNotAvailableError, OdmError
from urllib3.exceptions import ReadTimeoutError
def download(url, destination, progress_callback=None, parallel_downloads=16, parallel_chunks_size=10, timeout=30):
"""Download files in parallel (download accelerator)
Args:
url (str): URL to download
destination (str): directory where to download file. If the directory does not exist, it will be created.
progress_callback (function): an optional callback with one parameter, the download progress percentage.
parallel_downloads (int): maximum number of parallel downloads if the node supports http range.
parallel_chunks_size (int): size in MB of chunks for parallel downloads
timeout (int): seconds before timing out
Returns:
str: path to file
"""
if not os.path.exists(destination):
os.makedirs(destination, exist_ok=True)
try:
download_stream = requests.get(url, timeout=timeout, stream=True)
headers = download_stream.headers
output_path = os.path.join(destination, os.path.basename(url))
# Keep track of download progress (if possible)
content_length = download_stream.headers.get('content-length')
total_length = int(content_length) if content_length is not None else None
downloaded = 0
chunk_size = int(parallel_chunks_size * 1024 * 1024)
use_fallback = False
accept_ranges = headers.get('accept-ranges')
# Can we do parallel downloads?
if accept_ranges is not None and accept_ranges.lower() == 'bytes' and total_length is not None and total_length > chunk_size and parallel_downloads > 1:
num_chunks = int(math.ceil(total_length / float(chunk_size)))
num_workers = parallel_downloads
class nonloc:
completed_chunks = AtomicCounter(0)
merge_chunks = [False] * num_chunks
error = None
def merge():
current_chunk = 0
with open(output_path, "wb") as out_file:
while current_chunk < num_chunks and nonloc.error is None:
if nonloc.merge_chunks[current_chunk]:
chunk_file = "%s.part%s" % (output_path, current_chunk)
with open(chunk_file, "rb") as fd:
out_file.write(fd.read())
os.unlink(chunk_file)
current_chunk += 1
else:
time.sleep(0.1)
def worker():
while True:
task = q.get()
part_num, bytes_range = task
if bytes_range is None or nonloc.error is not None:
q.task_done()
break
try:
# Download chunk
res = requests.get(url, stream=True, timeout=timeout, headers={'Range': 'bytes=%s-%s' % bytes_range})
if res.status_code == 206:
with open("%s.part%s" % (output_path, part_num), 'wb') as fd:
bytes_written = 0
try:
for chunk in res.iter_content(4096):
bytes_written += fd.write(chunk)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
raise OdmError(str(e))
if bytes_written != (bytes_range[1] - bytes_range[0] + 1):
# Process again
q.put((part_num, bytes_range))
return
with nonloc.completed_chunks.lock:
nonloc.completed_chunks.value += 1
if progress_callback is not None:
progress_callback(100.0 * nonloc.completed_chunks.value / num_chunks)
nonloc.merge_chunks[part_num] = True
else:
nonloc.error = RangeNotAvailableError()
except OdmError as e:
time.sleep(5)
q.put((part_num, bytes_range))
except Exception as e:
nonloc.error = e
finally:
q.task_done()
q = queue.PriorityQueue()
threads = []
for i in range(num_workers):
t = threading.Thread(target=worker)
t.start()
threads.append(t)
merge_thread = threading.Thread(target=merge)
merge_thread.start()
range_start = 0
for i in range(num_chunks):
range_end = min(range_start + chunk_size - 1, total_length - 1)
q.put((i, (range_start, range_end)))
range_start = range_end + 1
# block until all tasks are done
while not all(nonloc.merge_chunks) and nonloc.error is None:
time.sleep(0.1)
# stop workers
for i in range(len(threads)):
q.put((-1, None))
for t in threads:
t.join()
merge_thread.join()
if nonloc.error is not None:
if isinstance(nonloc.error, RangeNotAvailableError):
use_fallback = True
else:
raise nonloc.error
else:
use_fallback = True
if use_fallback:
# Single connection, boring download
with open(output_path, 'wb') as fd:
for chunk in download_stream.iter_content(4096):
downloaded += len(chunk)
if progress_callback is not None and total_length is not None:
progress_callback((100.0 * float(downloaded) / total_length))
fd.write(chunk)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, ReadTimeoutError) as e:
raise OdmError(e)
return output_path

Wyświetl plik

@ -0,0 +1,37 @@
import numpy as np
# Based on Fast Guided Filter
# Kaiming He, Jian Sun
# https://arxiv.org/abs/1505.00996
def box(img, radius):
dst = np.zeros_like(img)
(r, c) = img.shape
s = [radius, 1]
c_sum = np.cumsum(img, 0)
dst[0:radius+1, :, ...] = c_sum[radius:2*radius+1, :, ...]
dst[radius+1:r-radius, :, ...] = c_sum[2*radius+1:r, :, ...] - c_sum[0:r-2*radius-1, :, ...]
dst[r-radius:r, :, ...] = np.tile(c_sum[r-1:r, :, ...], s) - c_sum[r-2*radius-1:r-radius-1, :, ...]
s = [1, radius]
c_sum = np.cumsum(dst, 1)
dst[:, 0:radius+1, ...] = c_sum[:, radius:2*radius+1, ...]
dst[:, radius+1:c-radius, ...] = c_sum[:, 2*radius+1 : c, ...] - c_sum[:, 0 : c-2*radius-1, ...]
dst[:, c-radius: c, ...] = np.tile(c_sum[:, c-1:c, ...], s) - c_sum[:, c-2*radius-1 : c-radius-1, ...]
return dst
def guided_filter(img, guide, radius, eps):
(r, c) = img.shape
CNT = box(np.ones([r, c]), radius)
mean_img = box(img, radius) / CNT
mean_guide = box(guide, radius) / CNT
a = ((box(img * guide, radius) / CNT) - mean_img * mean_guide) / (((box(img * img, radius) / CNT) - mean_img * mean_img) + eps)
b = mean_guide - a * mean_img
return (box(a, radius) / CNT) * img + (box(b, radius) / CNT)

Wyświetl plik

@ -0,0 +1,103 @@
import time
import numpy as np
import cv2
import os
import onnx
import onnxruntime as ort
from .guidedfilter import guided_filter
from opendm import log
from threading import Lock
mutex = Lock()
# Use GPU if it is available, otherwise CPU
provider = "CUDAExecutionProvider" if "CUDAExecutionProvider" in ort.get_available_providers() else "CPUExecutionProvider"
class SkyFilter():
def __init__(self, model, width = 384, height = 384):
self.model = model
self.width, self.height = width, height
log.ODM_INFO(' ?> Using provider %s' % provider)
self.load_model()
def load_model(self):
log.ODM_INFO(' -> Loading the model')
onnx_model = onnx.load(self.model)
# Check the model
try:
onnx.checker.check_model(onnx_model)
except onnx.checker.ValidationError as e:
log.ODM_INFO(' !> The model is invalid: %s' % e)
raise
else:
log.ODM_INFO(' ?> The model is valid!')
self.session = ort.InferenceSession(self.model, providers=[provider])
def get_mask(self, img):
height, width, c = img.shape
# Resize image to fit the model input
new_img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA)
new_img = np.array(new_img, dtype=np.float32)
# Input vector for onnx model
input_v = np.expand_dims(new_img.transpose((2, 0, 1)), axis=0)
ort_inputs = {self.session.get_inputs()[0].name: input_v}
# Run the model
with mutex:
ort_outs = self.session.run(None, ort_inputs)
# Get the output
output = np.array(ort_outs)
output = output[0][0].transpose((1, 2, 0))
output = cv2.resize(output, (width, height), interpolation=cv2.INTER_LANCZOS4)
output = np.array([output, output, output]).transpose((1, 2, 0))
output = np.clip(output, a_max=1.0, a_min=0.0)
return self.refine(output, img)
def refine(self, pred, img):
guided_filter_radius, guided_filter_eps = 20, 0.01
refined = guided_filter(img[:,:,2], pred[:,:,0], guided_filter_radius, guided_filter_eps)
res = np.clip(refined, a_min=0, a_max=1)
# Convert res to CV_8UC1
res = np.array(res * 255., dtype=np.uint8)
# Thresholding
res = cv2.threshold(res, 127, 255, cv2.THRESH_BINARY_INV)[1]
return res
def run_img(self, img_path, dest):
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img / 255., dtype=np.float32)
mask = self.get_mask(img)
img_name = os.path.basename(img_path)
fpath = os.path.join(dest, img_name)
fname, _ = os.path.splitext(fpath)
mask_name = fname + '_mask.png'
cv2.imwrite(mask_name, mask)
return mask_name

Wyświetl plik

@ -29,3 +29,5 @@ scipy==1.5.4
xmltodict==0.12.0 xmltodict==0.12.0
fpdf2==2.4.6 fpdf2==2.4.6
Shapely==1.7.1 Shapely==1.7.1
onnx==1.12.0
onnxruntime==1.11.1

Wyświetl plik

@ -11,6 +11,9 @@ from opendm.geo import GeoFile
from shutil import copyfile from shutil import copyfile
from opendm import progress from opendm import progress
from opendm import boundary from opendm import boundary
from opendm import ai
from opendm.skyremoval.skyfilter import SkyFilter
from opendm.concurrency import parallel_map
def save_images_database(photos, database_file): def save_images_database(photos, database_file):
with open(database_file, 'w') as f: with open(database_file, 'w') as f:
@ -113,7 +116,7 @@ class ODMLoadDatasetStage(types.ODM_Stage):
try: try:
p = types.ODM_Photo(f) p = types.ODM_Photo(f)
p.set_mask(find_mask(f, masks)) p.set_mask(find_mask(f, masks))
photos += [p] photos.append(p)
dataset_list.write(photos[-1].filename + '\n') dataset_list.write(photos[-1].filename + '\n')
except PhotoCorruptedException: except PhotoCorruptedException:
log.ODM_WARNING("%s seems corrupted and will not be used" % os.path.basename(f)) log.ODM_WARNING("%s seems corrupted and will not be used" % os.path.basename(f))
@ -145,6 +148,49 @@ class ODMLoadDatasetStage(types.ODM_Stage):
for p in photos: for p in photos:
p.override_camera_projection(args.camera_lens) p.override_camera_projection(args.camera_lens)
# Automatic sky removal
if args.sky_removal:
# For each image that :
# - Doesn't already have a mask, AND
# - Is not nadir (or if orientation info is missing), AND
# - There are no spaces in the image filename (OpenSfM requirement)
# Automatically generate a sky mask
# Generate list of sky images
sky_images = []
for p in photos:
if p.mask is None and (p.pitch is None or (-10 > p.pitch > 10)) and (not " " in p.filename):
sky_images.append({'file': os.path.join(images_dir, p.filename), 'p': p})
if len(sky_images) > 0:
log.ODM_INFO("Automatically generating sky masks for %s images" % len(sky_images))
model = ai.get_model("skyremoval", "https://github.com/OpenDroneMap/SkyRemoval/releases/download/v1.0.5/model.zip", "v1.0.5")
if model is not None:
sf = SkyFilter(model=model)
def parallel_sky_filter(item):
try:
mask_file = sf.run_img(item['file'], images_dir)
# Check and set
if mask_file is not None and os.path.isfile(mask_file):
item['p'].set_mask(os.path.basename(mask_file))
log.ODM_INFO("Wrote %s" % os.path.basename(mask_file))
else:
log.ODM_WARNING("Cannot generate mask for %s" % img)
except Exception as e:
log.ODM_WARNING("Cannot generate mask for %s: %s" % (img, str(e)))
parallel_map(parallel_sky_filter, sky_images, max_workers=args.max_concurrency)
log.ODM_INFO("Sky masks generation completed!")
else:
log.ODM_WARNING("Cannot load AI model (you might need to be connected to the internet?)")
else:
log.ODM_WARNING("No images suitable for sky mask generation detected (are they all nadir?)")
# End sky removal
# Save image database for faster restart # Save image database for faster restart
save_images_database(photos, images_database_file) save_images_database(photos, images_database_file)
else: else: