Skip to content

Commit

Permalink
Remove dependency on schema definition
Browse files Browse the repository at this point in the history
  • Loading branch information
margrietpalm committed Dec 17, 2024
1 parent a377c58 commit 0fa27c2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 42 deletions.
2 changes: 1 addition & 1 deletion threedi_schema/domain/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class ModelSettings(Base):
friction_coefficient = Column(Float)
friction_coefficient_file = Column(String(255))
embedded_cutoff_threshold = Column(Float)
epsg_code = Column(Integer)
# epsg_code = Column(Integer)
max_angle_1d_advection = Column(Float)
friction_averaging = Column(Boolean)
table_step_size_1d = Column(Float)
Expand Down
87 changes: 48 additions & 39 deletions threedi_schema/migrations/versions/0230_reproject_geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm.attributes import InstrumentedAttribute

from threedi_schema import models
from threedi_schema.migrations.exceptions import InvalidSRIDException

# revision identifiers, used by Alembic.
Expand All @@ -21,6 +19,12 @@
branch_labels = None
depends_on = None

GEOM_TABLES = ['boundary_condition_1d', 'boundary_condition_2d', 'channel', 'connection_node', 'measure_location',
'measure_map', 'memory_control', 'table_control', 'cross_section_location', 'culvert',
'dem_average_area', 'dry_weather_flow', 'dry_weather_flow_map', 'exchange_line', 'grid_refinement_line',
'grid_refinement_area', 'lateral_1d', 'lateral_2d', 'obstacle', 'orifice', 'pipe', 'potential_breach',
'pump', 'pump_map', 'surface', 'surface_map', 'weir', 'windshielding_1d']


def get_crs_info(srid):
# Create temporary spatialite to find crs unit and projection
Expand All @@ -42,7 +46,7 @@ def get_crs_info(srid):
def get_model_srid() -> int:
# Note: this will not work for models which are allowed to have no CRS (no geometries)
conn = op.get_bind()
srid_str = conn.execute(sa.text("SELECT epsg_code FROM model_settings LIMIT 1")).fetchone()
srid_str = conn.execute(sa.text("SELECT epsg_code FROM model_settings")).fetchone()
if srid_str is None or srid_str[0] is None:
if not has_geom():
return None
Expand All @@ -59,48 +63,55 @@ def get_model_srid() -> int:
return srid


def get_cols_for_model(model, skip_cols=None):
if skip_cols is None:
skip_cols = []
return [getattr(model, item) for item in model.__dict__
if item not in skip_cols
and isinstance(getattr(model, item), InstrumentedAttribute)]
def get_geom_type(table_name, geo_col_name):
connection = op.get_bind()
columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall()
for col in columns:
if col[1] == geo_col_name:
return col[2]


def create_sqlite_table_from_model(model, table_name, add_geom=True, srid=None):
cols = get_cols_for_model(model, skip_cols=["id", "geom"])
query = f"""
CREATE TABLE {table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
{','.join(f"{col.name} {col.type}" for col in cols)});
"""
def add_geometry_column(table: str, name: str, srid: int, geometry_type: str):
# Adding geometry columns via alembic doesn't work
query = (
f"SELECT AddGeometryColumn('{table}', '{name}', {srid}, '{geometry_type}', 'XY', 1);")
op.execute(sa.text(query))
if add_geom:
op.execute(sa.text(
f"SELECT AddGeometryColumn('{table_name}', 'geom', {srid}, '{model.geom.type.geometry_type}', 'XY', 1);"))


def transform_column(model, srid):
table_name = model.__tablename__
temp_table_name = f'_temp_230_{table_name}'
create_sqlite_table_from_model(model, temp_table_name, add_geom=True, srid=srid)
col_names = ",".join([col.name for col in get_cols_for_model(model, skip_cols=["geom"])])
def transform_column(table_name, srid):
connection = op.get_bind()
columns = connection.execute(sa.text(f"PRAGMA table_info('{table_name}')")).fetchall()
# get all column names and types
skip_cols = ['id', 'geom']
col_names = [col[1] for col in columns if col[1] not in skip_cols]
col_types = [col[2] for col in columns if col[1] not in skip_cols]
# Create temporary table
temp_table_name = f'_temp_230_{table_name}_{uuid.uuid4().hex}'
# Create new table, insert data, drop original and rename temp to table_name
col_str = ','.join(['id INTEGER PRIMARY KEY NOT NULL'] + [f'{cname} {ctype}' for cname, ctype in
zip(col_names, col_types)])
query = f"CREATE TABLE {temp_table_name} ({col_str});"
op.execute(sa.text(query))
# Add geometry column with new srid!
geom_type = get_geom_type(table_name, 'geom')
add_geometry_column(temp_table_name, 'geom', srid, geom_type)
# Copy transformed geometry and other columns to temp table
op.execute(sa.text(f"""
INSERT INTO `{temp_table_name}` ({col_names}, `geom`)
SELECT {col_names}, ST_Transform(`geom`, {srid}) AS `geom` FROM `{table_name}`
"""))

