Skip to content

Commit

Permalink
Issuing a watchdog to ensure our issuing thread is still alive (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese authored Apr 4, 2024
1 parent df29bbe commit bbbaf0f
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions src/replit/database/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Async and dict-like interfaces for interacting with Replit Database."""

import asyncio
from collections import abc
import json
import threading
Expand Down Expand Up @@ -82,8 +83,17 @@ class AsyncDatabase:
:param unbind Callable: Permit additional behavior after Database close
"""

__slots__ = ("db_url", "sess", "client", "_get_db_url", "_unbind", "_refresh_timer")
__slots__ = (
"db_url",
"sess",
"client",
"_get_db_url",
"_unbind",
"_refresh_timer",
"_watchdog_timer",
)
_refresh_timer: Optional[threading.Timer]
_watchdog_timer: Optional[threading.Timer]

def __init__(
self,
Expand Down Expand Up @@ -113,6 +123,9 @@ def __init__(
if self._get_db_url:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()
watched_thread = threading.main_thread()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()

def _refresh_db(self) -> None:
if self._refresh_timer:
Expand All @@ -125,6 +138,14 @@ def _refresh_db(self) -> None:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()

def _watchdog(self, watched_thread: threading.Thread) -> None:
if not watched_thread.is_alive():
return asyncio.run(self.close())
if self._watchdog_timer:
self._watchdog_timer.cancel()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()

def update_db_url(self, db_url: str) -> None:
"""Update the database url.
Expand Down Expand Up @@ -292,6 +313,9 @@ async def close(self) -> None:
if self._refresh_timer:
self._refresh_timer.cancel()
self._refresh_timer = None
if self._watchdog_timer:
self._watchdog_timer.cancel()
self._watchdog_timer = None
if self._unbind:
# Permit signaling to surrounding scopes that we have closed
self._unbind()
Expand Down Expand Up @@ -485,8 +509,16 @@ class Database(abc.MutableMapping):
:param unbind Callable: Permit additional behavior after Database close
"""

__slots__ = ("db_url", "sess", "_get_db_url", "_unbind", "_refresh_timer")
__slots__ = (
"db_url",
"sess",
"_get_db_url",
"_unbind",
"_refresh_timer",
"_watchdog_timer",
)
_refresh_timer: Optional[threading.Timer]
_watchdog_timer: Optional[threading.Timer]

def __init__(
self,
Expand Down Expand Up @@ -518,6 +550,9 @@ def __init__(
if self._get_db_url:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()
watched_thread = threading.main_thread()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()

def _refresh_db(self) -> None:
if self._refresh_timer:
Expand All @@ -530,6 +565,14 @@ def _refresh_db(self) -> None:
self._refresh_timer = threading.Timer(3600, self._refresh_db)
self._refresh_timer.start()

def _watchdog(self, watched_thread: threading.Thread) -> None:
if not watched_thread.is_alive():
return self.close()
if self._watchdog_timer:
self._watchdog_timer.cancel()
self._watchdog_timer = threading.Timer(1, self._watchdog, args=[watched_thread])
self._watchdog_timer.start()

def update_db_url(self, db_url: str) -> None:
"""Update the database url.
Expand Down Expand Up @@ -720,6 +763,9 @@ def close(self) -> None:
if self._refresh_timer:
self._refresh_timer.cancel()
self._refresh_timer = None
if self._watchdog_timer:
self._watchdog_timer.cancel()
self._watchdog_timer = None
if self._unbind:
# Permit signaling to surrounding scopes that we have closed
self._unbind()

0 comments on commit bbbaf0f

Please sign in to comment.