Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure correct use_* values when matching tables have no data #143

Closed
8 changes: 7 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@ Changelog of threedi-schema



0.228.1 (unreleased)
0.228.2 (unreleased)
--------------------

- Rename sqlite table "tags" to "tag"
- Remove indices referring to removed tables in previous migrations
- Remove columns referencing v2 in geometry_column
- Ensure correct use_* values when matching tables have no data


0.228.1 (2024-11-26)
--------------------

- Add `progress_func` argument to schema.upgrade


0.228.0 (2024-11-25)
--------------------
Expand Down
2 changes: 1 addition & 1 deletion threedi_schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from .domain import constants, custom_types, models # NOQA

# fmt: off
__version__ = '0.228.1.dev0'
__version__ = '0.228.2.dev0'

# fmt: on
18 changes: 14 additions & 4 deletions threedi_schema/application/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..infrastructure.spatial_index import ensure_spatial_indexes
from ..infrastructure.spatialite_versions import copy_models, get_spatialite_version
from .errors import MigrationMissingError, UpgradeFailedError
from .upgrade_utils import setup_logging

__all__ = ["ModelSchema"]

Expand All @@ -39,11 +40,12 @@ def get_schema_version():
return int(env.get_head_revision())


def _upgrade_database(db, revision="head", unsafe=True):
def _upgrade_database(db, revision="head", unsafe=True, progress_func=None):
"""Upgrade ThreediDatabase instance"""
engine = db.engine

config = get_alembic_config(engine, unsafe=unsafe)
if progress_func is not None:
setup_logging(db.schema, revision, config, progress_func)
alembic_command.upgrade(config, revision)


Expand Down Expand Up @@ -88,6 +90,7 @@ def upgrade(
backup=True,
upgrade_spatialite_version=False,
convert_to_geopackage=False,
progress_func=None,
):
"""Upgrade the database to the latest version.

Expand All @@ -106,6 +109,9 @@ def upgrade(

Specify 'convert_to_geopackage=True' to also convert from spatialite
to geopackage file version after the upgrade.

Specify a 'progress_func' to handle progress updates. `progress_func` should
expect a single argument representing the fraction of progress
"""
try:
rev_nr = get_schema_version() if revision == "head" else int(revision)
Expand All @@ -127,9 +133,13 @@ def upgrade(
)
if backup:
with self.db.file_transaction() as work_db:
_upgrade_database(work_db, revision=revision, unsafe=True)
_upgrade_database(
work_db, revision=revision, unsafe=True, progress_func=progress_func
)
else:
_upgrade_database(self.db, revision=revision, unsafe=False)
_upgrade_database(
self.db, revision=revision, unsafe=False, progress_func=progress_func
)
if upgrade_spatialite_version:
self.upgrade_spatialite_version()
elif convert_to_geopackage:
Expand Down
81 changes: 81 additions & 0 deletions threedi_schema/application/upgrade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import logging
from typing import Callable, TYPE_CHECKING

from alembic.config import Config
from alembic.script import ScriptDirectory

if TYPE_CHECKING:
from .schema import ModelSchema
else:
ModelSchema = None


class ProgressHandler(logging.Handler):
def __init__(self, progress_func, total_steps):
super().__init__()
self.progress_func = progress_func
self.total_steps = total_steps
self.current_step = 0

def emit(self, record):
msg = record.getMessage()
if msg.startswith("Running upgrade"):
self.progress_func(100 * self.current_step / self.total_steps)
self.current_step += 1


def get_upgrade_steps_count(
config: Config, current_revision: int, target_revision: str = "head"
) -> int:
"""
Count number of upgrade steps for a schematisation upgrade.

Args:
config: Config parameter containing the configuration information
current_revision: current revision as integer
target_revision: target revision as zero-padded 4 digit string or "head"
"""
if target_revision != "head":
try:
int(target_revision)
except TypeError:
# this should lead to issues in the upgrade pipeline, lets not take over that error handling here
return 0
# walk_revisions also includes the revision from current_revision to previous
# reduce the number of steps with 1
offset = -1
# The first defined revision is 200; revision numbers < 200 will cause walk_revisions to fail
if current_revision < 200:
current_revision = 200
# set offset to 0 because previous to current is not included in walk_revisions
offset = 0
if target_revision != "head" and int(target_revision) < current_revision:
# assume that this will be correctly handled by alembic
return 0
current_revision_str = f"{current_revision:04d}"
script = ScriptDirectory.from_config(config)
# Determine upgrade steps
revisions = script.walk_revisions(current_revision_str, target_revision)
return len(list(revisions)) + offset


