diff --git a/pghoard/pghoard.py b/pghoard/pghoard.py index 396ecda1..8eba7ed3 100644 --- a/pghoard/pghoard.py +++ b/pghoard/pghoard.py @@ -47,7 +47,7 @@ from pghoard.receivexlog import PGReceiveXLog from pghoard.transfer import (TransferAgent, TransferQueue, UploadEvent, UploadEventProgressTracker) from pghoard.walreceiver import WALReceiver -from pghoard.webserver import WebServer +from pghoard.webserver import DownloadResultsProcessor, WebServer @dataclass @@ -149,6 +149,10 @@ def __init__(self, config_path): self.webserver = WebServer( self.config, self.requested_basebackup_sites, self.compression_queue, self.transfer_queue, self.metrics ) + self.download_results_processor = DownloadResultsProcessor( + self.config, self.webserver.lock, self.webserver.download_results, self.webserver.pending_download_ops, + self.webserver.prefetch_404 + ) self.wal_file_deleter = WALFileDeleterThread( config=self.config, wal_file_deletion_queue=self.wal_file_deletion_queue, metrics=self.metrics @@ -701,6 +705,7 @@ def start_threads_on_startup(self): self.inotify.start() self.upload_tracker.start() self.webserver.start() + self.download_results_processor.start() self.wal_file_deleter.start() for compressor in self.compressors: compressor.start() @@ -983,6 +988,8 @@ def _get_all_threads(self): if hasattr(self, "webserver"): all_threads.append(self.webserver) + if hasattr(self, "download_results_processor"): + all_threads.append(self.download_results_processor) all_threads.extend(self.basebackups.values()) all_threads.extend(self.receivexlogs.values()) all_threads.extend(self.walreceivers.values()) diff --git a/pghoard/webserver.py b/pghoard/webserver.py index 7d5edb20..bacc5ac4 100644 --- a/pghoard/webserver.py +++ b/pghoard/webserver.py @@ -20,7 +20,7 @@ from queue import Empty, Queue from socketserver import ThreadingMixIn from threading import RLock -from typing import Dict +from typing import Any, Dict from rohmu.errors import Error, FileNotFoundFromStorageError @@ -102,16 +102,19 @@ class DownloadResultsProcessor(PGHoardThread): Processes download_results queue, validates WAL and renames tmp file to target (".prefetch") """ def __init__( - self, lock: RLock, log: logging.Logger, download_results: Queue, pending_download_ops: Dict[str, PendingDownloadOp], - prefetch_404: deque + self, config: Dict[str, Any], lock: RLock, download_results: Queue, + pending_download_ops: Dict[str, PendingDownloadOp], prefetch_404: deque ) -> None: super().__init__(name=self.__class__.__name__) - self.running = False - self.log = log + self.log = logging.getLogger("WebServer") self.lock = lock self.download_results = download_results self.pending_download_ops = pending_download_ops self.prefetch_404 = prefetch_404 + # PGHoard expects the threads to have these attributes + self.running = False + self.config = config + self.site_transfers = {} def run_safe(self) -> None: self.running = True @@ -174,9 +177,6 @@ def process_queue_item(self, download_result: CallbackEvent) -> None: pending_download_op.target_path, metadata.get("host"), metadata.get("hash-algorithm"), metadata.get("hash") ) - def stop(self) -> None: - self.running = False - class WebServer(PGHoardThread): def __init__(self, config, requested_basebackup_sites, compression_queue, transfer_queue, metrics): @@ -197,9 +197,6 @@ def __init__(self, config, requested_basebackup_sites, compression_queue, transf self.log.debug("WebServer initialized with address: %r port: %r", self.address, self.port) self.is_initialized = threading.Event() self.prefetch_404 = deque(maxlen=32) # pylint: disable=attribute-defined-outside-init - self.download_results_processor = DownloadResultsProcessor( - self.lock, self.log, self.download_results, self.pending_download_ops, self.prefetch_404 - ) def run_safe(self): # We bind the port only when we start running @@ -216,7 +213,6 @@ def run_safe(self): download_results=self.download_results, prefetch_404=self.prefetch_404, metrics=self.metrics) - self.download_results_processor.start() self.is_initialized.set() self.server.serve_forever() @@ -224,8 +220,6 @@ def close(self): self.log.debug("Closing WebServer") if self.server: self.server.shutdown() - self.download_results_processor.stop() - self.download_results_processor.join() self.log.debug("Closed WebServer") self._running = False diff --git a/test/test_webserver.py b/test/test_webserver.py index 9fb86ca3..e996bd0e 100644 --- a/test/test_webserver.py +++ b/test/test_webserver.py @@ -777,7 +777,7 @@ def test_uncontrolled_target_path(self, pghoard): @pytest.fixture(name="download_results_processor") def fixture_download_results_processor() -> DownloadResultsProcessor: - return DownloadResultsProcessor(threading.RLock(), logging.getLogger("WebServer"), Queue(), {}, deque()) + return DownloadResultsProcessor({}, threading.RLock(), Queue(), {}, deque()) class TestDownloadResultsProcessor: