diff --git a/README.md b/README.md index 95bccc0..2360c5f 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,11 @@ modules: # Defaults to false. accept_invites_only_from_local_users: false + # Optional: if set to true, invites for suspended users will be auto + # accepted. + # Defaults to false. + accept_invites_for_suspended_users: false + # (For workerised Synapse deployments) # # This module should only be active on a single worker process at once, diff --git a/synapse_auto_accept_invite/__init__.py b/synapse_auto_accept_invite/__init__.py index 8983181..88b0d02 100644 --- a/synapse_auto_accept_invite/__init__.py +++ b/synapse_auto_accept_invite/__init__.py @@ -25,6 +25,7 @@ class InviteAutoAccepterConfig: accept_invites_only_for_direct_messages: bool = False accept_invites_only_from_local_users: bool = False + accept_invites_for_suspended_users: bool = False worker_to_run_on: Optional[str] = None @@ -70,12 +71,16 @@ def parse_config(config: Dict[str, Any]) -> InviteAutoAccepterConfig: accept_invites_only_from_local_users = config.get( "accept_invites_only_from_local_users", False ) + accept_invites_for_suspended_users = config.get( + "accept_invites_for_suspended_users", False + ) worker_to_run_on = config.get("worker_to_run_on", None) return InviteAutoAccepterConfig( accept_invites_only_for_direct_messages=accept_invites_only_for_direct_messages, accept_invites_only_from_local_users=accept_invites_only_from_local_users, + accept_invites_for_suspended_users=accept_invites_for_suspended_users, worker_to_run_on=worker_to_run_on, ) @@ -86,50 +91,64 @@ async def on_new_event(self, event: EventBase, *args: Any) -> None: Args: event: The incoming event. """ + # Check if the event is an invite for a local user. - is_invite_for_local_user = ( + if not ( event.type == "m.room.member" and event.is_state() and event.membership == "invite" and self._api.is_mine(event.state_key) - ) + ): + return # Only accept invites for direct messages if the configuration mandates it. is_direct_message = event.content.get("is_direct", False) - is_allowed_by_direct_message_rules = ( - not self._config.accept_invites_only_for_direct_messages - or is_direct_message is True - ) + if ( + self._config.accept_invites_only_for_direct_messages + and is_direct_message is False + ): + return # Only accept invites from remote users if the configuration mandates it. is_from_local_user = self._api.is_mine(event.sender) - is_allowed_by_local_user_rules = ( - not self._config.accept_invites_only_from_local_users - or is_from_local_user is True - ) - if ( - is_invite_for_local_user - and is_allowed_by_direct_message_rules - and is_allowed_by_local_user_rules + self._config.accept_invites_only_from_local_users + and is_from_local_user is False ): - # Make the user join the room. We run this as a background process to circumvent a race condition - # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) - run_as_background_process( - "retry_make_join", - self._retry_make_join, - event.state_key, - event.state_key, - event.room_id, - "join", - bg_start_span=False, - ) + return - if is_direct_message: - # Mark this room as a direct message! - await self._mark_room_as_direct_message( - event.state_key, event.sender, event.room_id - ) + # Check the user is activated. + recipient = await self._api.get_userinfo_by_id(event.state_key) + + # Ignore if the user doesn't exist. + if recipient is None: + return + + # Never accept invites for deactivated users. + if recipient.is_deactivated: + return + + # Only accept invites for suspended remote if the configuration mandates it. + if not self._config.accept_invites_for_suspended_users and recipient.suspended: + return + + # Make the user join the room. We run this as a background process to circumvent a race condition + # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) + run_as_background_process( + "retry_make_join", + self._retry_make_join, + event.state_key, + event.state_key, + event.room_id, + "join", + bg_start_span=False, + ) + + if is_direct_message: + # Mark this room as a direct message! + await self._mark_room_as_direct_message( + event.state_key, event.sender, event.room_id + ) async def _mark_room_as_direct_message( self, user_id: str, dm_user_id: str, room_id: str diff --git a/tests/__init__.py b/tests/__init__.py index abb3ad5..f73ef94 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -18,6 +18,7 @@ import attr from synapse.module_api import ModuleApi +from synapse.types import UserID, UserInfo from synapse_auto_accept_invite import InviteAutoAccepter @@ -73,6 +74,22 @@ def create_module( module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test" module_api.worker_name = worker_name module_api.sleep.return_value = make_multiple_awaitable(None) + module_api.get_userinfo_by_id.return_value = UserInfo( + user_id=UserID.from_string("@user:test"), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + suspended=False, + ) config = InviteAutoAccepter.parse_config(config_override) diff --git a/tests/test_accept_invite.py b/tests/test_accept_invite.py index 1fb55b7..eaa49fd 100644 --- a/tests/test_accept_invite.py +++ b/tests/test_accept_invite.py @@ -17,6 +17,7 @@ import aiounittest from frozendict import frozendict +from synapse.types import UserID, UserInfo from synapse_auto_accept_invite import InviteAutoAccepter from tests import MockEvent, create_module, make_awaitable @@ -31,6 +32,7 @@ def setUp(self) -> None: # We know our module API is a mock, but mypy doesn't. self.mocked_update_membership: Mock = self.module._api.update_room_membership # type: ignore[assignment] + self.get_userinfo_by_id: Mock = self.module._api.get_userinfo_by_id # type: ignore[assignment] async def test_simple_accept_invite(self) -> None: """Tests that receiving an invite for a local user makes the module attempt to @@ -453,6 +455,138 @@ def test_runs_on_only_one_worker(self) -> None: Mock, specified_module._api.register_third_party_rules_callbacks ).assert_called_once() + async def test_ignore_invite_from_missing_user(self) -> None: + """Tests that receiving an invite for a missing user is ignored.""" + invite = MockEvent( + sender=self.remote_invitee, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + self.get_userinfo_by_id.return_value = None + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + self.mocked_update_membership.assert_not_called() + + async def test_ignore_invite_from_deactivated_user(self) -> None: + """Tests that receiving an invite for a deactivated user is ignored.""" + invite = MockEvent( + sender=self.remote_invitee, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + self.get_userinfo_by_id.return_value = UserInfo( + user_id=UserID.from_string("@user:test"), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=True, + locked=False, + is_shadow_banned=False, + approved=True, + suspended=False, + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + self.mocked_update_membership.assert_not_called() + + async def test_ignore_invite_from_suspended_user(self) -> None: + """Tests that receiving an invite for a suspended user is ignored by default.""" + invite = MockEvent( + sender=self.remote_invitee, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + self.get_userinfo_by_id.return_value = UserInfo( + user_id=UserID.from_string("@user:test"), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + suspended=True, + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await self.module.on_new_event(event=invite) # type: ignore[arg-type] + self.mocked_update_membership.assert_not_called() + + async def test_accept_invite_for_suspended_user_if_enabled( + self, + ) -> None: + """Tests that, if the module is configured to accept invites for suspended users, invites + are still automatically accepted. + """ + module = create_module( + config_override={"accept_invites_for_suspended_users": True}, + ) + + get_userinfo_by_id: Mock = module._api.get_userinfo_by_id # type: ignore[assignment] + get_userinfo_by_id.return_value = UserInfo( + user_id=UserID.from_string("@user:test"), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + suspended=True, + ) + + mocked_update_membership: Mock = module._api.update_room_membership # type: ignore[assignment] + join_event = MockEvent( + sender="someone", + state_key="someone", + type="m.room.member", + content={"membership": "join"}, + ) + mocked_update_membership.return_value = make_awaitable(join_event) + + invite = MockEvent( + sender=self.user_id, + state_key=self.invitee, + type="m.room.member", + content={"membership": "invite"}, + ) + + # Stop mypy from complaining that we give on_new_event a MockEvent rather than an + # EventBase. + await module.on_new_event(event=invite) # type: ignore[arg-type] + + await self.retry_assertions( + mocked_update_membership, + 1, + sender=invite.state_key, + target=invite.state_key, + room_id=invite.room_id, + new_membership="join", + ) + async def retry_assertions( self, mock: Mock, call_count: int, **kwargs: Any ) -> None: