Skip to content

Commit

Permalink
savepoint
Browse files Browse the repository at this point in the history
  • Loading branch information
odesenfans committed Oct 30, 2023
1 parent 61691ac commit 464cd8b
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 80 deletions.
7 changes: 7 additions & 0 deletions src/aleph/db/accessors/pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def get_pending_messages(
return session.execute(select_stmt).scalars()


def get_pending_message(session: DbSession, pending_message_id: int) -> Optional[PendingMessageDb]:
select_stmt = select(PendingMessageDb).where(
PendingMessageDb.id == pending_message_id
)
return session.execute(select_stmt).scalar_one_or_none()


def count_pending_messages(session: DbSession, chain: Optional[Chain] = None) -> int:
"""
Counts pending messages.
Expand Down
123 changes: 65 additions & 58 deletions src/aleph/jobs/fetch_pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,30 @@
AsyncIterator,
Sequence,
NewType,
Optional,
)

import aio_pika.abc
from aleph_message.models import ItemHash
from configmanager import Config
from setproctitle import setproctitle

from ..chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.pending_messages import (
make_pending_message_fetched_statement,
get_next_pending_messages,
get_pending_message,
)
from aleph.db.connection import make_engine, make_session_factory
from aleph.db.models import PendingMessageDb, MessageDb
from aleph.db.models import MessageDb
from aleph.handlers.message_handler import MessageHandler
from aleph.services.ipfs import IpfsService
from aleph.services.ipfs.common import make_ipfs_client
from aleph.services.storage.fileystem_engine import FileSystemStorageEngine
from aleph.storage import StorageService
from aleph.toolkit.logging import setup_logging
from aleph.toolkit.monitoring import setup_sentry
from aleph.toolkit.timestamp import utc_now
from aleph.types.db_session import DbSessionFactory
from .job_utils import prepare_loop, MessageJob
from .job_utils import prepare_loop, MessageJob, make_pending_message_queue
from ..chains.signature_verifier import SignatureVerifier
from ..services.cache.node_cache import NodeCache

LOGGER = getLogger(__name__)
Expand All @@ -47,15 +49,23 @@ def __init__(
session_factory: DbSessionFactory,
message_handler: MessageHandler,
max_retries: int,
pending_message_queue: aio_pika.abc.AbstractQueue,
):
super().__init__(
session_factory=session_factory,
message_handler=message_handler,
max_retries=max_retries,
)
self.pending_message_queue = pending_message_queue

async def fetch_pending_message(self, pending_message: PendingMessageDb):
async def fetch_pending_message(
self, pending_message_id: int
) -> Optional[MessageDb]:
with self.session_factory() as session:
pending_message = get_pending_message(
session=session, pending_message_id=pending_message_id
)

try:
message = await self.message_handler.verify_and_fetch(
session=session, pending_message=pending_message
Expand All @@ -76,6 +86,7 @@ async def fetch_pending_message(self, pending_message: PendingMessageDb):
exception=e,
)
session.commit()
return None

async def fetch_pending_messages(
self, config: Config, node_cache: NodeCache, loop: bool = True
Expand All @@ -87,61 +98,55 @@ async def fetch_pending_messages(
await node_cache.set(retry_messages_cache_key, 0)
max_concurrent_tasks = config.aleph.jobs.pending_messages.max_concurrency.value
fetch_tasks: Set[asyncio.Task] = set()
task_message_dict: Dict[asyncio.Task, PendingMessageDb] = {}
task_message_dict: Dict[asyncio.Task, ItemHash] = {}
messages_being_fetched: Set[str] = set()
fetched_messages: List[MessageDb] = []

while True:
with self.session_factory() as session:
if fetch_tasks:
finished_tasks, fetch_tasks = await asyncio.wait(
fetch_tasks, return_when=asyncio.FIRST_COMPLETED
)
for finished_task in finished_tasks:
pending_message = task_message_dict.pop(finished_task)
messages_being_fetched.remove(pending_message.item_hash)
await node_cache.decr(retry_messages_cache_key)

if len(fetch_tasks) < max_concurrent_tasks:
pending_messages = get_next_pending_messages(
session=session,
current_time=utc_now(),
limit=max_concurrent_tasks - len(fetch_tasks),
offset=len(fetch_tasks),
exclude_item_hashes=messages_being_fetched,
fetched=False,
)
if fetch_tasks:
finished_tasks, fetch_tasks = await asyncio.wait(
fetch_tasks, return_when=asyncio.FIRST_COMPLETED
)
for finished_task in finished_tasks:
pending_message_hash = task_message_dict.pop(finished_task)
messages_being_fetched.remove(pending_message_hash)
await node_cache.decr(retry_messages_cache_key)

if len(fetch_tasks) < max_concurrent_tasks:
for i in range(0, max_concurrent_tasks - len(fetch_tasks)):
message = await self.pending_message_queue.get(fail=False)
if message is None:
break

for pending_message in pending_messages:
# Avoid processing the same message twice at the same time.
if pending_message.item_hash in messages_being_fetched:
async with message.process(requeue=True, ignore_processed=True):
pending_message_hash = ItemHash(
message.routing_key.split(".")[1]
)
# Avoid fetching the same message twice at the same time.
if pending_message_hash in messages_being_fetched:
await message.reject(requeue=True)
continue

# Check if the message is already processing
messages_being_fetched.add(pending_message.item_hash)

messages_being_fetched.add(pending_message_hash)
await node_cache.incr(retry_messages_cache_key)

pending_message_id = int(message.body.decode("utf-8"))
message_task = asyncio.create_task(
self.fetch_pending_message(
pending_message=pending_message,
pending_message_id=pending_message_id,
)
)
fetch_tasks.add(message_task)
task_message_dict[message_task] = pending_message
task_message_dict[message_task] = pending_message_hash

if fetched_messages:
yield fetched_messages
fetched_messages = []
if fetched_messages:
yield fetched_messages
fetched_messages = []

if not PendingMessageDb.count(session):
# If not in loop mode, stop if there are no more pending messages
if not loop:
break
# If we are done, wait a few seconds until retrying
if not fetch_tasks:
LOGGER.info("waiting 1 second(s) for new pending messages...")
await asyncio.sleep(1)
# If not in loop mode, stop if there are no more pending messages
if not loop:
if not messages_being_fetched:
break

def make_pipeline(
self,
Expand Down Expand Up @@ -179,27 +184,29 @@ async def fetch_messages_task(config: Config):
storage_service=storage_service,
config=config,
)
pending_message_queue = await make_pending_message_queue(
config=config, routing_key="fetch.*"
)
fetcher = PendingMessageFetcher(
session_factory=session_factory,
message_handler=message_handler,
max_retries=config.aleph.jobs.pending_messages.max_retries.value,
pending_message_queue=pending_message_queue,
)

while True:
with session_factory() as session:
try:
fetch_pipeline = fetcher.make_pipeline(
config=config, node_cache=node_cache
)
async for fetched_messages in fetch_pipeline:
for fetched_message in fetched_messages:
LOGGER.info(
"Successfully fetched %s", fetched_message.item_hash
)
try:
fetch_pipeline = fetcher.make_pipeline(
config=config, node_cache=node_cache
)
async for fetched_messages in fetch_pipeline:
for fetched_message in fetched_messages:
LOGGER.info(
"Successfully fetched %s", fetched_message.item_hash
)

except Exception:
LOGGER.exception("Error in pending messages job")
session.rollback()
except Exception:
LOGGER.exception("Unexpected error in pending messages fetch job")

LOGGER.debug("Waiting 1 second(s) for new pending messages...")
await asyncio.sleep(1)
Expand Down
52 changes: 51 additions & 1 deletion src/aleph/jobs/job_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Union
from typing import Tuple

import aio_pika
from configmanager import Config
from sqlalchemy import update

Expand All @@ -28,6 +29,46 @@
MAX_RETRY_INTERVAL: int = 300


async def _make_pending_queue(
config: Config, exchange_name: str, queue_name: str, routing_key: str
) -> aio_pika.abc.AbstractQueue:
mq_conn = await aio_pika.connect_robust(
host=config.p2p.mq_host.value,
port=config.rabbitmq.port.value,
login=config.rabbitmq.username.value,
password=config.rabbitmq.password.value,
)
channel = await mq_conn.channel()
exchange = await channel.declare_exchange(
name=exchange_name,
type=aio_pika.ExchangeType.TOPIC,
auto_delete=False,
)
queue = await channel.declare_queue(
name=queue_name, durable=True, auto_delete=False
)
await queue.bind(exchange, routing_key=routing_key)
return queue


async def make_pending_tx_queue(config: Config) -> aio_pika.abc.AbstractQueue:
return await _make_pending_queue(
config=config,
exchange_name=config.rabbitmq.pending_tx_exchange.value,
queue_name="pending-tx-queue",
routing_key="#",
)


async def make_pending_message_queue(config: Config, routing_key: str) -> aio_pika.abc.AbstractQueue:
return await _make_pending_queue(
config=config,
exchange_name=config.rabbitmq.pending_message_exchange.value,
queue_name="pending_message_queue",
routing_key=routing_key,
)


def compute_next_retry_interval(attempts: int) -> dt.timedelta:
"""
Computes the time interval for the next attempt/retry of a message.
Expand All @@ -39,7 +80,7 @@ def compute_next_retry_interval(attempts: int) -> dt.timedelta:
:return: The time interval between the previous processing attempt and the next one.
"""

seconds = 2 ** attempts
seconds = 2**attempts
return dt.timedelta(seconds=min(seconds, MAX_RETRY_INTERVAL))


Expand Down Expand Up @@ -92,11 +133,15 @@ def __init__(
session_factory: DbSessionFactory,
message_handler: MessageHandler,
max_retries: int,
pending_message_queue: aio_pika.abc.AbstractQueue,
):
self.session_factory = session_factory
self.message_handler = message_handler
self.max_retries = max_retries

self.task_lock = asyncio.Lock()
self.retry_task = None

def _handle_rejection(
self,
session: DbSession,
Expand All @@ -119,6 +164,11 @@ def _handle_rejection(

return RejectedMessage(pending_message=pending_message, error_code=error_code)

def _reinsert_failed_messages(self):
with self.task_lock():



def _handle_retry(
self,
session: DbSession,
Expand Down
22 changes: 1 addition & 21 deletions src/aleph/jobs/process_pending_txs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aleph.toolkit.timestamp import utc_now
from aleph.types.chain_sync import ChainSyncProtocol
from aleph.types.db_session import DbSessionFactory
from .job_utils import prepare_loop
from .job_utils import prepare_loop, make_pending_tx_queue

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,26 +98,6 @@ async def process_pending_txs(self) -> None:
)


async def make_pending_tx_queue(config: Config) -> aio_pika.abc.AbstractQueue:
mq_conn = await aio_pika.connect_robust(
host=config.p2p.mq_host.value,
port=config.rabbitmq.port.value,
login=config.rabbitmq.username.value,
password=config.rabbitmq.password.value,
)
channel = await mq_conn.channel()
pending_tx_exchange = await channel.declare_exchange(
name=config.rabbitmq.pending_tx_exchange.value,
type=aio_pika.ExchangeType.TOPIC,
auto_delete=False,
)
pending_tx_queue = await channel.declare_queue(
name="pending-tx-queue", durable=True, auto_delete=False
)
await pending_tx_queue.bind(pending_tx_exchange, routing_key="#")
return pending_tx_queue


async def handle_txs_task(config: Config):
engine = make_engine(config=config, application_name="aleph-txs")
session_factory = make_session_factory(engine)
Expand Down

0 comments on commit 464cd8b

Please sign in to comment.