From fe4a3c62605607b044d124b6fd076190352d27ed Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Sun, 12 May 2019 13:41:39 -0400 Subject: [PATCH] Rewrote parts of LRE Former-commit-id: 75d80cf6c073cd9b255e73bb61e5914f9c729c69 --- opendm/remote.py | 196 ++++++++++++++++++++++++------------------- tests/test_remote.py | 38 +++++++-- 2 files changed, 141 insertions(+), 93 deletions(-) diff --git a/opendm/remote.py b/opendm/remote.py index 0a58c639..36ae99e4 100644 --- a/opendm/remote.py +++ b/opendm/remote.py @@ -62,91 +62,116 @@ class LocalRemoteExecutor: # Shared variables across threads class nonloc: error = None - local_is_processing = False semaphore = None - handle_result_mutex = threading.Lock() - unfinished_tasks = AtomicCounter(len(self.project_paths)) - node_task_limit = AtomicCounter(0) + calculate_task_limit_lock = threading.Lock() + finished_tasks = AtomicCounter(0) # Create queue q = queue.Queue() for pp in self.project_paths: log.ODM_DEBUG("LRE: Adding to queue %s" % pp) q.put(taskClass(pp, self.node, self.params)) + + def remove_task_safe(task): + try: + removed = task.remove() + except exceptions.OdmError: + removed = False + return removed def cleanup_remote_tasks(): if self.params['tasks']: log.ODM_WARNING("LRE: Attempting to cleanup remote tasks") else: - log.ODM_WARNING("LRE: No remote tasks to cleanup") + log.ODM_INFO("LRE: No remote tasks to cleanup") for task in self.params['tasks']: - try: - removed = task.remove() - except exceptions.OdmError: - removed = False - log.ODM_DEBUG("Removing remote task %s... %s" % (task.uuid, 'OK' if removed else 'NO')) + log.ODM_DEBUG("Removing remote task %s... %s" % (task.uuid, 'OK' if remove_task_safe(task) else 'NO')) def handle_result(task, local, error = None, partial=False): - try: - handle_result_mutex.acquire() - acquire_semaphore_on_exit = False + def cleanup_remote(): + if not partial and task.remote_task: + log.ODM_DEBUG("Cleaning up remote task (%s)... %s" % (task.remote_task.uuid, 'OK' if remove_task_safe(task.remote_task) else 'NO')) + self.params['tasks'].remove(task.remote_task) + task.remote_task = None - if error: - log.ODM_WARNING("LRE: %s failed with: %s" % (task, str(error))) - - # Special case in which the error is caused by a SIGTERM signal - # this means a local processing was terminated either by CTRL+C or - # by canceling the task. - if str(error) == "Child was terminated by signal 15": - system.exit_gracefully() + if error: + log.ODM_WARNING("LRE: %s failed with: %s" % (task, str(error))) + + # Special case in which the error is caused by a SIGTERM signal + # this means a local processing was terminated either by CTRL+C or + # by canceling the task. + if str(error) == "Child was terminated by signal 15": + system.exit_gracefully() - if isinstance(error, NodeTaskLimitReachedException) and not nonloc.semaphore and node_task_limit.value > 0: - sem_value = max(1, node_task_limit.value) + if isinstance(error, NodeTaskLimitReachedException) and not nonloc.semaphore: + # Estimate the maximum number of tasks based on how many tasks + # are currently running + with calculate_task_limit_lock: + node_task_limit = 0 + for t in self.params['tasks']: + try: + info = t.info() + if info.status == TaskStatus.RUNNING: + node_task_limit += 1 + except exceptions.OdmError: + pass + + sem_value = max(1, node_task_limit) nonloc.semaphore = threading.Semaphore(sem_value) log.ODM_DEBUG("LRE: Node task limit reached. Setting semaphore to %s" % sem_value) for i in range(sem_value): nonloc.semaphore.acquire() - acquire_semaphore_on_exit = True - # Retry, but only if the error is not related to a task failure - if task.retries < task.max_retries and not isinstance(error, exceptions.TaskFailedError): - # Put task back in queue - # Don't increment the retry counter if this task simply reached the task - # limit count. - if not isinstance(error, NodeTaskLimitReachedException): - task.retries += 1 - task.wait_until = datetime.datetime.now() + datetime.timedelta(seconds=task.retries * task.retry_timeout) - log.ODM_DEBUG("LRE: Re-queueing %s (retries: %s)" % (task, task.retries)) - q.put(task) - else: - nonloc.error = error - unfinished_tasks.increment(-1) - else: - if not local and not partial: - node_task_limit.increment(-1) - - if not partial: - log.ODM_INFO("LRE: %s finished successfully" % task) - unfinished_tasks.increment(-1) - - if local: - nonloc.local_is_processing = False - - if not task.finished: - if not acquire_semaphore_on_exit and nonloc.semaphore: nonloc.semaphore.release() - task.finished = True + # Retry, but only if the error is not related to a task failure + if task.retries < task.max_retries and not isinstance(error, exceptions.TaskFailedError): + # Put task back in queue + # Don't increment the retry counter if this task simply reached the task + # limit count. + if not isinstance(error, NodeTaskLimitReachedException): + task.retries += 1 + task.wait_until = datetime.datetime.now() + datetime.timedelta(seconds=task.retries * task.retry_timeout) + cleanup_remote() q.task_done() - finally: - handle_result_mutex.release() - if acquire_semaphore_on_exit and nonloc.semaphore: - log.ODM_INFO("LRE: Waiting...") - nonloc.semaphore.acquire() - def worker(): + log.ODM_DEBUG("LRE: Re-queueing %s (retries: %s)" % (task, task.retries)) + q.put(task) + return + else: + nonloc.error = error + finished_tasks.increment() + else: + if not partial: + log.ODM_INFO("LRE: %s finished successfully" % task) + finished_tasks.increment() + + cleanup_remote() + + if not local and not partial and nonloc.semaphore: nonloc.semaphore.release() + if not partial: q.task_done() + + def local_worker(): while True: - # If we've found a limit on the maximum number of tasks + # Block until a new queue item is available + task = q.get() + + if task is None or nonloc.error is not None: + q.task_done() + break + + # Process local + try: + task.process(True, handle_result) + except Exception as e: + handle_result(task, True, e) + + + def remote_worker(): + while True: + had_semaphore = bool(nonloc.semaphore) + + # If we've found an estimate of the limit on the maximum number of tasks # a node can process, we block until some tasks have completed if nonloc.semaphore: nonloc.semaphore.acquire() @@ -157,35 +182,35 @@ class LocalRemoteExecutor: q.task_done() if nonloc.semaphore: nonloc.semaphore.release() break - - task.finished = False - if not nonloc.local_is_processing or not self.node_online: - # Process local - try: - nonloc.local_is_processing = True - task.process(True, handle_result) - except Exception as e: - handle_result(task, True, e) - else: - # Process remote - try: - task.process(False, handle_result) - node_task_limit.increment() # Called after upload, but before processing is started - except Exception as e: - handle_result(task, False, e) + # Special case in which we've just created a semaphore + if not had_semaphore and nonloc.semaphore: + log.ODM_INFO("Just found semaphore, sending %s back to the queue" % task) + q.put(task) + q.task_done() + continue + + # Process remote + try: + task.process(False, handle_result) + except Exception as e: + handle_result(task, False, e) # Create queue thread - t = threading.Thread(target=worker) + local_thread = threading.Thread(target=local_worker) + if self.node_online: + remote_thread = threading.Thread(target=remote_worker) system.add_cleanup_callback(cleanup_remote_tasks) - # Start worker process - t.start() + # Start workers + local_thread.start() + if self.node_online: + remote_thread.start() # block until all tasks are done (or CTRL+C) try: - while unfinished_tasks.value > 0: + while finished_tasks.value < len(self.project_paths): time.sleep(0.5) except KeyboardInterrupt: log.ODM_WARNING("LRE: CTRL+C") @@ -194,15 +219,20 @@ class LocalRemoteExecutor: # stop workers if nonloc.semaphore: nonloc.semaphore.release() q.put(None) + if self.node_online: + q.put(None) # Wait for queue thread - t.join() + local_thread.join() + if self.node_online: + remote_thread.join() # Wait for all remains threads for thrds in self.params['threads']: thrds.join() system.remove_cleanup_callback(cleanup_remote_tasks) + cleanup_remote_tasks() if nonloc.error is not None: # Try not to leak access token @@ -224,7 +254,7 @@ class Task: self.max_retries = max_retries self.retries = 0 self.retry_timeout = retry_timeout - self.finished = False + self.remote_task = None def process(self, local, done): def handle_result(error = None, partial=False): @@ -233,9 +263,7 @@ class Task: log.ODM_INFO("LRE: About to process %s %s" % (self, 'locally' if local else 'remotely')) if local: - t = threading.Thread(target=self._process_local, args=(handle_result, )) - self.params['threads'].append(t) - t.start() + self._process_local(handle_result) # Block until complete else: now = datetime.datetime.now() if self.wait_until > now: @@ -313,6 +341,7 @@ class Task: progress_callback=print_progress, skip_post_processing=True, outputs=outputs) + self.remote_task = task # Cleanup seed file os.remove(seed_file) @@ -332,7 +361,6 @@ class Task: # stop the process and re-add the task to the queue. if info.status == TaskStatus.QUEUED: log.ODM_WARNING("LRE: %s (%s) turned from RUNNING to QUEUED. Re-adding to back of the queue." % (self, task.uuid)) - task.remove() raise NodeTaskLimitReachedException("Delayed task limit reached") elif info.status == TaskStatus.RUNNING: # Print a status message once in a while diff --git a/tests/test_remote.py b/tests/test_remote.py index cc32d5d1..69b5c323 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -3,6 +3,7 @@ import unittest import threading from opendm.remote import LocalRemoteExecutor, Task, NodeTaskLimitReachedException from pyodm import Node +from pyodm.types import TaskStatus class TestRemote(unittest.TestCase): def setUp(self): @@ -26,29 +27,48 @@ class TestRemote(unittest.TestCase): MAX_QUEUE = 2 class nonloc: local_task_check = False - remote_queue = 0 + remote_queue = 1 + + class OdmTaskMock: + def __init__(self, running, queue_num): + self.running = running + self.queue_num = queue_num + self.uuid = 'xxxxx-xxxxx-xxxxx-xxxxx-xxxx' + str(queue_num) + + def info(self): + class StatusMock: + status = TaskStatus.RUNNING if self.running else TaskStatus.QUEUED + return StatusMock() + + def remove(self): + return True class TaskMock(Task): def process_local(self): - # First task should be submodel_0000 - if not nonloc.local_task_check: nonloc.local_task_check = self.project_path.endswith("0000") - time.sleep(3) + # First task should be 0000 or 0001 + if not nonloc.local_task_check: nonloc.local_task_check = self.project_path.endswith("0000") or self.project_path.endswith("0001") + time.sleep(1) def process_remote(self, done): - time.sleep(0.2) + time.sleep(0.05) # file upload + + self.remote_task = OdmTaskMock(nonloc.remote_queue <= MAX_QUEUE, nonloc.remote_queue) + self.params['tasks'].append(self.remote_task) + nonloc.remote_queue += 1 # Upload successful done(error=None, partial=True) # Async processing def monitor(): - nonloc.remote_queue += 1 - time.sleep(0.3) + time.sleep(0.2) try: - if nonloc.remote_queue > MAX_QUEUE: - nonloc.remote_queue = 0 + if self.remote_task.queue_num > MAX_QUEUE: + nonloc.remote_queue -= 1 raise NodeTaskLimitReachedException("Delayed task limit reached") + time.sleep(0.5) + nonloc.remote_queue -= 1 done() except Exception as e: done(e)