From 209a1b5603972bf54b0c86a7c4cc9f37a95aea9c Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Tue, 7 Feb 2017 11:43:17 -0500 Subject: [PATCH] TestWatch class to monitor and intercept on demand multithreaded calls, changed location of assets during testing --- app/.gitignore | 1 + app/background.py | 26 +++---------- app/tests/test_api_task.py | 26 +++---------- app/tests/test_testwatch.py | 40 +++++++++++++++++++ app/testwatch.py | 77 +++++++++++++++++++++++++++++++++++++ webodm/settings.py | 2 + 6 files changed, 131 insertions(+), 41 deletions(-) create mode 100644 app/.gitignore create mode 100644 app/tests/test_testwatch.py create mode 100644 app/testwatch.py diff --git a/app/.gitignore b/app/.gitignore new file mode 100644 index 00000000..b2e60fb7 --- /dev/null +++ b/app/.gitignore @@ -0,0 +1 @@ +media_test/ diff --git a/app/background.py b/app/background.py index 0d28f424..57b31cb8 100644 --- a/app/background.py +++ b/app/background.py @@ -1,24 +1,11 @@ from threading import Thread + +import logging from django import db from webodm import settings +from app.testwatch import testWatch - -# TODO: design class such that: -# 1. test cases can choose which functions to intercept (prevent from executing) -# 2. test cases can see how many times a function has been called (and with which parameters) -# 3. test cases can pause until a function has been called -class TestWatch: - stats = {} - - def called(self, func, *args, **kwargs): - list = TestWatch.stats[func] if func in TestWatch.stats else [] - list.append({'f': func, 'args': args, 'kwargs': kwargs}) - print(list) - - def clear(self): - TestWatch.stats = {} - -testWatch = TestWatch() +logger = logging.getLogger('app.logger') def background(func): """ @@ -30,9 +17,7 @@ def background(func): if 'background' in kwargs: del kwargs['background'] if background: - if settings.TESTING: - # During testing, intercept all background requests and execute them on the same thread - testWatch.called(func.__name__, *args, **kwargs) + if testWatch.hook_pre(func, *args, **kwargs): return # Create a function that closes all # db connections at the end of the thread @@ -44,6 +29,7 @@ def background(func): ret = func(*args, **kwargs) finally: db.connections.close_all() + testWatch.hook_post(func, *args, **kwargs) return ret t = Thread(target=execute_and_close_db) diff --git a/app/tests/test_api_task.py b/app/tests/test_api_task.py index c5b9f59d..355677c3 100644 --- a/app/tests/test_api_task.py +++ b/app/tests/test_api_task.py @@ -1,18 +1,16 @@ import os import subprocess - import time -from django import db from django.contrib.auth.models import User from rest_framework import status from rest_framework.test import APIClient -from app import scheduler from app.models import Project, Task, ImageUpload from app.tests.classes import BootTransactionTestCase from nodeodm import status_codes from nodeodm.models import ProcessingNode +from app.testwatch import testWatch # We need to test the task API in a TransactionTestCase because # task processing happens on a separate thread, and normal TestCases @@ -133,6 +131,7 @@ class TestApiTask(BootTransactionTestCase): # Neither should an individual tile # Z/X/Y coords are choosen based on node-odm test dataset for orthophoto_tiles/ res = client.get("/api/projects/{}/tasks/{}/tiles/16/16020/42443.png".format(project.id, task.id)) + print(res.status_code) self.assertTrue(res.status_code == status.HTTP_404_NOT_FOUND) # Cannot access a tiles.json we have no access to @@ -160,6 +159,8 @@ class TestApiTask(BootTransactionTestCase): }) self.assertTrue(res.status_code == status.HTTP_404_NOT_FOUND) + testWatch.clear() + # Assign processing node to task via API res = client.patch("/api/projects/{}/tasks/{}/".format(project.id, task.id), { 'processing_node': pnode.id @@ -167,28 +168,12 @@ class TestApiTask(BootTransactionTestCase): self.assertTrue(res.status_code == status.HTTP_200_OK) # On update scheduler.processing_pending_tasks should have been called in the background - time.sleep(DELAY) - - print("HERE") - from app.background import testWatch - print(testWatch.stats) + testWatch.wait_until_call("app.scheduler.process_pending_tasks", timeout=5) # Processing should have completed task.refresh_from_db() self.assertTrue(task.status == status_codes.RUNNING) - - # TODO: need a way to prevent multithreaded code from executing - # and a way to notify our test case that multithreaded code should have - # executed - - # TODO: at this point we might not even need a TransactionTestCase? - - #from app import scheduler - #scheduler.process_pending_tasks(background=True) - - # time.sleep(3) - # TODO: check # TODO: what happens when nodes go offline, or an offline node is assigned to a task # TODO: check raw/non-raw assets once task is finished processing @@ -196,4 +181,3 @@ class TestApiTask(BootTransactionTestCase): # Teardown processing node node_odm.terminate() - #time.sleep(20) diff --git a/app/tests/test_testwatch.py b/app/tests/test_testwatch.py new file mode 100644 index 00000000..6e9558c3 --- /dev/null +++ b/app/tests/test_testwatch.py @@ -0,0 +1,40 @@ +from django.test import TestCase + +from app.testwatch import TestWatch + + +def test(a, b): + return a + b + +class TestTestWatch(TestCase): + def test_methods(self): + tw = TestWatch() + + self.assertTrue(tw.get_calls_count("app.tests.test_testwatch.test") == 0) + self.assertTrue(tw.get_calls_count("app.tests.test_testwatch.nonexistant") == 0) + + # Test watch count + tw.hook_pre(test, 1, 2) + test(1, 2) + tw.hook_post(test, 1, 2) + + self.assertTrue(tw.get_calls_count("app.tests.test_testwatch.test") == 1) + + tw.hook_pre(test, 1, 2) + test(1, 2) + tw.hook_post(test, 1, 2) + + self.assertTrue(tw.get_calls_count("app.tests.test_testwatch.test") == 2) + + @TestWatch.watch(testWatch=tw) + def test2(d): + d['flag'] = not d['flag'] + + # Test intercept + tw.intercept("app.tests.test_testwatch.test2") + d = {'flag': True} + test2(d) + self.assertTrue(d['flag']) + + + diff --git a/app/testwatch.py b/app/testwatch.py new file mode 100644 index 00000000..f56c4c58 --- /dev/null +++ b/app/testwatch.py @@ -0,0 +1,77 @@ +import time + +import logging + +from webodm import settings + +logger = logging.getLogger('app.logger') + +class TestWatch: + def __init__(self): + self.clear() + + def clear(self): + self._calls = {} + self._intercept_list = {} + + def func_to_name(f): + return "{}.{}".format(f.__module__, f.__name__) + + def intercept(self, fname): + self._intercept_list[fname] = True + + def should_prevent_execution(self, func): + return TestWatch.func_to_name(func) in self._intercept_list + + def get_calls(self, fname): + return self._calls[fname] if fname in self._calls else [] + + def get_calls_count(self, fname): + return len(self.get_calls(fname)) + + def wait_until_call(self, fname, count = 1, timeout = 30): + SLEEP_INTERVAL = 0.125 + TIMEOUT_LIMIT = timeout / SLEEP_INTERVAL + c = 0 + while self.get_calls_count(fname) < count and c < TIMEOUT_LIMIT: + time.sleep(SLEEP_INTERVAL) + c += 1 + + if c >= TIMEOUT_LIMIT: + raise TimeoutError("wait_until_call has timed out waiting for {}".format(fname)) + + return self.get_calls(fname) + + def log_call(self, func, *args, **kwargs): + fname = TestWatch.func_to_name(func) + logger.info("{} called".format(fname)) + list = self._calls[fname] if fname in self._calls else [] + list.append({'f': fname, 'args': args, 'kwargs': kwargs}) + self._calls[fname] = list + + def hook_pre(self, func, *args, **kwargs): + if settings.TESTING and self.should_prevent_execution(func): + logger.info(func.__name__ + " intercepted") + self.log_call(func, *args, **kwargs) + return True # Intercept + return False # Do not intercept + + def hook_post(self, func, *args, **kwargs): + if settings.TESTING: + self.log_call(func, *args, **kwargs) + + def watch(**kwargs): + """ + Decorator that adds pre/post hook calls + """ + tw = kwargs.get('testWatch', testWatch) + def outer(func): + def wrapper(*args, **kwargs): + if tw.hook_pre(func, *args, **kwargs): return + ret = func(*args, **kwargs) + tw.hook_post(func, *args, **kwargs) + return ret + return wrapper + return outer + +testWatch = TestWatch() \ No newline at end of file diff --git a/webodm/settings.py b/webodm/settings.py index 304f57be..a91cb304 100644 --- a/webodm/settings.py +++ b/webodm/settings.py @@ -224,6 +224,8 @@ REST_FRAMEWORK = { } TESTING = sys.argv[1:2] == ['test'] +if TESTING: + MEDIA_ROOT = os.path.join(BASE_DIR, 'app', 'media_test') try: from .local_settings import *