Rewrote parts of LRE

pull/979/head
Piero Toffanin 2019-05-12 13:41:39 -04:00
rodzic 14855d010c
commit 75d80cf6c0
2 zmienionych plików z 141 dodań i 93 usunięć

Wyświetl plik

@ -62,12 +62,10 @@ class LocalRemoteExecutor:
# Shared variables across threads # Shared variables across threads
class nonloc: class nonloc:
error = None error = None
local_is_processing = False
semaphore = None semaphore = None
handle_result_mutex = threading.Lock() calculate_task_limit_lock = threading.Lock()
unfinished_tasks = AtomicCounter(len(self.project_paths)) finished_tasks = AtomicCounter(0)
node_task_limit = AtomicCounter(0)
# Create queue # Create queue
q = queue.Queue() q = queue.Queue()
@ -75,23 +73,28 @@ class LocalRemoteExecutor:
log.ODM_DEBUG("LRE: Adding to queue %s" % pp) log.ODM_DEBUG("LRE: Adding to queue %s" % pp)
q.put(taskClass(pp, self.node, self.params)) q.put(taskClass(pp, self.node, self.params))
def cleanup_remote_tasks(): def remove_task_safe(task):
if self.params['tasks']:
log.ODM_WARNING("LRE: Attempting to cleanup remote tasks")
else:
log.ODM_WARNING("LRE: No remote tasks to cleanup")
for task in self.params['tasks']:
try: try:
removed = task.remove() removed = task.remove()
except exceptions.OdmError: except exceptions.OdmError:
removed = False removed = False
log.ODM_DEBUG("Removing remote task %s... %s" % (task.uuid, 'OK' if removed else 'NO')) return removed
def cleanup_remote_tasks():
if self.params['tasks']:
log.ODM_WARNING("LRE: Attempting to cleanup remote tasks")
else:
log.ODM_INFO("LRE: No remote tasks to cleanup")
for task in self.params['tasks']:
log.ODM_DEBUG("Removing remote task %s... %s" % (task.uuid, 'OK' if remove_task_safe(task) else 'NO'))
def handle_result(task, local, error = None, partial=False): def handle_result(task, local, error = None, partial=False):
try: def cleanup_remote():
handle_result_mutex.acquire() if not partial and task.remote_task:
acquire_semaphore_on_exit = False log.ODM_DEBUG("Cleaning up remote task (%s)... %s" % (task.remote_task.uuid, 'OK' if remove_task_safe(task.remote_task) else 'NO'))
self.params['tasks'].remove(task.remote_task)
task.remote_task = None
if error: if error:
log.ODM_WARNING("LRE: %s failed with: %s" % (task, str(error))) log.ODM_WARNING("LRE: %s failed with: %s" % (task, str(error)))
@ -102,13 +105,24 @@ class LocalRemoteExecutor:
if str(error) == "Child was terminated by signal 15": if str(error) == "Child was terminated by signal 15":
system.exit_gracefully() system.exit_gracefully()
if isinstance(error, NodeTaskLimitReachedException) and not nonloc.semaphore and node_task_limit.value > 0: if isinstance(error, NodeTaskLimitReachedException) and not nonloc.semaphore:
sem_value = max(1, node_task_limit.value) # Estimate the maximum number of tasks based on how many tasks
# are currently running
with calculate_task_limit_lock:
node_task_limit = 0
for t in self.params['tasks']:
try:
info = t.info()
if info.status == TaskStatus.RUNNING:
node_task_limit += 1
except exceptions.OdmError:
pass
sem_value = max(1, node_task_limit)
nonloc.semaphore = threading.Semaphore(sem_value) nonloc.semaphore = threading.Semaphore(sem_value)
log.ODM_DEBUG("LRE: Node task limit reached. Setting semaphore to %s" % sem_value) log.ODM_DEBUG("LRE: Node task limit reached. Setting semaphore to %s" % sem_value)
for i in range(sem_value): for i in range(sem_value):
nonloc.semaphore.acquire() nonloc.semaphore.acquire()
acquire_semaphore_on_exit = True
# Retry, but only if the error is not related to a task failure # Retry, but only if the error is not related to a task failure
if task.retries < task.max_retries and not isinstance(error, exceptions.TaskFailedError): if task.retries < task.max_retries and not isinstance(error, exceptions.TaskFailedError):
@ -118,35 +132,46 @@ class LocalRemoteExecutor:
if not isinstance(error, NodeTaskLimitReachedException): if not isinstance(error, NodeTaskLimitReachedException):
task.retries += 1 task.retries += 1
task.wait_until = datetime.datetime.now() + datetime.timedelta(seconds=task.retries * task.retry_timeout) task.wait_until = datetime.datetime.now() + datetime.timedelta(seconds=task.retries * task.retry_timeout)
cleanup_remote()
q.task_done()
log.ODM_DEBUG("LRE: Re-queueing %s (retries: %s)" % (task, task.retries)) log.ODM_DEBUG("LRE: Re-queueing %s (retries: %s)" % (task, task.retries))
q.put(task) q.put(task)
return
else: else:
nonloc.error = error nonloc.error = error
unfinished_tasks.increment(-1) finished_tasks.increment()
else: else:
if not local and not partial:
node_task_limit.increment(-1)
if not partial: if not partial:
log.ODM_INFO("LRE: %s finished successfully" % task) log.ODM_INFO("LRE: %s finished successfully" % task)
unfinished_tasks.increment(-1) finished_tasks.increment()
if local: cleanup_remote()
nonloc.local_is_processing = False
if not task.finished: if not local and not partial and nonloc.semaphore: nonloc.semaphore.release()
if not acquire_semaphore_on_exit and nonloc.semaphore: nonloc.semaphore.release() if not partial: q.task_done()
task.finished = True
q.task_done()
finally:
handle_result_mutex.release()
if acquire_semaphore_on_exit and nonloc.semaphore:
log.ODM_INFO("LRE: Waiting...")
nonloc.semaphore.acquire()
def worker(): def local_worker():
while True: while True:
# If we've found a limit on the maximum number of tasks # Block until a new queue item is available
task = q.get()
if task is None or nonloc.error is not None:
q.task_done()
break
# Process local
try:
task.process(True, handle_result)
except Exception as e:
handle_result(task, True, e)
def remote_worker():
while True:
had_semaphore = bool(nonloc.semaphore)
# If we've found an estimate of the limit on the maximum number of tasks
# a node can process, we block until some tasks have completed # a node can process, we block until some tasks have completed
if nonloc.semaphore: nonloc.semaphore.acquire() if nonloc.semaphore: nonloc.semaphore.acquire()
@ -158,34 +183,34 @@ class LocalRemoteExecutor:
if nonloc.semaphore: nonloc.semaphore.release() if nonloc.semaphore: nonloc.semaphore.release()
break break
task.finished = False # Special case in which we've just created a semaphore
if not had_semaphore and nonloc.semaphore:
log.ODM_INFO("Just found semaphore, sending %s back to the queue" % task)
q.put(task)
q.task_done()
continue
if not nonloc.local_is_processing or not self.node_online:
# Process local
try:
nonloc.local_is_processing = True
task.process(True, handle_result)
except Exception as e:
handle_result(task, True, e)
else:
# Process remote # Process remote
try: try:
task.process(False, handle_result) task.process(False, handle_result)
node_task_limit.increment() # Called after upload, but before processing is started
except Exception as e: except Exception as e:
handle_result(task, False, e) handle_result(task, False, e)
# Create queue thread # Create queue thread
t = threading.Thread(target=worker) local_thread = threading.Thread(target=local_worker)
if self.node_online:
remote_thread = threading.Thread(target=remote_worker)
system.add_cleanup_callback(cleanup_remote_tasks) system.add_cleanup_callback(cleanup_remote_tasks)
# Start worker process # Start workers
t.start() local_thread.start()
if self.node_online:
remote_thread.start()
# block until all tasks are done (or CTRL+C) # block until all tasks are done (or CTRL+C)
try: try:
while unfinished_tasks.value > 0: while finished_tasks.value < len(self.project_paths):
time.sleep(0.5) time.sleep(0.5)
except KeyboardInterrupt: except KeyboardInterrupt:
log.ODM_WARNING("LRE: CTRL+C") log.ODM_WARNING("LRE: CTRL+C")
@ -194,15 +219,20 @@ class LocalRemoteExecutor:
# stop workers # stop workers
if nonloc.semaphore: nonloc.semaphore.release() if nonloc.semaphore: nonloc.semaphore.release()
q.put(None) q.put(None)
if self.node_online:
q.put(None)
# Wait for queue thread # Wait for queue thread
t.join() local_thread.join()
if self.node_online:
remote_thread.join()
# Wait for all remains threads # Wait for all remains threads
for thrds in self.params['threads']: for thrds in self.params['threads']:
thrds.join() thrds.join()
system.remove_cleanup_callback(cleanup_remote_tasks) system.remove_cleanup_callback(cleanup_remote_tasks)
cleanup_remote_tasks()
if nonloc.error is not None: if nonloc.error is not None:
# Try not to leak access token # Try not to leak access token
@ -224,7 +254,7 @@ class Task:
self.max_retries = max_retries self.max_retries = max_retries
self.retries = 0 self.retries = 0
self.retry_timeout = retry_timeout self.retry_timeout = retry_timeout
self.finished = False self.remote_task = None
def process(self, local, done): def process(self, local, done):
def handle_result(error = None, partial=False): def handle_result(error = None, partial=False):
@ -233,9 +263,7 @@ class Task:
log.ODM_INFO("LRE: About to process %s %s" % (self, 'locally' if local else 'remotely')) log.ODM_INFO("LRE: About to process %s %s" % (self, 'locally' if local else 'remotely'))
if local: if local:
t = threading.Thread(target=self._process_local, args=(handle_result, )) self._process_local(handle_result) # Block until complete
self.params['threads'].append(t)
t.start()
else: else:
now = datetime.datetime.now() now = datetime.datetime.now()
if self.wait_until > now: if self.wait_until > now:
@ -313,6 +341,7 @@ class Task:
progress_callback=print_progress, progress_callback=print_progress,
skip_post_processing=True, skip_post_processing=True,
outputs=outputs) outputs=outputs)
self.remote_task = task
# Cleanup seed file # Cleanup seed file
os.remove(seed_file) os.remove(seed_file)
@ -332,7 +361,6 @@ class Task:
# stop the process and re-add the task to the queue. # stop the process and re-add the task to the queue.
if info.status == TaskStatus.QUEUED: if info.status == TaskStatus.QUEUED:
log.ODM_WARNING("LRE: %s (%s) turned from RUNNING to QUEUED. Re-adding to back of the queue." % (self, task.uuid)) log.ODM_WARNING("LRE: %s (%s) turned from RUNNING to QUEUED. Re-adding to back of the queue." % (self, task.uuid))
task.remove()
raise NodeTaskLimitReachedException("Delayed task limit reached") raise NodeTaskLimitReachedException("Delayed task limit reached")
elif info.status == TaskStatus.RUNNING: elif info.status == TaskStatus.RUNNING:
# Print a status message once in a while # Print a status message once in a while

