Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix race in MultiWriterIdGenerator #11045

Merged
merged 7 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11045.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race.
82 changes: 67 additions & 15 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)

import attr
from sortedcontainers import SortedSet
from sortedcontainers import SortedList, SortedSet

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import (
Expand Down Expand Up @@ -265,6 +265,15 @@ def __init__(
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()

# We also need to track when we've requested some new stream IDs but
# they haven't yet been added to the `_unfinished_ids` set. Every time
# we request a new stream ID we add the current max stream ID to the
# list, and remove it once we've added the newly allocated IDs to the
# `_unfinished_ids` set. This means that we *may* be allocated stream
# IDs above those in the list, and so we can't advance the local current
# position beyond the minimum stream ID in this list.
self._in_flight_fetches: SortedList[int] = SortedList()

# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
Expand All @@ -290,6 +299,9 @@ def __init__(
)
self._known_persisted_positions: List[int] = []

# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
squahtx marked this conversation as resolved.
Show resolved Hide resolved

self._sequence_gen = PostgresSequenceGenerator(sequence_name)

# We check that the table and sequence haven't diverged.
Expand All @@ -305,6 +317,10 @@ def __init__(
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)

self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)

def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
Expand Down Expand Up @@ -411,10 +427,32 @@ def _load_current_ids(
cur.close()

def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn)
stream_ids = self._load_next_mult_id_txn(txn, 1)
return stream_ids[0]

def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
# We need to track that we've requested some more stream IDs, and what
# the current max allocated stream ID is. This is to prevent a race
# where we've been allocated stream IDs but they have not yet been added
# to the `_unfinished_ids` set, allowing the current position to advance
# past them.
with self._lock:
current_max = self._max_seen_allocated_stream_id
self._in_flight_fetches.add(current_max)

try:
stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)

with self._lock:
self._unfinished_ids.update(stream_ids)
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
)
finally:
with self._lock:
self._in_flight_fetches.remove(current_max)

return stream_ids

def get_next(self) -> AsyncContextManager[int]:
"""
Expand Down Expand Up @@ -463,9 +501,6 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

next_id = self._load_next_id_txn(txn)

with self._lock:
self._unfinished_ids.add(next_id)

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)

Expand Down Expand Up @@ -497,15 +532,27 @@ def _mark_id_as_finished(self, next_id: int) -> None:

new_cur: Optional[int] = None

if self._unfinished_ids:
if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the
# largest finished ID less than the minimum unfinished ID.
# largest finished ID strictly less than the minimum unfinished
# ID.

# The minimum unfinished ID needs to take account of both
# `_unfinished_ids` and `_in_flight_fetches`.
if self._unfinished_ids and self._in_flight_fetches:
# `_in_flight_fetches` stores the maximum safe stream ID, so
# we add one to make it equivalent to the minimum unsafe ID.
min_unfinished = min(
self._unfinished_ids[0], self._in_flight_fetches[0] + 1
)
elif self._in_flight_fetches:
min_unfinished = self._in_flight_fetches[0] + 1
else:
min_unfinished = self._unfinished_ids[0]

finished = set()

min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
if s < min_unfinshed:
if s < min_unfinished:
if new_cur is None or new_cur < s:
new_cur = s
else:
Expand Down Expand Up @@ -575,6 +622,10 @@ def advance(self, instance_name: str, new_id: int) -> None:
new_id, self._current_positions.get(instance_name, 0)
)

self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, new_id
)

self._add_persisted_position(new_id)

def get_persisted_upto_position(self) -> int:
Expand Down Expand Up @@ -605,7 +656,11 @@ def _add_persisted_position(self, new_id: int) -> None:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
if our_current_position and not self._unfinished_ids:
if (
our_current_position
and not self._unfinished_ids
and not self._in_flight_fetches
):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
Expand Down Expand Up @@ -697,9 +752,6 @@ async def __aenter__(self) -> Union[int, List[int]]:
db_autocommit=True,
)

with self.id_gen._lock:
self.id_gen._unfinished_ids.update(self.stream_ids)

if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
Expand Down