diff --git a/coreplugins/objdetect/api.py b/coreplugins/objdetect/api.py index b3161948..970b78d7 100644 --- a/coreplugins/objdetect/api.py +++ b/coreplugins/objdetect/api.py @@ -6,7 +6,7 @@ from app.plugins.views import TaskView, GetTaskResult, TaskResultOutputError from app.plugins.worker import run_function_async from django.utils.translation import gettext_lazy as _ -def detect(orthophoto, model, progress_callback=None): +def detect(orthophoto, model, classes=None, progress_callback=None): import os from webodm import settings @@ -17,7 +17,7 @@ def detect(orthophoto, model, progress_callback=None): return {'error': "GeoDeep library is missing"} try: - return {'output': gdetect(orthophoto, model, output_type='geojson', max_threads=settings.WORKERS_MAX_THREADS, progress_callback=progress_callback)} + return {'output': gdetect(orthophoto, model, output_type='geojson', classes=classes, max_threads=settings.WORKERS_MAX_THREADS, progress_callback=progress_callback)} except Exception as e: return {'error': str(e)} @@ -31,10 +31,20 @@ class TaskObjDetect(TaskView): orthophoto = os.path.abspath(task.get_asset_download_path("orthophoto.tif")) model = request.data.get('model', 'cars') - if not model in ['cars', 'trees']: + # model --> (modelID, classes) + model_map = { + 'cars': ('cars', None), + 'trees': ('trees', None), + 'athletic': ('aerovision', ['tennis-court', 'track-field', 'soccer-field', 'baseball-field', 'swimming-pool', 'basketball-court']), + 'boats': ('aerovision', ['boat']), + 'planes': ('aerovision', ['plane']), + } + + if not model in model_map: return Response({'error': 'Invalid model'}, status=status.HTTP_200_OK) - celery_task_id = run_function_async(detect, orthophoto, model, with_progress=True).task_id + model_id, classes = model_map[model] + celery_task_id = run_function_async(detect, orthophoto, model_id, classes, with_progress=True).task_id return Response({'celery_task_id': celery_task_id}, status=status.HTTP_200_OK) diff --git a/coreplugins/objdetect/public/ObjDetectPanel.jsx b/coreplugins/objdetect/public/ObjDetectPanel.jsx index 695b9f6b..02c003e6 100644 --- a/coreplugins/objdetect/public/ObjDetectPanel.jsx +++ b/coreplugins/objdetect/public/ObjDetectPanel.jsx @@ -175,7 +175,10 @@ export default class ObjDetectPanel extends React.Component { const { loading, permanentError, objLayer, detecting, model, progress } = this.state; const models = [ {label: _('Cars'), value: 'cars'}, - {label: _('Trees'), value: 'trees'}, + {label: _('Trees'), value: 'trees'}, + {label: _('Athletic Facilities'), value: 'athletic'}, + {label: _('Boats'), value: 'boats'}, + {label: _('Planes'), value: 'planes'} ] let content = ""; diff --git a/requirements.txt b/requirements.txt index d3db17e5..61abce0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ drf-nested-routers==0.11.1 funcsigs==1.0.2 futures==3.1.1 gunicorn==19.8.0 -geodeep==0.9.7 +geodeep==0.9.8 itypes==1.1.0 kombu==4.6.7 Markdown==3.3.4