diff --git a/ogn/gateway/manage.py b/ogn/gateway/manage.py index 3c6d756..2e4a85d 100644 --- a/ogn/gateway/manage.py +++ b/ogn/gateway/manage.py @@ -18,8 +18,9 @@ def run(aprs_user="anon-dev"): gateway.connect_db() while user_interrupted is False: - print("Connect OGN gateway") + print("Connect OGN gateway as {}".format(aprs_user)) gateway.connect(aprs_user) + socket_open = True try: gateway.run() @@ -28,10 +29,11 @@ def run(aprs_user="anon-dev"): user_interrupted = True except BrokenPipeError: print("BrokenPipeError") - except socket.err: + except socket.error: print("socket error") + socket_open = False - print("Disconnect OGN gateway") - gateway.disconnect() + if socket_open: + gateway.disconnect() print("\nExit OGN gateway") diff --git a/tests/gateway/__init__.py b/tests/gateway/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/gateway/test_manage.py b/tests/gateway/test_manage.py new file mode 100644 index 0000000..178673a --- /dev/null +++ b/tests/gateway/test_manage.py @@ -0,0 +1,35 @@ +import unittest +import unittest.mock as mock + +from ogn.gateway.manage import run + + +class GatewayTest(unittest.TestCase): + + # try simple user interrupt + @mock.patch('ogn.gateway.manage.ognGateway') + def test_user_interruption(self, mock_gateway): + instance = mock_gateway.return_value + instance.run.side_effect = KeyboardInterrupt() + + run("user_1") + + instance.connect_db.assert_called_once_with() + instance.connect.assert_called_once_with("user_1") + instance.run.assert_called_once_with() + instance.disconnect.assert_called_once_with() + + # make BrokenPipeErrors and a socket error (may happen) and then a user interrupt (important!) + @mock.patch('ogn.gateway.manage.ognGateway') + def test_BrokenPipeError(self, mock_gateway): + instance = mock_gateway.return_value + instance.run.side_effect = [BrokenPipeError(), BrokenPipeError(), KeyboardInterrupt()] + + run("user_2") + + instance.connect_db.assert_called_once_with() + self.assertTrue(instance.run.call_count, 3) + self.assertTrue(instance.disconnect.call_count, 2) # not called if socket crashed + +if __name__ == '__main__': + unittest.main()