col_str = ','.join(['id'] + col_names)
query = f"""
INSERT INTO {temp_table_name} ({col_str}, geom)
SELECT {col_str}, ST_Transform(geom, {srid}) AS geom FROM {table_name}
"""
op.execute(sa.text(query))
# Discard geometry column in old table
op.execute(sa.text(f"SELECT DiscardGeometryColumn('{table_name}', 'geom')"))
op.execute(sa.text(f"SELECT DiscardGeometryColumn('{temp_table_name}', 'geom')"))
# Remove old table
op.execute(sa.text(f"DROP TABLE `{table_name}`"))
op.execute(sa.text(f"DROP TABLE '{table_name}'"))
# Rename temp table
op.execute(sa.text(f"ALTER TABLE `{temp_table_name}` RENAME TO `{table_name}`;"))
op.execute(sa.text(f"ALTER TABLE '{temp_table_name}' RENAME TO '{table_name}';"))
# Recover geometry stuff
op.execute(sa.text(f"SELECT RecoverGeometryColumn('{table_name}', "
f"'geom', {srid}, '{model.geom.type.geometry_type}', 'XY')"))
f"'geom', {srid}, '{geom_type}', 'XY')"))
op.execute(sa.text(f"SELECT CreateSpatialIndex('{table_name}', 'geom')"))
op.execute(sa.text(f"SELECT RecoverSpatialIndex('{table_name}', 'geom')"))

Expand All @@ -114,8 +125,7 @@ def prep_spatialite(srid: int):

def has_geom():
connection = op.get_bind()
geom_tables = [model.__tablename__ for model in models.DECLARED_MODELS if hasattr(model, "geom")]
has_data = [connection.execute(sa.text(f'SELECT COUNT(*) FROM {table}')).fetchone()[0] > 0 for table in geom_tables]
has_data = [connection.execute(sa.text(f'SELECT COUNT(*) FROM {table}')).fetchone()[0] > 0 for table in GEOM_TABLES]
return any(has_data)


Expand All @@ -129,12 +139,11 @@ def upgrade():
# prepare spatialite databases
prep_spatialite(srid)
# transform all geometries
for model in models.DECLARED_MODELS:
if hasattr(model, "geom"):
transform_column(model, srid)
for table_name in GEOM_TABLES:
transform_column(table_name, srid)
# remove crs from model_settings
# with op.batch_alter_table('model_settings') as batch_op:
# batch_op.drop_column('epsg_code')
with op.batch_alter_table('model_settings') as batch_op:
batch_op.drop_column('epsg_code')


def downgrade():
Expand Down
7 changes: 5 additions & 2 deletions threedi_schema/tests/test_migration_230_crs_reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

import pytest
from sqlalchemy import text

from threedi_schema import models, ModelSchema, ThreediDatabase
from threedi_schema.migrations.exceptions import InvalidSRIDException
Expand All @@ -23,6 +24,7 @@ def db(tmp_path_factory, sqlite_path):
return ThreediDatabase(tmp_sqlite)


# TODO - match other testing and use generic fixtures
@pytest.mark.parametrize("epsg_code", [
999999, # non-existing
2227, # projected / US survey foot
Expand All @@ -31,13 +33,14 @@ def db(tmp_path_factory, sqlite_path):
def test_check_valid_crs(db, epsg_code):
session = db.get_session()
# Update the epsg_code in ModelSettings
model_settings_to_update = session.query(models.ModelSettings).filter_by(id=0).first()
model_settings_to_update.epsg_code = epsg_code
session.execute(text(f"UPDATE model_settings SET epsg_code = {epsg_code}"))
session.commit()
with pytest.raises(InvalidSRIDException) as exc_info:
db.schema.upgrade(backup=False)
session.execute(text("UPDATE model_settings SET epsg_code = 28992"))


# TODO - match other testing and use generic fixtures
def test_migration(tmp_path_factory):
# Ensure all geometries are transformed
sqlite_path = data_dir.joinpath("v2_bergermeer_221.sqlite")
Expand Down

0 comments on commit 0fa27c2

Please sign in to comment.