Skip to content

Commit

Permalink
Add check for primary endpoint ready (#46)
Browse files Browse the repository at this point in the history
* Add check for started service

* Add check for primary endpoint being ready

* Update charm unit tests

* Add unit test

* Fix unit test

* Make property private

* Improve primary endpoint check

* Revert change

* Add test for relation broken
  • Loading branch information
marceloneppel authored Sep 27, 2022
1 parent 946b48f commit 7b5f706
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 42 deletions.
46 changes: 37 additions & 9 deletions lib/charms/postgresql_k8s/v0/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Any charm using this library should import the `psycopg2` or `psycopg2-binary` dependency.
"""
import logging
from typing import Set

import psycopg2
from psycopg2 import sql
Expand All @@ -31,7 +32,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 4
LIBPATCH = 5


logger = logging.getLogger(__name__)
Expand All @@ -53,6 +54,10 @@ class PostgreSQLGetPostgreSQLVersionError(Exception):
"""Exception raised when retrieving PostgreSQL version fails."""


class PostgreSQLListUsersError(Exception):
"""Exception raised when retrieving PostgreSQL users list fails."""


class PostgreSQLUpdateUserPasswordError(Exception):
"""Exception raised when updating a user password fails."""

Expand Down Expand Up @@ -132,13 +137,18 @@ def create_user(
try:
with self._connect_to_database() as connection, connection.cursor() as cursor:
cursor.execute(f"SELECT TRUE FROM pg_roles WHERE rolname='{user}';")
user_definition = f"{user} WITH LOGIN{' SUPERUSER' if admin else ''} ENCRYPTED PASSWORD '{password}'"
user_definition = (
f"WITH LOGIN{' SUPERUSER' if admin else ''} ENCRYPTED PASSWORD '{password}'"
)
if extra_user_roles:
user_definition += f' {extra_user_roles.replace(",", " ")}'
if cursor.fetchone() is not None:
cursor.execute(f"ALTER ROLE {user_definition};")
statement = "ALTER ROLE {}"
else:
cursor.execute(f"CREATE ROLE {user_definition};")
statement = "CREATE ROLE {}"
cursor.execute(
sql.SQL(f"{statement} {user_definition};").format(sql.Identifier(user))
)
except psycopg2.Error as e:
logger.error(f"Failed to create user: {e}")
raise PostgreSQLCreateUserError()
Expand All @@ -161,12 +171,16 @@ def delete_user(self, user: str) -> None:
with self._connect_to_database(
database
) as connection, connection.cursor() as cursor:
cursor.execute(f"REASSIGN OWNED BY {user} TO postgres;")
cursor.execute(f"DROP OWNED BY {user};")
cursor.execute(
sql.SQL("REASSIGN OWNED BY {} TO {};").format(
sql.Identifier(user), sql.Identifier(self.user)
)
)
cursor.execute(sql.SQL("DROP OWNED BY {};").format(sql.Identifier(user)))

# Delete the user.
with self._connect_to_database() as connection, connection.cursor() as cursor:
cursor.execute(f"DROP ROLE {user};")
cursor.execute(sql.SQL("DROP ROLE {};").format(sql.Identifier(user)))
except psycopg2.Error as e:
logger.error(f"Failed to delete user: {e}")
raise PostgreSQLDeleteUserError()
Expand Down Expand Up @@ -202,11 +216,25 @@ def is_tls_enabled(self, check_current_host: bool = False) -> bool:
) as connection, connection.cursor() as cursor:
cursor.execute("SHOW ssl;")
return "on" in cursor.fetchone()[0]
except psycopg2.Error as e:
except psycopg2.Error:
# Connection errors happen when PostgreSQL has not started yet.
logger.debug(e.pgerror)
return False

def list_users(self) -> Set[str]:
"""Returns the list of PostgreSQL database users.
Returns:
List of PostgreSQL database users.
"""
try:
with self._connect_to_database() as connection, connection.cursor() as cursor:
cursor.execute("SELECT usename FROM pg_catalog.pg_user;")
usernames = cursor.fetchall()
return {username[0] for username in usernames}
except psycopg2.Error as e:
logger.error(f"Failed to list PostgreSQL database users: {e}")
raise PostgreSQLListUsersError()

def update_user_password(self, username: str, password: str) -> None:
"""Update a user password.
Expand Down
28 changes: 25 additions & 3 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def _on_postgresql_pebble_ready(self, event: WorkloadEvent) -> None:
self.unit.status = BlockedStatus(f"failed to patch pod with error {e}")
return

if not self._patroni.primary_endpoint_ready:
self.unit.status = WaitingStatus("awaiting for primary endpoint to be ready")
event.defer()
return

self._peers.data[self.app]["cluster_initialised"] = "True"

# Update the replication configuration.
Expand All @@ -416,7 +421,15 @@ def _on_postgresql_pebble_ready(self, event: WorkloadEvent) -> None:
self.unit.status = ActiveStatus()

def _on_upgrade_charm(self, _) -> None:
# Add labels required for replication when the pod loses them (like when it's deleted).
# Recreate k8s resources and add labels required for replication
# when the pod loses them (like when it's deleted).
try:
self._create_resources()
except ApiError:
logger.exception("failed to create k8s resources")
self.unit.status = BlockedStatus("failed to create k8s resources")
return

try:
self._patch_pod_labels(self.unit.name)
except ApiError as e:
Expand Down Expand Up @@ -581,7 +594,16 @@ def _on_stop(self, _) -> None:
logger.error(f"failed to delete resource: {resource}.")

def _on_update_status(self, _) -> None:
# Display an active status message if the current unit is the primary.
"""Display an active status message if the current unit is the primary."""
container = self.unit.get_container("postgresql")
if not container.can_connect():
return

services = container.pebble.get_services(names=[self._postgresql_service])
if len(services) == 0:
# Service has not been added nor started yet, so don't try to check Patroni API.
return

try:
if self._patroni.get_primary(unit_name_pattern=True) == self.unit.name:
self.unit.status = ActiveStatus("Primary")
Expand All @@ -594,8 +616,8 @@ def _patroni(self):
return Patroni(
self._endpoint,
self._endpoints,
self.primary_endpoint,
self._namespace,
self.app.planned_units(),
self._storage_path,
self.get_secret("app", USER_PASSWORD_KEY),
self.get_secret("app", REPLICATION_PASSWORD_KEY),
Expand Down
32 changes: 29 additions & 3 deletions src/patroni.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,30 @@ class NotReadyError(Exception):
"""Raised when not all cluster members healthy or finished initial sync."""


class EndpointNotReadyError(Exception):
"""Raised when an endpoint is not ready."""


class Patroni:
"""This class handles the communication with Patroni API and configuration files."""

def __init__(
self,
endpoint: str,
endpoints: List[str],
primary_endpoint: str,
namespace: str,
planned_units,
storage_path: str,
superuser_password: str,
replication_password: str,
tls_enabled: bool,
):
self._endpoint = endpoint
self._endpoints = endpoints
self._primary_endpoint = primary_endpoint
self._namespace = namespace
self._storage_path = storage_path
self._planned_units = planned_units
self._members_count = len(self._endpoints)
self._superuser_password = superuser_password
self._replication_password = replication_password
self._tls_enabled = tls_enabled
Expand Down Expand Up @@ -109,6 +114,27 @@ def are_all_members_ready(self) -> bool:

return all(member["state"] == "running" for member in r.json()["members"])

@property
def primary_endpoint_ready(self) -> bool:
"""Is the primary endpoint redirecting connections to the primary pod.
Returns:
Return whether the primary endpoint is redirecting connections to the primary pod.
"""
try:
for attempt in Retrying(stop=stop_after_delay(10), wait=wait_fixed(3)):
with attempt:
r = requests.get(
f"{'https' if self._tls_enabled else 'http'}://{self._primary_endpoint}:8008/health",
verify=self._verify,
)
if r.json()["state"] != "running":
raise EndpointNotReadyError
except RetryError:
return False

return True

@property
def member_started(self) -> bool:
"""Has the member started Patroni and PostgreSQL.
Expand Down Expand Up @@ -178,7 +204,7 @@ def render_postgresql_conf_file(self) -> None:
# TODO: add extra configurations here later.
rendered = template.render(
logging_collector="on",
synchronous_commit="on" if self._planned_units > 1 else "off",
synchronous_commit="on" if self._members_count > 1 else "off",
synchronous_standby_names="*",
)
self._render_file(f"{self._storage_path}/postgresql-k8s-operator.conf", rendered, 0o644)
Expand Down
6 changes: 5 additions & 1 deletion src/resources.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ metadata:
name: {{ app_name }}-primary
spec:
ports:
- port: 5432
- name: database
port: 5432
targetPort: 5432
- name: api
port: 8008
targetPort: 8008
selector:
cluster-name: patroni-{{ app_name }}
role: master
Expand Down
13 changes: 9 additions & 4 deletions tests/integration/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def check_database_users_existence(
users_that_should_exist: List[str],
users_that_should_not_exist: List[str],
admin: bool = False,
database_app_name: str = DATABASE_APP_NAME,
) -> None:
"""Checks that applications users exist in the database.
Expand All @@ -41,10 +42,12 @@ async def check_database_users_existence(
users_that_should_exist: List of users that should exist in the database
users_that_should_not_exist: List of users that should not exist in the database
admin: Whether to check if the existing users are superusers
database_app_name: Optional database app name
(the default value is the name on metadata.yaml)
"""
unit = ops_test.model.applications[DATABASE_APP_NAME].units[0]
unit = ops_test.model.applications[database_app_name].units[0]
unit_address = await get_unit_address(ops_test, unit.name)
password = await get_password(ops_test)
password = await get_password(ops_test, database_app_name=database_app_name)

# Retrieve all users in the database.
output = await execute_query_on_unit(
Expand Down Expand Up @@ -344,9 +347,11 @@ def get_expected_k8s_resources(namespace: str, application: str) -> set:
return resources


async def get_password(ops_test: OpsTest, username: str = "operator"):
async def get_password(
ops_test: OpsTest, username: str = "operator", database_app_name: str = DATABASE_APP_NAME
):
"""Retrieve a user password using the action."""
unit = ops_test.model.units.get(f"{DATABASE_APP_NAME}/0")
unit = ops_test.model.units.get(f"{database_app_name}/0")
action = await unit.run_action("get-password", **{"username": username})
result = await action.wait()
return result.results[f"{username}-password"]
Expand Down
45 changes: 31 additions & 14 deletions tests/integration/new_relations/test_new_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yaml
from pytest_operator.plugin import OpsTest

from tests.integration.helpers import scale_application
from tests.integration.helpers import check_database_users_existence, scale_application
from tests.integration.new_relations.helpers import (
build_connection_string,
check_relation_data_existence,
Expand All @@ -32,8 +32,10 @@

@pytest.mark.abort_on_fail
@pytest.mark.database_relation_tests
async def test_deploy_charms(ops_test: OpsTest, application_charm, database_charm):
"""Deploy both charms (application and database) to use in the tests."""
async def test_database_relation_with_charm_libraries(
ops_test: OpsTest, application_charm, database_charm
):
"""Test basic functionality of database relation interface."""
# Deploy both charms (multiple units for each application to test that later they correctly
# set data in the relation application databag using only the leader unit).
async with ops_test.fast_forward():
Expand Down Expand Up @@ -66,17 +68,11 @@ async def test_deploy_charms(ops_test: OpsTest, application_charm, database_char
trust=True,
),
)
await ops_test.model.wait_for_idle(apps=APP_NAMES, status="active", wait_for_units=1)


@pytest.mark.database_relation_tests
async def test_database_relation_with_charm_libraries(ops_test: OpsTest):
"""Test basic functionality of database relation interface."""
# Relate the charms and wait for them exchanging some connection data.
await ops_test.model.add_relation(
f"{APPLICATION_APP_NAME}:{FIRST_DATABASE_RELATION_NAME}", DATABASE_APP_NAME
)
await ops_test.model.wait_for_idle(apps=APP_NAMES, status="active")
# Relate the charms and wait for them exchanging some connection data.
await ops_test.model.add_relation(
f"{APPLICATION_APP_NAME}:{FIRST_DATABASE_RELATION_NAME}", DATABASE_APP_NAME
)
await ops_test.model.wait_for_idle(apps=APP_NAMES, status="active", raise_on_blocked=True)

# Get the connection string to connect to the database using the read/write endpoint.
connection_string = await build_connection_string(
Expand Down Expand Up @@ -303,3 +299,24 @@ async def test_read_only_endpoint_in_scaled_up_cluster(ops_test: OpsTest):
"read-only-endpoints",
exists=True,
)


@pytest.mark.database_relation_tests
async def test_relation_broken(ops_test: OpsTest):
"""Test that the user is removed when the relation is broken."""
async with ops_test.fast_forward():
# Retrieve the relation user.
relation_user = await get_application_relation_data(
ops_test, APPLICATION_APP_NAME, FIRST_DATABASE_RELATION_NAME, "username"
)

# Break the relation.
await ops_test.model.applications[DATABASE_APP_NAME].remove_relation(
f"{DATABASE_APP_NAME}", f"{APPLICATION_APP_NAME}:{FIRST_DATABASE_RELATION_NAME}"
)
await ops_test.model.wait_for_idle(apps=APP_NAMES, status="active", raise_on_blocked=True)

# Check that the relation user was removed from the database.
await check_database_users_existence(
ops_test, [], [relation_user], database_app_name=DATABASE_APP_NAME
)
Loading

0 comments on commit 7b5f706

Please sign in to comment.