diff --git a/CHANGES.rst b/CHANGES.rst index 50c5767c..90cc5745 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,10 +2,10 @@ Changelog of threedi-schema =================================================== -0.230.1 (unreleased) +0.300 (unreleased) -------------------- -- Nothing changed yet. +- Convert spatialite to geopackage during upgrade 0.230.0 (2025-01-16) diff --git a/threedi_schema/application/schema.py b/threedi_schema/application/schema.py index 5b2d5bc4..6775ce53 100644 --- a/threedi_schema/application/schema.py +++ b/threedi_schema/application/schema.py @@ -117,12 +117,33 @@ def epsg_code(self): def epsg_source(self): return self._get_epsg_data()[1] + @property + def is_geopackage(self): + with self.db.get_session() as session: + return bool( + session.execute( + text( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='gpkg_contents';" + ) + ).scalar() + ) + + @property + def is_spatialite(self): + with self.db.get_session() as session: + return bool( + session.execute( + text( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='spatial_ref_sys';" + ) + ).scalar() + ) + def upgrade( self, revision="head", backup=True, upgrade_spatialite_version=False, - convert_to_geopackage=False, progress_func=None, custom_epsg_code=None, ): @@ -141,9 +162,6 @@ def upgrade( Specify 'upgrade_spatialite_version=True' to also upgrade the spatialite file version after the upgrade. - Specify 'convert_to_geopackage=True' to also convert from spatialite - to geopackage file version after the upgrade. - Specify a 'progress_func' to handle progress updates. `progress_func` should expect a single argument representing the fraction of progress @@ -156,11 +174,6 @@ def upgrade( raise ValueError( f"Incorrect version format: {revision}. Expected 'head' or a numeric value." ) - if convert_to_geopackage and rev_nr < 300: - raise UpgradeFailedError( - f"Cannot convert to geopackage for {revision=} because geopackage support is " - "enabled from revision 300", - ) v = self.get_version() if v is not None and v < constants.LATEST_SOUTH_MIGRATION_ID: raise MigrationMissingError( @@ -168,6 +181,22 @@ def upgrade( f"{constants.LATEST_SOUTH_MIGRATION_ID}. Please consult the " f"3Di documentation on how to update legacy databases." ) + if ( + v is not None + and v <= constants.LAST_SPTL_SCHEMA_VERSION + and not self.is_spatialite + ): + raise UpgradeFailedError( + f"Cannot upgrade from {revision=} because {self.db.path} is not a spatialite" + ) + elif ( + v is not None + and v > constants.LAST_SPTL_SCHEMA_VERSION + and not self.is_geopackage + ): + raise UpgradeFailedError( + f"Cannot upgrade from {revision=} because {self.db.path} is not a geopackage" + ) def run_upgrade(_revision): if backup: @@ -201,11 +230,20 @@ def run_upgrade(_revision): self._set_custom_epsg_code(custom_epsg_code) run_upgrade("0230") self._remove_custom_epsg_code() - run_upgrade(revision) - if upgrade_spatialite_version: + # First upgrade to LAST_SPTL_SCHEMA_VERSION. + # When the requested revision <= LAST_SPTL_SCHEMA_VERSION, this is the only upgrade step + run_upgrade( + revision + if rev_nr <= constants.LAST_SPTL_SCHEMA_VERSION + else f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}" + ) + # only upgrade spatialite version is target revision is <= LAST_SPTL_SCHEMA_VERSION + if rev_nr <= constants.LAST_SPTL_SCHEMA_VERSION and upgrade_spatialite_version: self.upgrade_spatialite_version() - elif convert_to_geopackage: + # Finish upgrade if target revision > LAST_SPTL_SCHEMA_VERSION + elif rev_nr > constants.LAST_SPTL_SCHEMA_VERSION: self.convert_to_geopackage() + run_upgrade(revision) def _set_custom_epsg_code(self, custom_epsg_code: int): if ( @@ -269,7 +307,7 @@ def set_spatial_indexes(self): f"{schema_version}. Current version: {version}." ) - ensure_spatial_indexes(self.db, models.DECLARED_MODELS) + ensure_spatial_indexes(self.db.engine, models.DECLARED_MODELS) def upgrade_spatialite_version(self): """Upgrade the version of the spatialite file to the version of the @@ -282,7 +320,11 @@ def upgrade_spatialite_version(self): """ lib_version, file_version = get_spatialite_version(self.db) if file_version == 3 and lib_version in (4, 5): - self.validate_schema() + if self.get_version() != constants.LAST_SPTL_SCHEMA_VERSION: + raise MigrationMissingError( + f"This tool requires schema version " + f"{constants.LAST_SPTL_SCHEMA_VERSION:}. Current version: {self.get_version()}." + ) with self.db.file_transaction(start_empty=True) as work_db: rev_nr = min(get_schema_version(), 229) first_rev = f"{rev_nr:04d}" @@ -327,8 +369,15 @@ def convert_to_geopackage(self): return # Ensure database is upgraded and views are recreated - self.upgrade() - self.validate_schema() + revision = self.get_version() + if revision is None or revision <= constants.LAST_SPTL_SCHEMA_VERSION: + self.upgrade( + revision=f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}", backup=False + ) + elif revision > constants.LAST_SPTL_SCHEMA_VERSION: + UpgradeFailedError( + f"Cannot convert schema version {revision} to geopackage" + ) # Make necessary modifications for conversion on temporary database with self.db.file_transaction(start_empty=False, copy_results=False) as work_db: # remove spatialite specific tables that break conversion @@ -410,3 +459,4 @@ def convert_to_geopackage(self): "CREATE TABLE views_geometry_columns(view_name TEXT, view_geometry TEXT, view_rowid TEXT, f_table_name VARCHAR(256), f_geometry_column VARCHAR(256))" ) ) + ensure_spatial_indexes(self.db.engine, models.DECLARED_MODELS) diff --git a/threedi_schema/domain/constants.py b/threedi_schema/domain/constants.py index 90e86a26..eda9d354 100644 --- a/threedi_schema/domain/constants.py +++ b/threedi_schema/domain/constants.py @@ -1,6 +1,7 @@ from enum import Enum LATEST_SOUTH_MIGRATION_ID = 160 +LAST_SPTL_SCHEMA_VERSION = 230 VERSION_TABLE_NAME = "schema_version" diff --git a/threedi_schema/infrastructure/spatial_index.py b/threedi_schema/infrastructure/spatial_index.py index 1c963d88..e9271154 100644 --- a/threedi_schema/infrastructure/spatial_index.py +++ b/threedi_schema/infrastructure/spatial_index.py @@ -1,50 +1,47 @@ -from geoalchemy2.types import Geometry -from sqlalchemy import func, text +from sqlalchemy import func, inspect, text __all__ = ["ensure_spatial_indexes"] -def _ensure_spatial_index(connection, column): - """Ensure presence of a spatial index for given geometry olumn""" - if ( - connection.execute( - func.RecoverSpatialIndex(column.table.name, column.name) - ).scalar() - is not None - ): - return False - +def create_spatial_index(connection, column): + """ + Create spatial index for given column. + Note that this will fail if the spatial index already exists! + """ idx_name = f"{column.table.name}_{column.name}" - connection.execute(text(f"DROP TABLE IF EXISTS idx_{idx_name}")) - for prefix in {"gii_", "giu_", "gid_"}: - connection.execute(text(f"DROP TRIGGER IF EXISTS {prefix}{idx_name}")) - if ( - connection.execute( - func.CreateSpatialIndex(column.table.name, column.name) - ).scalar() - != 1 - ): - raise RuntimeError(f"Spatial index creation for {idx_name} failed") - + try: + connection.execute(func.gpkgAddSpatialIndex(column.table.name, column.name)) + except Exception as e: + raise RuntimeError( + f"Spatial index creation for {idx_name} failed with error {e}" + ) return True -def ensure_spatial_indexes(db, models): +def get_missing_spatial_indexes(engine, models): + """ + Collect all rtree tables that should exist + There can only be one geometry column per table and we assume any geometry column is named geom + """ + inspector = inspect(engine) + table_names = inspector.get_table_names() + return [ + model + for model in models + if "geom" in model.__table__.columns + and f"rtree_{model.__table__.name}_geom" not in table_names + ] + + +def ensure_spatial_indexes(engine, models): """Ensure presence of spatial indexes for all geometry columns""" created = False - engine = db.engine - + no_spatial_index_models = get_missing_spatial_indexes(engine, models) with engine.connect() as connection: with connection.begin(): - for model in models: - geom_columns = [ - x for x in model.__table__.columns if isinstance(x.type, Geometry) - ] - if len(geom_columns) > 1: - # Pragmatic fix: spatialindex breaks on multiple geometry columns per table - geom_columns = [x for x in geom_columns if x.name == "the_geom"] - if geom_columns: - created &= _ensure_spatial_index(connection, geom_columns[0]) - + for model in no_spatial_index_models: + created &= create_spatial_index( + connection, model.__table__.columns["geom"] + ) if created: connection.execute(text("VACUUM")) diff --git a/threedi_schema/migrations/versions/0300_geopackage.py b/threedi_schema/migrations/versions/0300_geopackage.py new file mode 100644 index 00000000..078dd341 --- /dev/null +++ b/threedi_schema/migrations/versions/0300_geopackage.py @@ -0,0 +1,30 @@ +"""Reproject geometries to model CRS + +Revision ID: 0230 +Revises: +Create Date: 2024-11-12 12:30 + +""" +import sqlite3 +import uuid + +import sqlalchemy as sa +from alembic import op + +from threedi_schema.migrations.exceptions import InvalidSRIDException + +# revision identifiers, used by Alembic. +revision = "0300" +down_revision = "0230" +branch_labels = None +depends_on = None + + +def upgrade(): + # this upgrade only changes the model version + pass + + +def downgrade(): + # Not implemented on purpose + pass diff --git a/threedi_schema/scripts.py b/threedi_schema/scripts.py index 98b82a68..b3316cd7 100644 --- a/threedi_schema/scripts.py +++ b/threedi_schema/scripts.py @@ -42,7 +42,6 @@ def migrate( revision=revision, backup=backup, upgrade_spatialite_version=upgrade_spatialite_version, - convert_to_geopackage=convert_to_geopackage, ) click.echo("The migrated schema revision is: %s" % schema.get_version()) diff --git a/threedi_schema/tests/conftest.py b/threedi_schema/tests/conftest.py index f0c3d2ce..353bc670 100644 --- a/threedi_schema/tests/conftest.py +++ b/threedi_schema/tests/conftest.py @@ -56,8 +56,7 @@ def in_memory_sqlite(): @pytest.fixture -def sqlite_latest(in_memory_sqlite): +def sqlite_latest(empty_sqlite_v4): """An in-memory database with the latest schema version""" - db = ThreediDatabase("") - in_memory_sqlite.schema.upgrade("head", backup=False, custom_epsg_code=28992) - return db + empty_sqlite_v4.schema.upgrade("head", backup=False, custom_epsg_code=28992) + return empty_sqlite_v4 diff --git a/threedi_schema/tests/test_gpkg.py b/threedi_schema/tests/test_gpkg.py index a4fa9259..5f62b5ec 100644 --- a/threedi_schema/tests/test_gpkg.py +++ b/threedi_schema/tests/test_gpkg.py @@ -1,31 +1,18 @@ import pytest -from sqlalchemy import text + +from threedi_schema.domain import constants @pytest.mark.parametrize("upgrade_spatialite", [True, False]) def test_convert_to_geopackage(oldest_sqlite, upgrade_spatialite): - if upgrade_spatialite: - oldest_sqlite.schema.upgrade(upgrade_spatialite_version=True) + # if upgrade_spatialite: + oldest_sqlite.schema.upgrade( + upgrade_spatialite_version=upgrade_spatialite, + revision=f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}", + ) oldest_sqlite.schema.convert_to_geopackage() # Ensure that after the conversion the geopackage is used assert oldest_sqlite.path.suffix == ".gpkg" - with oldest_sqlite.session_scope() as session: - gpkg_table_exists = bool( - session.execute( - text( - "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='gpkg_contents';" - ) - ).scalar() - ) - spatialite_table_exists = bool( - session.execute( - text( - "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='spatial_ref_sys';" - ) - ).scalar() - ) - - assert gpkg_table_exists - assert not spatialite_table_exists - assert oldest_sqlite.schema.validate_schema() + assert not oldest_sqlite.schema.is_spatialite + assert oldest_sqlite.schema.is_geopackage diff --git a/threedi_schema/tests/test_migration_230_crs_reprojection.py b/threedi_schema/tests/test_migration_230_crs_reprojection.py index fe790aa2..32e5476c 100644 --- a/threedi_schema/tests/test_migration_230_crs_reprojection.py +++ b/threedi_schema/tests/test_migration_230_crs_reprojection.py @@ -21,7 +21,7 @@ def test_check_valid_crs(in_memory_sqlite, epsg_code): def test_migration(tmp_path_factory, oldest_sqlite): schema = ModelSchema(oldest_sqlite) - schema.upgrade(backup=False) + schema.upgrade(backup=False, revision="0230") cursor = sqlite3.connect(schema.db.path).cursor() query = cursor.execute("SELECT srid FROM geometry_columns where f_table_name = 'geom'") epsg_matches = [int(item[0])==28992 for item in query.fetchall()] diff --git a/threedi_schema/tests/test_schema.py b/threedi_schema/tests/test_schema.py index 15c57702..968b8f32 100644 --- a/threedi_schema/tests/test_schema.py +++ b/threedi_schema/tests/test_schema.py @@ -7,6 +7,8 @@ from threedi_schema.application import errors from threedi_schema.application.schema import get_schema_version from threedi_schema.domain import constants +from threedi_schema.domain.models import DECLARED_MODELS +from threedi_schema.infrastructure.spatial_index import get_missing_spatial_indexes from threedi_schema.infrastructure.spatialite_versions import get_spatialite_version @@ -135,9 +137,9 @@ def test_full_upgrade_oldest(oldest_sqlite): run_upgrade_test(schema) -def test_full_upgrade_empty(in_memory_sqlite): +def test_full_upgrade_empty(empty_sqlite_v4): """Upgrade an empty database to the latest version""" - schema = ModelSchema(in_memory_sqlite) + schema = ModelSchema(empty_sqlite_v4) schema.upgrade( backup=False, upgrade_spatialite_version=False, custom_epsg_code=28992 ) @@ -163,10 +165,7 @@ def run_upgrade_test(schema): schema_cols = get_columns_from_schema(schema, table) sqlite_cols = get_columns_from_sqlite(session, table) assert set(schema_cols).issubset(set(sqlite_cols)) - check_result = session.execute( - text("SELECT CheckSpatialIndex('connection_node', 'geom')") - ).scalar() - assert check_result == 1 + assert get_missing_spatial_indexes(schema.db.engine, DECLARED_MODELS) == [] def test_upgrade_with_custom_epsg_code(in_memory_sqlite): @@ -190,9 +189,9 @@ def test_upgrade_with_custom_epsg_code(in_memory_sqlite): assert all([srid == 28992 for srid in srids]) -def test_upgrade_with_custom_epsg_code_version_too_new(in_memory_sqlite): +def test_upgrade_with_custom_epsg_code_version_too_new(empty_sqlite_v4): """Set custom epsg code for schema version > 229""" - schema = ModelSchema(in_memory_sqlite) + schema = ModelSchema(empty_sqlite_v4) schema.upgrade( revision="0230", backup=False, @@ -282,14 +281,18 @@ def test_upgrade_without_backup(south_latest_sqlite): assert db is south_latest_sqlite -def test_convert_to_geopackage_raise(oldest_sqlite): - if get_schema_version() >= 300: - pytest.skip("Warning not expected beyond schema 300") - schema = ModelSchema(oldest_sqlite) - with pytest.raises(errors.UpgradeFailedError): - schema.upgrade( - backup=False, upgrade_spatialite_version=False, convert_to_geopackage=True - ) +@pytest.mark.parametrize( + "is_var, version", + [("is_spatialite", constants.LAST_SPTL_SCHEMA_VERSION), ("is_geopackage", 300)], +) +def test_upgrade_incorrect_format(in_memory_sqlite, is_var, version): + schema = ModelSchema(in_memory_sqlite) + with mock.patch.object( + type(schema), is_var, new=mock.PropertyMock(return_value=False) + ): + with mock.patch.object(type(schema), "get_version", return_value=version): + with pytest.raises(errors.UpgradeFailedError): + schema.upgrade() def test_upgrade_revision_exception(oldest_sqlite): @@ -304,7 +307,11 @@ def test_upgrade_spatialite_3(oldest_sqlite): pytest.skip("Nothing to test: spatialite library version equals file version") schema = ModelSchema(oldest_sqlite) - schema.upgrade(backup=False, upgrade_spatialite_version=True) + schema.upgrade( + backup=False, + upgrade_spatialite_version=True, + revision=f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}", + ) _, file_version_after = get_spatialite_version(oldest_sqlite) assert file_version_after == 4 @@ -317,6 +324,8 @@ def test_upgrade_spatialite_3(oldest_sqlite): assert check_result == 1 +# TODO: remove this because ensure_spatial_indexes is already tested +@pytest.mark.skip(reason="will be removed") def test_set_spatial_indexes(in_memory_sqlite): engine = in_memory_sqlite.engine @@ -341,8 +350,8 @@ def test_set_spatial_indexes(in_memory_sqlite): class TestGetEPSGData: - def test_no_epsg(self, in_memory_sqlite): - schema = ModelSchema(in_memory_sqlite) + def test_no_epsg(self, empty_sqlite_v4): + schema = ModelSchema(empty_sqlite_v4) schema.upgrade( backup=False, upgrade_spatialite_version=False, custom_epsg_code=28992 ) @@ -354,3 +363,26 @@ def test_with_epsg(self, oldest_sqlite): schema.upgrade(backup=False, upgrade_spatialite_version=False) assert schema.epsg_code == 28992 assert schema.epsg_source == "boundary_condition_1d.geom" + + +def test_is_spatialite(in_memory_sqlite): + schema = ModelSchema(in_memory_sqlite) + schema.upgrade( + backup=False, + upgrade_spatialite_version=False, + custom_epsg_code=28992, + revision=f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}", + ) + assert schema.is_spatialite + + +def test_is_geopackage(oldest_sqlite): + schema = ModelSchema(oldest_sqlite) + schema.upgrade( + backup=False, + upgrade_spatialite_version=False, + custom_epsg_code=28992, + revision=f"{constants.LAST_SPTL_SCHEMA_VERSION:04d}", + ) + schema.convert_to_geopackage() + assert schema.is_geopackage diff --git a/threedi_schema/tests/test_spatial_index.py b/threedi_schema/tests/test_spatial_index.py new file mode 100644 index 00000000..8744448c --- /dev/null +++ b/threedi_schema/tests/test_spatial_index.py @@ -0,0 +1,58 @@ +import pytest +from geoalchemy2 import Geometry +from sqlalchemy import Column, create_engine, func, Integer +from sqlalchemy.event import listen +from sqlalchemy.orm import declarative_base, sessionmaker + +from threedi_schema.application.threedi_database import load_spatialite +from threedi_schema.infrastructure.spatial_index import ( + create_spatial_index, + ensure_spatial_indexes, + get_missing_spatial_indexes, +) + +Base = declarative_base() + + +class Model(Base): + __tablename__ = "model" + + id = Column(Integer, primary_key=True) + geom = Column(Geometry("POINT")) + + +@pytest.fixture() +def engine(): + engine = create_engine("sqlite:///:memory:") + listen(engine, "connect", load_spatialite) + + Base.metadata.create_all(engine) + Session = sessionmaker(bind=engine) + session = Session() + session.execute(func.gpkgCreateBaseTables()) + return engine + + +def test_get_missing_spatial_indexes(engine): + assert get_missing_spatial_indexes(engine, [Model]) == [Model] + + +def test_create_spatial_index(engine): + with engine.connect() as connection: + with connection.begin(): + create_spatial_index(connection, Model.__table__.columns["geom"]) + assert get_missing_spatial_indexes(engine, [Model]) == [] + + +def test_create_spatial_index_fail(engine): + with engine.connect() as connection: + with connection.begin(): + create_spatial_index(connection, Model.__table__.columns["geom"]) + with pytest.raises(RuntimeError): + create_spatial_index(connection, Model.__table__.columns["geom"]) + + +def test_ensure_spatial_index(engine): + assert get_missing_spatial_indexes(engine, [Model]) == [Model] + ensure_spatial_indexes(engine, [Model]) + assert get_missing_spatial_indexes(engine, [Model]) == [] diff --git a/threedi_schema/tests/test_upgrade_utils.py b/threedi_schema/tests/test_upgrade_utils.py index 21463b70..0d0e6982 100644 --- a/threedi_schema/tests/test_upgrade_utils.py +++ b/threedi_schema/tests/test_upgrade_utils.py @@ -48,9 +48,9 @@ def test_upgrade_with_progress_func(oldest_sqlite): schema = oldest_sqlite.schema progress_func = MagicMock() schema.upgrade( + revision="0201", backup=False, upgrade_spatialite_version=False, progress_func=progress_func, - revision="0201", ) assert progress_func.call_args_list == [call(0.0), call(50.0)]