diff --git a/threedi_schema/domain/models.py b/threedi_schema/domain/models.py index 61dd81e..0b701b2 100644 --- a/threedi_schema/domain/models.py +++ b/threedi_schema/domain/models.py @@ -350,7 +350,7 @@ class ModelSettings(Base): friction_coefficient = Column(Float) friction_coefficient_file = Column(String(255)) embedded_cutoff_threshold = Column(Float) - epsg_code = Column(Integer) + # epsg_code = Column(Integer) max_angle_1d_advection = Column(Float) friction_averaging = Column(Boolean) table_step_size_1d = Column(Float) diff --git a/threedi_schema/migrations/versions/0230_reproject_geometries.py b/threedi_schema/migrations/versions/0230_reproject_geometries.py index 477643d..38dbe57 100644 --- a/threedi_schema/migrations/versions/0230_reproject_geometries.py +++ b/threedi_schema/migrations/versions/0230_reproject_geometries.py @@ -10,9 +10,7 @@ import sqlalchemy as sa from alembic import op -from sqlalchemy.orm.attributes import InstrumentedAttribute -from threedi_schema import models from threedi_schema.migrations.exceptions import InvalidSRIDException # revision identifiers, used by Alembic. @@ -21,6 +19,12 @@ branch_labels = None depends_on = None +GEOM_TABLES = ['boundary_condition_1d', 'boundary_condition_2d', 'channel', 'connection_node', 'measure_location', + 'measure_map', 'memory_control', 'table_control', 'cross_section_location', 'culvert', + 'dem_average_area', 'dry_weather_flow', 'dry_weather_flow_map', 'exchange_line', 'grid_refinement_line', + 'grid_refinement_area', 'lateral_1d', 'lateral_2d', 'obstacle', 'orifice', 'pipe', 'potential_breach', + 'pump', 'pump_map', 'surface', 'surface_map', 'weir', 'windshielding_1d'] + def get_crs_info(srid): # Create temporary spatialite to find crs unit and projection @@ -42,7 +46,7 @@ def get_crs_info(srid): def get_model_srid() -> int: # Note: this will not work for models which are allowed to have no CRS (no geometries) conn = op.get_bind() - srid_str = conn.execute(sa.text("SELECT epsg_code FROM model_settings LIMIT 1")).fetchone() + srid_str = conn.execute(sa.text("SELECT epsg_code FROM model_settings")).fetchone() if srid_str is None or srid_str[0] is None: if not has_geom(): return None @@ -59,48 +63,55 @@ def get_model_srid() -> int: return srid -def get_cols_for_model(model, skip_cols=None): - if skip_cols is None: - skip_cols = [] - return [getattr(model, item) for item in model.__dict__ - if item not in skip_cols - and isinstance(getattr(model, item), InstrumentedAttribute)] +def get_geom_type(table_name, geo_col_name): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + for col in columns: + if col[1] == geo_col_name: + return col[2] -def create_sqlite_table_from_model(model, table_name, add_geom=True, srid=None): - cols = get_cols_for_model(model, skip_cols=["id", "geom"]) - query = f""" - CREATE TABLE {table_name} ( - id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, - {','.join(f"{col.name} {col.type}" for col in cols)}); - """ +def add_geometry_column(table: str, name: str, srid: int, geometry_type: str): + # Adding geometry columns via alembic doesn't work + query = ( + f"SELECT AddGeometryColumn('{table}', '{name}', {srid}, '{geometry_type}', 'XY', 1);") op.execute(sa.text(query)) - if add_geom: - op.execute(sa.text( - f"SELECT AddGeometryColumn('{table_name}', 'geom', {srid}, '{model.geom.type.geometry_type}', 'XY', 1);")) -def transform_column(model, srid): - table_name = model.__tablename__ - temp_table_name = f'_temp_230_{table_name}' - create_sqlite_table_from_model(model, temp_table_name, add_geom=True, srid=srid) - col_names = ",".join([col.name for col in get_cols_for_model(model, skip_cols=["geom"])]) +def transform_column(table_name, srid): + connection = op.get_bind() + columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall() + # get all column names and types + skip_cols = ['id', 'geom'] + col_names = [col[1] for col in columns if col[1] not in skip_cols] + col_types = [col[2] for col in columns if col[1] not in skip_cols] + # Create temporary table + temp_table_name = f'_temp_230_{table_name}_{uuid.uuid4().hex}' + # Create new table, insert data, drop original and rename temp to table_name + col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in + zip(col_names, col_types)]) + query = f"CREATE TABLE {temp_table_name} ({col_str});" + op.execute(sa.text(query)) + # Add geometry column with new srid! + geom_type = get_geom_type(table_name, 'geom') + add_geometry_column(temp_table_name, 'geom', srid, geom_type) # Copy transformed geometry and other columns to temp table - op.execute(sa.text(f""" - INSERT INTO `{temp_table_name}` ({col_names}, `geom`) - SELECT {col_names}, ST_Transform(`geom`, {srid}) AS `geom` FROM `{table_name}` - """)) - + col_str = ','.join(['id'] + col_names) + query = f""" + INSERT INTO {temp_table_name} ({col_str}, geom) + SELECT {col_str}, ST_Transform(geom, {srid}) AS geom FROM {table_name} + """ + op.execute(sa.text(query)) # Discard geometry column in old table op.execute(sa.text(f"SELECT DiscardGeometryColumn('{table_name}', 'geom')")) op.execute(sa.text(f"SELECT DiscardGeometryColumn('{temp_table_name}', 'geom')")) # Remove old table - op.execute(sa.text(f"DROP TABLE `{table_name}`")) + op.execute(sa.text(f"DROP TABLE '{table_name}'")) # Rename temp table - op.execute(sa.text(f"ALTER TABLE `{temp_table_name}` RENAME TO `{table_name}`;")) + op.execute(sa.text(f"ALTER TABLE '{temp_table_name}' RENAME TO '{table_name}';")) # Recover geometry stuff op.execute(sa.text(f"SELECT RecoverGeometryColumn('{table_name}', " - f"'geom', {srid}, '{model.geom.type.geometry_type}', 'XY')")) + f"'geom', {srid}, '{geom_type}', 'XY')")) op.execute(sa.text(f"SELECT CreateSpatialIndex('{table_name}', 'geom')")) op.execute(sa.text(f"SELECT RecoverSpatialIndex('{table_name}', 'geom')")) @@ -114,8 +125,7 @@ def prep_spatialite(srid: int): def has_geom(): connection = op.get_bind() - geom_tables = [model.__tablename__ for model in models.DECLARED_MODELS if hasattr(model, "geom")] - has_data = [connection.execute(sa.text(f'SELECT COUNT(*) FROM {table}')).fetchone()[0] > 0 for table in geom_tables] + has_data = [connection.execute(sa.text(f'SELECT COUNT(*) FROM {table}')).fetchone()[0] > 0 for table in GEOM_TABLES] return any(has_data) @@ -129,12 +139,11 @@ def upgrade(): # prepare spatialite databases prep_spatialite(srid) # transform all geometries - for model in models.DECLARED_MODELS: - if hasattr(model, "geom"): - transform_column(model, srid) + for table_name in GEOM_TABLES: + transform_column(table_name, srid) # remove crs from model_settings - # with op.batch_alter_table('model_settings') as batch_op: - # batch_op.drop_column('epsg_code') + with op.batch_alter_table('model_settings') as batch_op: + batch_op.drop_column('epsg_code') def downgrade(): diff --git a/threedi_schema/tests/test_migration_230_crs_reprojection.py b/threedi_schema/tests/test_migration_230_crs_reprojection.py index 9a66b42..f6c67f5 100644 --- a/threedi_schema/tests/test_migration_230_crs_reprojection.py +++ b/threedi_schema/tests/test_migration_230_crs_reprojection.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest +from sqlalchemy import text from threedi_schema import models, ModelSchema, ThreediDatabase from threedi_schema.migrations.exceptions import InvalidSRIDException @@ -23,6 +24,7 @@ def db(tmp_path_factory, sqlite_path): return ThreediDatabase(tmp_sqlite) +# TODO - match other testing and use generic fixtures @pytest.mark.parametrize("epsg_code", [ 999999, # non-existing 2227, # projected / US survey foot @@ -31,13 +33,14 @@ def db(tmp_path_factory, sqlite_path): def test_check_valid_crs(db, epsg_code): session = db.get_session() # Update the epsg_code in ModelSettings - model_settings_to_update = session.query(models.ModelSettings).filter_by(id=0).first() - model_settings_to_update.epsg_code = epsg_code + session.execute(text(f"UPDATE model_settings SET epsg_code = {epsg_code}")) session.commit() with pytest.raises(InvalidSRIDException) as exc_info: db.schema.upgrade(backup=False) + session.execute(text("UPDATE model_settings SET epsg_code = 28992")) +# TODO - match other testing and use generic fixtures def test_migration(tmp_path_factory): # Ensure all geometries are transformed sqlite_path = data_dir.joinpath("v2_bergermeer_221.sqlite")