From 0d36142fd6b90268163ab8e62e9a1014f73790ca Mon Sep 17 00:00:00 2001 From: Doggie B <3859395+fubuloubu@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:45:43 -0400 Subject: [PATCH] feat(runner): add BacktestRunner w/ `silverback test` command --- setup.py | 1 + silverback/__init__.py | 1 + silverback/_cli.py | 27 +++++++++ silverback/pytest.py | 117 ++++++++++++++++++++++++++++++++++++++ silverback/runner.py | 64 +++++++++++++++++++++ tests/backtest_merge.yaml | 5 ++ 6 files changed, 215 insertions(+) create mode 100644 silverback/pytest.py create mode 100644 tests/backtest_merge.yaml diff --git a/setup.py b/setup.py index 7c6393e7..55b2a5c6 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ ], entry_points={ "console_scripts": ["silverback=silverback._cli:cli"], + "pytest11": ["silverback_test=silverback.pytest"], }, python_requires=">=3.10,<4", extras_require=extras_require, diff --git a/silverback/__init__.py b/silverback/__init__.py index 1f55c662..75a3e070 100644 --- a/silverback/__init__.py +++ b/silverback/__init__.py @@ -22,6 +22,7 @@ def __getattr__(name: str): __all__ = [ "StateSnapshot", + "BacktestRunner", "CircuitBreaker", "SilverbackBot", "SilverbackException", diff --git a/silverback/_cli.py b/silverback/_cli.py index 043ff5e6..3dd83634 100644 --- a/silverback/_cli.py +++ b/silverback/_cli.py @@ -1,10 +1,12 @@ import asyncio import os +import sys from datetime import datetime, timedelta, timezone from pathlib import Path from typing import TYPE_CHECKING, Optional import click +import pytest import yaml # type: ignore[import-untyped] from ape.cli import ( AccountAliasPromptChoice, @@ -13,6 +15,7 @@ account_option, ape_cli_context, network_option, + verbosity_option, ) from ape.exceptions import Abort, ApeException from ape.logging import LogLevel @@ -171,6 +174,30 @@ def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, bot): asyncio.run(run_worker(bot.broker, worker_count=workers, shutdown_timeout=shutdown_timeout)) +@cli.command( + section="Local Commands", + add_help_option=False, # NOTE: This allows pass-through to pytest's help + short_help="Run bot backtests (`tests/backtest_*.yaml`)", + context_settings=dict(ignore_unknown_options=True), +) +@ape_cli_context() +@verbosity_option() +@network_option( + default=os.environ.get("SILVERBACK_NETWORK_CHOICE", "auto"), + callback=_network_callback, +) +@click.option("--bot", "bots", multiple=True) +@click.argument("pytest_args", nargs=-1, type=click.UNPROCESSED) +def test(cli_ctx, network, bots, pytest_args): + os.environ["SILVERBACK_FORK_MODE"] = "1" + + return_code = pytest.main([*pytest_args], ["silverback.pytest"]) + + if return_code: + # only exit with non-zero status to make testing easier + sys.exit(return_code) + + @cli.command(section="Cloud Commands (https://silverback.apeworx.io)") @auth_required def login(auth: "FiefAuth"): diff --git a/silverback/pytest.py b/silverback/pytest.py new file mode 100644 index 00000000..46e2e5c8 --- /dev/null +++ b/silverback/pytest.py @@ -0,0 +1,117 @@ +import asyncio +import os +from pathlib import Path + +import pytest +import yaml # type: ignore[import] +from ape.utils import cached_property + +from silverback._importer import import_from_string +from silverback.exceptions import SilverbackException +from silverback.runner import BacktestRunner + + +class AssertionViolation(SilverbackException): + pass + + +def pytest_collect_file(parent, file_path): + if file_path.suffix == ".yaml" and file_path.name.startswith("backtest"): + return BacktestFile.from_parent(parent, path=file_path) + + +class BacktestFile(pytest.File): + def collect(self): + raw = yaml.safe_load(self.path.open()) + if not (network_triple := raw.get("network")): + raise ValueError(f"{self.path} is missing key 'network'.") + + start_block = raw.get("start_block", 0) + stop_block = raw.get("stop_block", -1) + assertion_checks = raw.get("assertions", {}) + + raw_bot_paths = raw.get("bots") + if isinstance(raw_bot_paths, list): + for bot_path in raw_bot_paths: + if ":" in bot_path: + bot_path, bot_name = bot_path.split(":") + bot_path = Path(bot_path) + else: + bot_path = Path(bot_path) + bot_name = "bot" + + yield BacktestItem.from_parent( + self, + name=f"{self.name}[{bot_name}]", + file_path=self.path, + bot_path=bot_path, + bot_name=bot_name, + network_triple=network_triple, + start_block=start_block, + stop_block=stop_block, + assertion_checks=assertion_checks, + ) + + else: + if ":" in raw_bot_paths: + bot_path, bot_name = raw_bot_paths.split(":") + bot_path = Path(bot_path) + else: + bot_path = Path(raw_bot_paths) + bot_name = "bot" + + yield BacktestItem.from_parent( + self, + name=self.name, + file_path=self.path, + bot_path=bot_path, + bot_name=bot_name, + network_triple=network_triple, + start_block=start_block, + stop_block=stop_block, + assertion_checks=assertion_checks, + ) + + +class BacktestItem(pytest.Item): + def __init__( + self, + *, + file_path, + bot_path, + bot_name, + network_triple, + start_block, + stop_block, + assertion_checks, + **kwargs, + ): + super().__init__(**kwargs) + self.file_path = file_path + self.bot_path = bot_path + self.bot_name = bot_name + self.network_triple = network_triple + self.start_block = start_block + self.stop_block = stop_block + self.assertion_checks = assertion_checks + + self.assertion_failures = 0 + self.overruns = 0 + + @cached_property + def runner(self): + os.environ["SILVERBACK_NETWORK_CHOICE"] = self.network_triple + os.environ["PYTHONPATH"] = str(self.bot_path.parent) + app = import_from_string(f"{self.bot_path.stem}:{self.bot_name}") + return BacktestRunner(app, start_block=self.start_block, stop_block=self.stop_block) + + def check_assertions(self, result: dict): + pass + + def runtest(self): + asyncio.run(self.runner.run()) + self.raise_run_status() + + def raise_run_status(self): + if self.overruns > 0 or self.assertion_failures > 0: + raise AssertionViolation() diff --git a/silverback/runner.py b/silverback/runner.py index 46987fa4..b3a69a90 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -5,6 +5,7 @@ from ape.logging import logger from ape.utils import ManagerAccessMixin from ape_ethereum.ecosystem import keccak +from click import progressbar from ethpm_types import EventABI from packaging.specifiers import SpecifierSet from packaging.version import Version @@ -420,3 +421,66 @@ async def _event_task(self, task_data: TaskData): await self._checkpoint(last_block_seen=event.block_number) await self._handle_task(await event_log_task_kicker.kiq(event)) await self._checkpoint(last_block_processed=event.block_number) + + +class BacktestRunner(BaseRunner): + def __init__( + self, + app: SilverbackBot, + start_block: int, + stop_block: int, + *args, + **kwargs, + ): + super().__init__(app, *args, **kwargs) + + # NOTE: Takes time to do the data collection + with progressbar( + chain.blocks.range(start_block, stop_block + 1), + length=(stop_block - start_block), + ) as blocks: + self.blocks = list(blocks) + + logger.info( + f"Using {self.__class__.__name__}:" + f" num_blocks={stop_block - start_block}" + f" max_exceptions={self.max_exceptions}" + ) + + async def _block_task(self, task_data: TaskData): + new_block_task_kicker = self._create_task_kicker(task_data) + + async for block in async_wrap_iter(iter(self.blocks)): + await self._checkpoint(last_block_seen=block.number) + await self._handle_task(await new_block_task_kicker.kiq(block)) + await self._checkpoint(last_block_processed=block.number) + + async def _event_task(self, task_data: TaskData): + if not (event_signature := task_data.labels.get("event_signature")): + raise StartupFailure("No Event Signature provided.") + + event_abi = EventABI.from_signature(event_signature) + + if not (contract_address := task_data.labels.get("contract_address")): + raise StartupFailure("Contract instance required.") + + if ( + not ( + events := chain.contracts.instance_at(contract_address)._events_.get(event_abi.name) + ) + or len(events) == 0 + ): + raise StartupFailure( + "Contract '{contract_address}' does not have event '{event_abi.name}'." + ) + + event_log_task_kicker = self._create_task_kicker(task_data) + + async for block in async_wrap_iter(iter(self.blocks)): + txn_hashes = iter(tx.txn_hash.hex() for tx in block.transactions) + receipts = map(chain.get_receipt, txn_hashes) + async for logs in async_wrap_iter(map(events[0].from_receipt, receipts)): + async for log in async_wrap_iter(iter(logs)): + await self._checkpoint(last_block_seen=log.block_number) + await self._handle_task(await event_log_task_kicker.kiq(log)) + await self._checkpoint(last_block_processed=log.block_number) diff --git a/tests/backtest_merge.yaml b/tests/backtest_merge.yaml new file mode 100644 index 00000000..67f16b7b --- /dev/null +++ b/tests/backtest_merge.yaml @@ -0,0 +1,5 @@ +bots: ["example"] +network: "ethereum:mainnet" +start_block: 15_338_009 +stop_block: 15_338_018 +something_else: blah