def setup_logging(
schema: ModelSchema,
target_revision: str,
config: Config,
progress_func: Callable[[float], None],
):
"""
Set up logging for schematisation upgrade

Args:
schema: ModelSchema object representing the current schema of the application
target_revision: A str specifying the target revision for migration
config: Config object containing configuration settings
progress_func: A Callable with a single argument of type float, used to track progress during migration
"""
n_steps = get_upgrade_steps_count(config, schema.get_version(), target_revision)
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
handler = ProgressHandler(progress_func, total_steps=n_steps)
logger.addHandler(handler)
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def set_use_inteception():
);
"""))

op.execute(sa.text("""
DELETE FROM interception
WHERE (interception IS NULL OR interception = '')
AND (interception_file IS NULL OR interception_file = '');
"""))


def delete_all_but_matching_id(table, settings_id):
op.execute(f"DELETE FROM {table} WHERE id NOT IN (SELECT {settings_id} FROM model_settings);")
Expand Down
11 changes: 11 additions & 0 deletions threedi_schema/migrations/versions/0223_upgrade_db_inflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,17 @@ def populate_surface_and_dry_weather_flow():
# Populate tables with default values
populate_dry_weather_flow_distribution()
populate_surface_parameters()
update_use_0d_inflow()

def update_use_0d_inflow():
op.execute(sa.text("""
UPDATE simulation_template_settings
SET use_0d_inflow = 0
WHERE
(SELECT COUNT(*) FROM surface) = 0
AND
(SELECT COUNT(*) FROM dry_weather_flow) = 0;
"""))


def set_surface_parameters_id():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,17 @@ def fix_geometry_columns():
op.execute(sa.text(migration_query))


def update_use_structure_control():
op.execute("""
UPDATE simulation_template_settings SET use_structure_control = CASE
WHEN
(SELECT COUNT(*) FROM table_control) = 0 AND
(SELECT COUNT(*) FROM memory_control) = 0 THEN 0
ELSE use_structure_control
END;
""")


def upgrade():
# Remove existing tables (outside of the specs) that conflict with new table names
drop_conflicting(op, list(ADD_TABLES.keys()) + [new_name for _, new_name in RENAME_TABLES])
Expand All @@ -395,6 +406,7 @@ def upgrade():
rename_measure_operator('memory_control')
move_setting('model_settings', 'use_structure_control',
'simulation_template_settings', 'use_structure_control')
update_use_structure_control()
remove_tables(DEL_TABLES)
# Fix geometry columns and also make all but geom column nullable
fix_geometry_columns()
Expand Down
37 changes: 37 additions & 0 deletions threedi_schema/migrations/versions/0229_clean_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import sqlalchemy as sa
from alembic import op

from threedi_schema import models

# revision identifiers, used by Alembic.
revision = "0229"
down_revision = "0228"
Expand Down Expand Up @@ -60,6 +62,41 @@ def upgrade():
clean_geometry_columns()
clean_triggers()

def update_use_settings():
# Ensure that use_* settings are only True when there is actual data for them
use_settings = [
(models.ModelSettings.use_groundwater_storage, models.GroundWater),
(models.ModelSettings.use_groundwater_flow, models.GroundWater),
(models.ModelSettings.use_interflow, models.Interflow),
(models.ModelSettings.use_simple_infiltration, models.SimpleInfiltration),
(models.ModelSettings.use_vegetation_drag_2d, models.VegetationDrag),
(models.ModelSettings.use_interception, models.Interception)
]
connection = op.get_bind() # Get the connection for raw SQL execution
for setting, table in use_settings:
use_row = connection.execute(
sa.select(getattr(models.ModelSettings, setting.name))
).scalar()
if not use_row:
continue
row = connection.execute(sa.select(table)).first()
use_row = (row is not None)
if use_row:
use_row = not all(
getattr(row, column.name) in (None, "")
for column in table.__table__.columns
if column.name != "id"
)
if not use_row:
connection.execute(
sa.update(models.ModelSettings)
.values({setting.name: False})
)


def upgrade():
update_use_settings()


def downgrade():
pass
56 changes: 56 additions & 0 deletions threedi_schema/tests/test_upgrade_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from pathlib import Path
from unittest.mock import call, MagicMock

import pytest

from threedi_schema.application import upgrade_utils
from threedi_schema.application.schema import get_alembic_config
from threedi_schema.application.threedi_database import ThreediDatabase

data_dir = Path(__file__).parent / "data"


def test_progress_handler():
progress_func = MagicMock()
mock_record = MagicMock(levelno=logging.INFO, percent=40)
expected_calls = [call(100 * i / 5) for i in range(5)]
handler = upgrade_utils.ProgressHandler(progress_func, total_steps=5)
for _ in range(5):
handler.handle(mock_record)
assert progress_func.call_args_list == expected_calls


@pytest.mark.parametrize(
"target_revision, nsteps_expected", [("0226", 5), ("0200", 0), (None, 0)]
)
def test_get_upgrade_steps_count(target_revision, nsteps_expected):
schema = ThreediDatabase(data_dir.joinpath("v2_bergermeer_221.sqlite")).schema
nsteps = upgrade_utils.get_upgrade_steps_count(
config=get_alembic_config(),
current_revision=schema.get_version(),
target_revision=target_revision,
)
assert nsteps == nsteps_expected


def test_get_upgrade_steps_count_pre_200(oldest_sqlite):
schema = oldest_sqlite.schema
nsteps = upgrade_utils.get_upgrade_steps_count(
config=get_alembic_config(),
current_revision=schema.get_version(),
target_revision="0226",
)
assert nsteps == 27


def test_upgrade_with_progress_func(oldest_sqlite):
schema = oldest_sqlite.schema
progress_func = MagicMock()
schema.upgrade(
backup=False,
upgrade_spatialite_version=False,
progress_func=progress_func,
revision="0201",
)
assert progress_func.call_args_list == [call(0.0), call(50.0)]