diff --git a/remotecv/worker.py b/remotecv/worker.py index 42153d6..33f6d10 100644 --- a/remotecv/worker.py +++ b/remotecv/worker.py @@ -8,13 +8,15 @@ # http://www.opensource.org/licenses/mit-license # Copyright (c) 2011 globo.com timehome@corp.globo.com -import argparse import logging import sys from http.server import HTTPServer from importlib import import_module from threading import Thread +import click +from click_option_group import optgroup + from remotecv.error_handler import ErrorHandler from remotecv.healthcheck import HealthCheckHandler from remotecv.importer import Importer @@ -73,169 +75,212 @@ def import_modules(): context.metrics.initialize() -def main(params=None): - if params is None: - params = sys.argv[1:] - parser = argparse.ArgumentParser(description="Runs RemoteCV.") - - conn_group = parser.add_argument_group("Worker Backend") - conn_group.add_argument( - "-b", - "--backend", - default="pyres", - choices=["pyres", "celery"], - help="Worker backend", - ) - - conn_group = parser.add_argument_group("Pyres Connection Arguments") - conn_group.add_argument("--host", default="localhost", help="Redis host") - conn_group.add_argument( - "--port", default=6379, type=int, help="Redis port" - ) - conn_group.add_argument( - "--database", default=0, type=int, help="Redis database" - ) - conn_group.add_argument("--password", default=None, help="Redis password") - conn_group.add_argument( - "--redis-mode", - default=SINGLE_NODE, - choices=[SINGLE_NODE, SENTINEL], - help="Redis mode", - ) - conn_group.add_argument( - "--sentinel-instances", - default="localhost:26376", - help="Redis Sentinel instances e.g. 'localhost:26376,localhost:26377'", - ) - conn_group.add_argument( - "--sentinel-password", default=None, help="Redis Sentinel password" - ) - conn_group.add_argument( - "--master-instance", - default=None, - help="Redis Sentinel master instance", - ) - conn_group.add_argument( - "--master-password", - default=None, - help="Redis Sentinel master password", - ) - conn_group.add_argument( - "--master-database", - default=0, - type=int, - help="Redis Sentinel master database", - ) - conn_group.add_argument( - "--socket-timeout", - default=10.0, - type=float, - help="Redis Sentinel socket timeout", - ) - - conn_group = parser.add_argument_group("Celery/SQS Connection Arguments") - conn_group.add_argument( - "--region", default="us-east-1", help="AWS SQS Region" - ) - conn_group.add_argument("--key_id", default="", help="AWS access key id") - conn_group.add_argument( - "--key_secret", default="", help="AWS access key secret" - ) - conn_group.add_argument( - "--polling_interval", default=20, help="AWS polling interval" - ) - - other_group = parser.add_argument_group("Other arguments") - other_group.add_argument( - "--server-port", default=8080, type=int, help="Server http port" - ) - other_group.add_argument( - "--with-healthcheck", - default=False, - action="store_true", - help="Start an healthchecker http endpoint", - ) - other_group.add_argument( - "-l", "--level", default="debug", help="Logging level" - ) - other_group.add_argument( - "-o", "--loader", default="remotecv.http_loader", help="Loader used" - ) - other_group.add_argument( - "-s", - "--store", - default="remotecv.result_store.redis_store", - help="Loader used", - ) - other_group.add_argument( - "-t", - "--timeout", - default=None, - type=int, - help="Timeout in seconds for image detection", - ) - other_group.add_argument( - "--sentry_url", default=None, help="URL used to send errors to sentry" - ) - other_group.add_argument( - "--metrics", - default="remotecv.metrics.logger_metrics", - help="Metrics client, should be the full name of a python module", - ) - - memcache_store_group = parser.add_argument_group( - "Memcache store arguments" - ) - memcache_store_group.add_argument( - "--memcache_hosts", - default="localhost:11211", - help="Comma separated list of memcache hosts", - ) - - parser.add_argument("args", nargs=argparse.REMAINDER) - - arguments = parser.parse_args(params) - logging.basicConfig(level=getattr(logging, arguments.level.upper())) - - config.backend = arguments.backend - config.redis_host = arguments.host - config.redis_port = arguments.port - config.redis_database = arguments.database - config.redis_password = arguments.password - config.redis_mode = arguments.redis_mode - config.redis_sentinel_instances = arguments.sentinel_instances - config.redis_sentinel_password = arguments.sentinel_password - config.redis_sentinel_socket_timeout = arguments.socket_timeout - config.redis_sentinel_master_instance = arguments.master_instance - config.redis_sentinel_master_password = arguments.master_password - config.redis_sentinel_master_database = arguments.master_database - - config.region = arguments.region - config.key_id = arguments.key_id - config.key_secret = arguments.key_secret - config.polling_interval = arguments.polling_interval - - config.timeout = arguments.timeout - config.server_port = arguments.server_port - config.log_level = arguments.level.upper() - config.loader = import_module(arguments.loader) - config.store = import_module(arguments.store) - config.metrics = arguments.metrics - - config.memcache_hosts = arguments.memcache_hosts - - config.extra_args = sys.argv[:1] + arguments.args - - config.error_handler = ErrorHandler(arguments.sentry_url) +@click.command() +@optgroup.group("Worker Backend") +@optgroup.option( + "-b", + "--backend", + envvar="BACKEND", + default="pyres", + type=click.Choice(["pyres", "celery"]), + help="Worker backend", +) +@optgroup.group("Pyres Connection Arguments") +@optgroup.option( + "--host", envvar="REDIS_HOST", default="localhost", help="Redis host" +) +@optgroup.option( + "--port", envvar="REDIS_PORT", default=6379, help="Redis port" +) +@optgroup.option( + "--database", envvar="REDIS_DATABASE", default=0, help="Redis database" +) +@optgroup.option( + "--password", envvar="REDIS_PASSWORD", default=None, help="Redis password" +) +@optgroup.option( + "--redis-mode", + envvar="REDIS_MODE", + default=SINGLE_NODE, + type=click.Choice([SINGLE_NODE, SENTINEL]), + help="Redis mode", +) +@optgroup.option( + "--sentinel-instances", + envvar="REDIS_SENTINEL_INSTANCES", + default="localhost:26376", + help="Redis Sentinel instances e.g. 'localhost:26376,localhost:26377'", +) +@optgroup.option( + "--sentinel-password", + envvar="REDIS_SENTINEL_PASSWORD", + default=None, + help="Redis Sentinel password", +) +@optgroup.option( + "--master-instance", + envvar="REDIS_MASTER_INSTANCE", + default=None, + help="Redis Sentinel master instance", +) +@optgroup.option( + "--master-password", + envvar="REDIS_MASTER_PASSWORD", + default=None, + help="Redis Sentinel master password", +) +@optgroup.option( + "--master-database", + envvar="REDIS_MASTER_DATABASE", + default=0, + help="Redis Sentinel master database", +) +@optgroup.option( + "--socket-timeout", + envvar="REDIS_SENTINEL_SOCKET_TIMEOUT", + default=10.0, + help="Redis Sentinel socket timeout", +) +@optgroup.group("Celery/SQS Connection Arguments") +@optgroup.option( + "--region", + envvar="AWS_REGION", + default="us-east-1", + help="AWS SQS Region", +) +@optgroup.option( + "--key-id", + envvar="AWS_ACCESS_KEY_ID", + default=None, + help="AWS access key id", +) +@optgroup.option( + "--key-secret", + envvar="AWS_SECRET_ACCESS_KEY", + default=None, + help="AWS access key secret", +) +@optgroup.option( + "--polling-interval", + envvar="SQS_POLLING_INTERVAL", + default=20, + help="AWS polling interval", +) +@optgroup.option( + "--celery-commands", + envvar="CELERY_COMMANDS", + default=[], + multiple=True, + help="SQS command", +) +@optgroup.group("Other arguments") +@optgroup.option( + "--server-port", + envvar="HTTP_SERVER_PORT", + default=8080, + help="HTTP server port", +) +@optgroup.option( + "--with-healthcheck", + envvar="WITH_HEALTHCHECK", + is_flag=True, + default=False, + help="Start a healthcheck http endpoint", +) +@optgroup.option( + "-l", + "--level", + envvar="LOG_LEVEL", + type=click.Choice(["debug", "info", "warning", "error", "critical"]), + default="debug", + help="Logging level", +) +@optgroup.option( + "-o", + "--loader", + envvar="IMAGE_LOADER", + default="remotecv.http_loader", + help="Image loader", +) +@optgroup.option( + "-s", + "--store", + envvar="DETECTOR_STORAGE", + default="remotecv.result_store.redis_store", + help="Detector result store", +) +@optgroup.option( + "-t", + "--timeout", + envvar="DETECTOR_TIMEOUT", + default=None, + type=click.INT, + help="Timeout in seconds for image detection", +) +@optgroup.option( + "--sentry-url", + envvar="SENTRY_URL", + default=None, + help="Sentry URL", +) +@optgroup.option( + "--metrics", + envvar="METRICS_CLIENT", + default="remotecv.metrics.logger_metrics", + help="Metrics client, should be the full name of a python module", +) +@optgroup.group("Memcached store arguments") +@optgroup.option( + "--memcached-hosts", + envvar="MEMCACHED_HOSTS", + default="localhost:11211", + help="Comma separated list of memcached hosts", +) +def main(**params): + """Runs RemoteCV""" + + logging.basicConfig(level=getattr(logging, params["level"].upper())) + + config.backend = params["backend"] + config.redis_host = params["host"] + config.redis_port = params["port"] + config.redis_database = params["database"] + config.redis_password = params["password"] + config.redis_mode = params["redis_mode"] + config.redis_sentinel_instances = params["sentinel_instances"] + config.redis_sentinel_password = params["sentinel_password"] + config.redis_sentinel_socket_timeout = params["socket_timeout"] + config.redis_sentinel_master_instance = params["master_instance"] + config.redis_sentinel_master_password = params["master_password"] + config.redis_sentinel_master_database = params["master_database"] + + config.region = params["region"] + config.key_id = params["key_id"] + config.key_secret = params["key_secret"] + config.polling_interval = params["polling_interval"] + + config.timeout = params["timeout"] + config.server_port = params["server_port"] + config.log_level = params["level"].upper() + config.loader = import_module(params["loader"]) + config.store = import_module(params["store"]) + + config.metrics = params["metrics"] + + config.memcache_hosts = params["memcached_hosts"] + + config.extra_args = sys.argv[:1] + list(params["celery_commands"]) + + config.error_handler = ErrorHandler(params["sentry_url"]) import_modules() - if arguments.with_healthcheck: + if params["with_healthcheck"]: start_http_server() - if arguments.backend == "pyres": + if params["backend"] == "pyres": start_pyres_worker() - elif arguments.backend == "celery": + elif params["backend"] == "celery": start_celery_worker() diff --git a/setup.py b/setup.py index 19fcb93..5e104cc 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,8 @@ "pyres==1.*,>=1.5.0", "redis==4.*,>=4.2.0", "sentry-sdk==0.*,>=0.14.2", + "click==8.*", + "click-option-group==0.5.*", ] setup( diff --git a/tests/test_worker.py b/tests/test_worker.py new file mode 100644 index 0000000..92a596f --- /dev/null +++ b/tests/test_worker.py @@ -0,0 +1,44 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +from unittest import TestCase, mock + +from preggy import expect +from click.testing import CliRunner + +from remotecv import worker + + +@mock.patch("remotecv.worker.start_pyres_worker") +@mock.patch("remotecv.worker.start_celery_worker") +@mock.patch("remotecv.worker.start_http_server") +class UniqueQueueTestCase(TestCase): + def setUp(self): + self.runner = CliRunner() + + def test_should_start_pyres_worker( + self, http_mock, celery_mock, pyres_mock + ): + result = self.runner.invoke(worker.main) + expect(result.exit_code).to_equal(0) + pyres_mock.assert_called_once() + celery_mock.assert_not_called() + http_mock.assert_not_called() + + def test_should_start_celery_worker( + self, http_mock, celery_mock, pyres_mock + ): + result = self.runner.invoke(worker.main, ["-b", "celery"]) + expect(result.exit_code).to_equal(0) + pyres_mock.assert_not_called() + celery_mock.assert_called_once() + http_mock.assert_not_called() + + def test_should_start_healthcheck( + self, http_mock, celery_mock, pyres_mock + ): + result = self.runner.invoke(worker.main, ["--with-healthcheck"]) + expect(result.exit_code).to_equal(0) + pyres_mock.assert_called_once() + celery_mock.assert_not_called() + http_mock.assert_called_once()