From 384d14c28f63d7c8f33de16eddd580e3bdc2bae4 Mon Sep 17 00:00:00 2001 From: Egor Voynov Date: Mon, 25 Mar 2024 14:08:40 +0100 Subject: [PATCH] add tests --- pghoard/webserver.py | 2 +- test/test_webserver.py | 77 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/pghoard/webserver.py b/pghoard/webserver.py index 81112a37..5bb5d407 100644 --- a/pghoard/webserver.py +++ b/pghoard/webserver.py @@ -125,7 +125,7 @@ def process_queue_item(self, download_result): os.unlink(src_tmp_file_path) return os.rename(src_tmp_file_path, dst_file_path) - metadata = download_result.payload["metadata"] or {} + metadata = download_result.payload.get("metadata", {}) self.log.info( "Renamed %s to %s. Original upload from %r, hash %s:%s", download_result.payload["target_path"], op.target_path, metadata.get("host"), metadata.get("hash-algorithm"), metadata.get("hash") diff --git a/test/test_webserver.py b/test/test_webserver.py index af785ca1..4d2604bf 100644 --- a/test/test_webserver.py +++ b/test/test_webserver.py @@ -8,6 +8,7 @@ import logging import os import socket +import threading import time from distutils.version import LooseVersion from http.client import HTTPConnection @@ -20,11 +21,12 @@ from pghoard import postgres_command, wal from pghoard.archive_sync import ArchiveSync -from pghoard.common import get_pg_wal_directory +from pghoard.common import CallbackEvent, get_pg_wal_directory from pghoard.object_store import HTTPRestore from pghoard.pgutil import create_connection_string from pghoard.postgres_command import archive_command, restore_command from pghoard.restore import Restore +from pghoard.webserver import DownloadResultsProcessor, PendingDownloadOp # pylint: disable=attribute-defined-outside-init from .base import CONSTANT_TEST_RSA_PRIVATE_KEY, CONSTANT_TEST_RSA_PUBLIC_KEY @@ -770,3 +772,76 @@ def test_uncontrolled_target_path(self, pghoard): conn.request("GET", wal_file, headers=headers) status = conn.getresponse().status assert status == 400 + + +class TestDownloadResultsProcessor: + wal_name = "000000060000000000000001" + + def save_wal_and_dowload_callback(self, pg_wal_dir, drp, wal_name=None, is_valid_wal=True): + if wal_name is None: + wal_name = self.wal_name + tmp_path = os.path.join(pg_wal_dir, f"{wal_name}.pghoard.tmp") + target_path = os.path.join(pg_wal_dir, f"{wal_name}.pghoard.prefetch") + assert not os.path.exists(tmp_path) + assert not os.path.exists(target_path) + + # save WAL on FS + if is_valid_wal: + wal = wal_header_for_file(wal_name) + else: + another_wal_name = "000000DD00000000000000DD" + assert wal_name != another_wal_name + wal = wal_header_for_file(another_wal_name) + with open(tmp_path, "wb") as out_file: + out_file.write(wal) + + download_result = CallbackEvent(success=True, payload={"target_path": tmp_path}, opaque=wal_name) + pending_op = PendingDownloadOp( + started_at=time.monotonic(), target_path=target_path, filetype="xlog", filename=wal_name + ) + drp.pending_download_ops[wal_name] = pending_op + drp.download_results.put(download_result) + return tmp_path, target_path + + def init_download_results_processor(self): + return DownloadResultsProcessor(threading.RLock(), logging.getLogger("WebServer"), Queue(), {}, []) + + def test_rename_valid_wal(self, tmpdir): + drp = self.init_download_results_processor() + tmp_path, target_path = self.save_wal_and_dowload_callback(tmpdir, drp) + download_result_item = drp.download_results.get() + drp.process_queue_item(download_result_item) + assert os.path.exists(target_path) + assert not os.path.exists(tmp_path) + + def test_dont_save_invalid_wal(self, tmpdir): + drp = self.init_download_results_processor() + tmp_path, target_path = self.save_wal_and_dowload_callback(tmpdir, drp, is_valid_wal=False) + download_result_item = drp.download_results.get() + drp.process_queue_item(download_result_item) + assert not os.path.exists(target_path) + assert not os.path.exists(tmp_path) + + def test_skip_not_pending_op(self, tmpdir): + drp = self.init_download_results_processor() + tmp_path, target_path = self.save_wal_and_dowload_callback(tmpdir, drp) + download_result_item = drp.download_results.get() + drp.pending_download_ops = {} + drp.process_queue_item(download_result_item) + assert not os.path.exists(target_path) + assert not os.path.exists(tmp_path) + + def test_dont_overwrite_existing_target_file(self, tmpdir): + drp = self.init_download_results_processor() + tmp_path, target_path = self.save_wal_and_dowload_callback(tmpdir, drp) + existing_file_data = b"-" + with open(target_path, "wb") as out_file: + out_file.write(existing_file_data) + assert os.path.exists(target_path) + assert os.path.exists(tmp_path) + + download_result_item = drp.download_results.get() + drp.process_queue_item(download_result_item) + assert os.path.exists(target_path) + assert open(target_path, "rb").read() == existing_file_data + assert os.path.exists(tmp_path)