From 0f2387b5ab5167efcf873f7cf8a6c51ee67d22ab Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Thu, 27 Mar 2025 20:33:36 -0400 Subject: [PATCH] Crop support in objdetect plugin --- coreplugins/objdetect/api.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/coreplugins/objdetect/api.py b/coreplugins/objdetect/api.py index 970b78d7..3bf8c92f 100644 --- a/coreplugins/objdetect/api.py +++ b/coreplugins/objdetect/api.py @@ -6,8 +6,11 @@ from app.plugins.views import TaskView, GetTaskResult, TaskResultOutputError from app.plugins.worker import run_function_async from django.utils.translation import gettext_lazy as _ -def detect(orthophoto, model, classes=None, progress_callback=None): +def detect(orthophoto, model, classes=None, crop=None, progress_callback=None): import os + import subprocess + import shutil + import tempfile from webodm import settings try: @@ -17,6 +20,31 @@ def detect(orthophoto, model, classes=None, progress_callback=None): return {'error': "GeoDeep library is missing"} try: + if crop is not None: + # Make a VRT with the crop area + + gdalwarp_bin = shutil.which("gdalwarp") + if gdalwarp_bin is None: + return {'error': 'Cannot find gdalwarp'} + + tmpdir = os.path.join(settings.MEDIA_TMP, os.path.basename(tempfile.mkdtemp('_objdetect', dir=settings.MEDIA_TMP))) + + crop_geojson = os.path.join(tmpdir, "crop.geojson") + ortho_vrt = os.path.join(tmpdir, "orthophoto.vrt") + with open(crop_geojson, "w", encoding="utf-8") as f: + f.write(crop) + p = subprocess.Popen([gdalwarp_bin, "-cutline", crop_geojson, + '--config', 'GDALWARP_DENSIFY_CUTLINE', 'NO', + '-crop_to_cutline', '-of', 'VRT', + orthophoto, ortho_vrt], cwd=tmpdir, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + out = out.decode('utf-8').strip() + err = err.decode('utf-8').strip() + if p.returncode != 0: + return {'error': f'Error calling gdalwarp: {str(err)}'} + + orthophoto = ortho_vrt + return {'output': gdetect(orthophoto, model, output_type='geojson', classes=classes, max_threads=settings.WORKERS_MAX_THREADS, progress_callback=progress_callback)} except Exception as e: return {'error': str(e)} @@ -44,7 +72,7 @@ class TaskObjDetect(TaskView): return Response({'error': 'Invalid model'}, status=status.HTTP_200_OK) model_id, classes = model_map[model] - celery_task_id = run_function_async(detect, orthophoto, model_id, classes, with_progress=True).task_id + celery_task_id = run_function_async(detect, orthophoto, model_id, classes, task.crop.geojson if task.crop is not None else None, with_progress=True).task_id return Response({'celery_task_id': celery_task_id}, status=status.HTTP_200_OK)