Wyświetl plik

@ -3,6 +3,7 @@ import unittest
import threading import threading
from opendm.remote import LocalRemoteExecutor, Task, NodeTaskLimitReachedException from opendm.remote import LocalRemoteExecutor, Task, NodeTaskLimitReachedException
from pyodm import Node from pyodm import Node
from pyodm.types import TaskStatus
class TestRemote(unittest.TestCase): class TestRemote(unittest.TestCase):
def setUp(self): def setUp(self):
@ -26,29 +27,48 @@ class TestRemote(unittest.TestCase):
MAX_QUEUE = 2 MAX_QUEUE = 2
class nonloc: class nonloc:
local_task_check = False local_task_check = False
remote_queue = 0 remote_queue = 1
class OdmTaskMock:
def __init__(self, running, queue_num):
self.running = running
self.queue_num = queue_num
self.uuid = 'xxxxx-xxxxx-xxxxx-xxxxx-xxxx' + str(queue_num)
def info(self):
class StatusMock:
status = TaskStatus.RUNNING if self.running else TaskStatus.QUEUED
return StatusMock()
def remove(self):
return True
class TaskMock(Task): class TaskMock(Task):
def process_local(self): def process_local(self):
# First task should be submodel_0000 # First task should be 0000 or 0001
if not nonloc.local_task_check: nonloc.local_task_check = self.project_path.endswith("0000") if not nonloc.local_task_check: nonloc.local_task_check = self.project_path.endswith("0000") or self.project_path.endswith("0001")
time.sleep(3) time.sleep(1)
def process_remote(self, done): def process_remote(self, done):
time.sleep(0.2) time.sleep(0.05) # file upload
self.remote_task = OdmTaskMock(nonloc.remote_queue <= MAX_QUEUE, nonloc.remote_queue)
self.params['tasks'].append(self.remote_task)
nonloc.remote_queue += 1
# Upload successful # Upload successful
done(error=None, partial=True) done(error=None, partial=True)
# Async processing # Async processing
def monitor(): def monitor():
nonloc.remote_queue += 1 time.sleep(0.2)
time.sleep(0.3)
try: try:
if nonloc.remote_queue > MAX_QUEUE: if self.remote_task.queue_num > MAX_QUEUE:
nonloc.remote_queue = 0 nonloc.remote_queue -= 1
raise NodeTaskLimitReachedException("Delayed task limit reached") raise NodeTaskLimitReachedException("Delayed task limit reached")
time.sleep(0.5)
nonloc.remote_queue -= 1
done() done()
except Exception as e: except Exception as e:
done(e) done(e)