diff --git a/CHANGES.rst b/CHANGES.rst index 2a22fb3b..9fe05654 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,12 +1,16 @@ Changelog of threedi-modelchecker ================================= - - -2.15.1 (unreleased) +2.16.0 (unreleased) ------------------- -- Nothing changed yet. +- Adapt to schema 230 where all geometries use the model CRS and model_settings.epsg_code is no longer available +- Remove checks for model_settings.epsg_code (317 and 318) +- Remove usage of epsg 4326 in the tests because this CRS is no longer valid +- Remove no longer needed transformations +- Add checks for mathing epsg in all geometries and raster files +- Add checks for valid epsg (existing code, projected, in meters) which requires pyproj +- Change ConnectionNodeCheck (201) to require minimum distance of 10cm 2.15.0 (2025-01-08) diff --git a/pyproject.toml b/pyproject.toml index 4f118f5c..95c5902c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,8 @@ dependencies = [ "Click", "GeoAlchemy2>=0.9,!=0.11.*", "SQLAlchemy>=1.4", - "threedi-schema==0.229.*" + "pyproj", + "threedi-schema==0.230.*" ] [project.optional-dependencies] diff --git a/threedi_modelchecker/__init__.py b/threedi_modelchecker/__init__.py index 0207c9d5..e5742e9a 100644 --- a/threedi_modelchecker/__init__.py +++ b/threedi_modelchecker/__init__.py @@ -1,5 +1,5 @@ from .model_checks import * # NOQA # fmt: off -__version__ = '2.15.1.dev0' +__version__ = '2.16.0.dev0' # fmt: on diff --git a/threedi_modelchecker/checks/base.py b/threedi_modelchecker/checks/base.py index 6892c7dc..db8c90f0 100644 --- a/threedi_modelchecker/checks/base.py +++ b/threedi_modelchecker/checks/base.py @@ -2,6 +2,7 @@ from enum import IntEnum from typing import List, NamedTuple +from geoalchemy2.functions import ST_SRID from sqlalchemy import and_, false, func, types from sqlalchemy.orm.session import Session from threedi_schema.domain import custom_types @@ -412,3 +413,26 @@ def description(self) -> str: return ( f"{self.table.name}.{self.column} is not a comma seperated list of integers" ) + + +class EPSGGeomCheck(BaseCheck): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ref_epsg_name = "" + + def get_invalid(self, session: Session) -> List[NamedTuple]: + self.ref_epsg_name = session.ref_epsg_name + if session.ref_epsg_code is None: + return [] + return ( + self.to_check(session) + .filter(ST_SRID(self.column) != session.ref_epsg_code) + .all() + ) + + def description(self) -> str: + return f"The epsg of {self.table.name}.{self.column_name} should match {self.ref_epsg_name}" diff --git a/threedi_modelchecker/checks/factories.py b/threedi_modelchecker/checks/factories.py index 91cbc67c..accb87c0 100644 --- a/threedi_modelchecker/checks/factories.py +++ b/threedi_modelchecker/checks/factories.py @@ -5,6 +5,7 @@ from .base import ( EnumCheck, + EPSGGeomCheck, ForeignKeyCheck, GeometryCheck, GeometryTypeCheck, @@ -12,6 +13,7 @@ TypeCheck, UniqueCheck, ) +from .raster import RasterHasMatchingEPSGCheck @dataclass @@ -124,3 +126,25 @@ def generate_enum_checks(table, custom_level_map=None, **kwargs): level = get_level(table, column, custom_level_map) enum_checks.append(EnumCheck(column, level=level, **kwargs)) return enum_checks + + +def generate_epsg_geom_checks(table, custom_level_map=None, **kwargs): + custom_level_map = custom_level_map or {} + column = table.columns.get("geom") + if column is not None: + level = get_level(table, column, custom_level_map) + return [EPSGGeomCheck(column=column, level=level, **kwargs)] + else: + return [] + + +def generate_epsg_raster_checks(table, raster_columns, custom_level_map=None, **kwargs): + custom_level_map = custom_level_map or {} + checks = [] + for column in table.columns: + if column in raster_columns: + level = get_level(table, column, custom_level_map) + checks.append( + RasterHasMatchingEPSGCheck(column=column, level=level, **kwargs) + ) + return checks diff --git a/threedi_modelchecker/checks/geo_query.py b/threedi_modelchecker/checks/geo_query.py deleted file mode 100644 index 50750c10..00000000 --- a/threedi_modelchecker/checks/geo_query.py +++ /dev/null @@ -1,24 +0,0 @@ -from geoalchemy2 import functions as geo_func -from sqlalchemy import func -from sqlalchemy.orm import Query -from sqlalchemy.sql import literal -from threedi_schema import models - -DEFAULT_EPSG = 28992 - - -def epsg_code_query(): - epsg_code = Query(models.ModelSettings.epsg_code).limit(1).scalar_subquery() - return func.coalesce(epsg_code, literal(DEFAULT_EPSG)).label("epsg_code") - - -def transform(col): - return geo_func.ST_Transform(col, epsg_code_query()) - - -def distance(col_1, col_2): - return geo_func.ST_Distance(transform(col_1), transform(col_2)) - - -def length(col): - return geo_func.ST_Length(transform(col)) diff --git a/threedi_modelchecker/checks/location.py b/threedi_modelchecker/checks/location.py index 65709c6d..7f7a41a9 100644 --- a/threedi_modelchecker/checks/location.py +++ b/threedi_modelchecker/checks/location.py @@ -1,11 +1,10 @@ from typing import List, NamedTuple -from sqlalchemy import func +from geoalchemy2.functions import ST_Distance, ST_NPoints, ST_PointN from sqlalchemy.orm import aliased, Session from threedi_schema.domain import models from threedi_modelchecker.checks.base import BaseCheck -from threedi_modelchecker.checks.geo_query import distance class PointLocationCheck(BaseCheck): @@ -32,7 +31,7 @@ def get_invalid(self, session): self.ref_table, self.ref_table.id == self.ref_column, ) - .filter(distance(self.column, self.ref_table.geom) > self.max_distance) + .filter(ST_Distance(self.column, self.ref_table.geom) > self.max_distance) .all() ) @@ -70,15 +69,13 @@ def __init__( def get_invalid(self, session: Session) -> List[NamedTuple]: start_node = aliased(self.ref_table_start) end_node = aliased(self.ref_table_end) - tol = self.max_distance - start_point = func.ST_PointN(self.column, 1) - end_point = func.ST_PointN(self.column, func.ST_NPoints(self.column)) - - start_ok = distance(start_point, start_node.geom) <= tol - end_ok = distance(end_point, end_node.geom) <= tol - start_ok_if_reversed = distance(end_point, start_node.geom) <= tol - end_ok_if_reversed = distance(start_point, end_node.geom) <= tol + start_point = ST_PointN(self.column, 1) + end_point = ST_PointN(self.column, ST_NPoints(self.column)) + start_ok = ST_Distance(start_point, start_node.geom) <= tol + end_ok = ST_Distance(end_point, end_node.geom) <= tol + start_ok_if_reversed = ST_Distance(end_point, start_node.geom) <= tol + end_ok_if_reversed = ST_Distance(start_point, end_node.geom) <= tol return ( self.to_check(session) .join(start_node, start_node.id == self.ref_column_start) diff --git a/threedi_modelchecker/checks/other.py b/threedi_modelchecker/checks/other.py index bd99e9b8..6304696a 100644 --- a/threedi_modelchecker/checks/other.py +++ b/threedi_modelchecker/checks/other.py @@ -2,6 +2,8 @@ from dataclasses import dataclass from typing import List, Literal, NamedTuple +import pyproj +from geoalchemy2.functions import ST_Distance, ST_Length from sqlalchemy import ( and_, case, @@ -19,7 +21,6 @@ from .base import BaseCheck, CheckLevel from .cross_section_definitions import cross_section_configuration_for_record -from .geo_query import distance, length, transform class CorrectAggregationSettingsExist(BaseCheck): @@ -258,8 +259,7 @@ def get_invalid(self, session): class ConnectionNodesLength(BaseCheck): """Check that the distance between `start_node` and `end_node` is at least - `min_distance`. The coords will be transformed into (the first entry) of - ModelSettings.epsg_code. + `min_distance`. """ def __init__( @@ -290,7 +290,7 @@ def get_invalid(self, session): self.to_check(session) .join(start_node, start_node.id == self.start_node) .join(end_node, end_node.id == self.end_node) - .filter(distance(start_node.geom, end_node.geom) < self.min_distance) + .filter(ST_Distance(start_node.geom, end_node.geom) < self.min_distance) ) return list(q.with_session(session).all()) @@ -321,30 +321,26 @@ def get_invalid(self, session: Session) -> List[NamedTuple]: distance between all connection nodes. """ query = text( - f"""SELECT * - FROM connection_node AS cn1, connection_node AS cn2 - WHERE - distance(cn1.geom, cn2.geom, 1) < :min_distance - AND cn1.ROWID != cn2.ROWID - AND cn2.ROWID IN ( - SELECT ROWID - FROM SpatialIndex - WHERE ( - f_table_name = "connection_node" - AND search_frame = Buffer(cn1.geom, {self.minimum_distance / 2}))); + f""" + SELECT * + FROM connection_node AS cn1, connection_node AS cn2 + WHERE ST_Distance(cn1.geom, cn2.geom) < {self.minimum_distance} + AND cn1.ROWID != cn2.ROWID + AND cn2.ROWID IN ( + SELECT ROWID + FROM SpatialIndex + WHERE ( + f_table_name = "connection_node" + AND search_frame = Buffer(cn1.geom, {self.minimum_distance / 2}))) """ ) - results = ( - session.connection() - .execute(query, {"min_distance": self.minimum_distance}) - .fetchall() - ) - + results = session.connection().execute(query).fetchall() return results def description(self) -> str: + return ( - f"The connection_node is within {self.minimum_distance} degrees of " + f"The connection_node is within {self.minimum_distance * 100 } cm of " f"another connection_node." ) @@ -548,10 +544,10 @@ def get_invalid(self, session: Session) -> List[NamedTuple]: linestring = models.Channel.geom tol = self.min_distance breach_point = func.Line_Locate_Point( - transform(linestring), transform(func.ST_PointN(self.column, 1)) + linestring, func.ST_PointN(self.column, 1) ) - dist_1 = breach_point * length(linestring) - dist_2 = (1 - breach_point) * length(linestring) + dist_1 = breach_point * ST_Length(linestring) + dist_2 = (1 - breach_point) * ST_Length(linestring) return ( self.to_check(session) .join(models.Channel, self.table.c.channel_id == models.Channel.id) @@ -577,10 +573,8 @@ def get_invalid(self, session: Session) -> List[NamedTuple]: # First fetch the position of each potential breach per channel def get_position(point, linestring): - breach_point = func.Line_Locate_Point( - transform(linestring), transform(func.ST_PointN(point, 1)) - ) - return (breach_point * length(linestring)).label("position") + breach_point = func.Line_Locate_Point(linestring, func.ST_PointN(point, 1)) + return (breach_point * ST_Length(linestring)).label("position") potential_breaches = sorted( session.query(self.table, get_position(self.column, models.Channel.geom)) @@ -776,7 +770,7 @@ def get_invalid(self, session: Session) -> List[NamedTuple]: self.table.c.id, self.table.c.area, self.table.c.geom, - func.ST_Area(transform(self.table.c.geom)).label("calculated_area"), + func.ST_Area(self.table.c.geom).label("calculated_area"), ).subquery() return ( session.query(all_results) @@ -1114,3 +1108,64 @@ def get_invalid(self, session: Session) -> List[NamedTuple]: def description(self) -> str: return f"The values in {self.table.name}.{self.column_name} should add up to" + + +class ModelEPSGCheckValid(BaseCheck): + def __init__(self, *args, **kwargs): + super().__init__(column=models.ModelSettings.id, *args, **kwargs) + self.epsg_code = None + + def get_invalid(self, session: Session) -> List[NamedTuple]: + self.epsg_code = session.ref_epsg_code + if self.epsg_code is not None: + try: + pyproj.CRS.from_epsg(self.epsg_code) + except pyproj.exceptions.CRSError: + return self.to_check(session).all() + return [] + + def description(self) -> str: + return f"Found invalid EPSG: {self.epsg_code}" + + +class ModelEPSGCheckProjected(BaseCheck): + def __init__(self, *args, **kwargs): + super().__init__(column=models.ModelSettings.id, *args, **kwargs) + self.epsg_code = None + + def get_invalid(self, session: Session) -> List[NamedTuple]: + self.epsg_code = session.ref_epsg_code + if self.epsg_code is not None: + try: + crs = pyproj.CRS.from_epsg(self.epsg_code) + except pyproj.exceptions.CRSError: + # handled by ModelEPSGCheckValid + return [] + if not crs.is_projected: + return self.to_check(session).all() + return [] + + def description(self) -> str: + return f"EPSG {self.epsg_code} is not projected" + + +class ModelEPSGCheckUnits(BaseCheck): + def __init__(self, *args, **kwargs): + super().__init__(column=models.ModelSettings.id, *args, **kwargs) + self.epsg_code = None + + def get_invalid(self, session: Session) -> List[NamedTuple]: + self.epsg_code = session.ref_epsg_code + if self.epsg_code is not None: + try: + crs = pyproj.CRS.from_epsg(self.epsg_code) + except pyproj.exceptions.CRSError: + # handled by ModelEPSGCheckValid + return [] + for ax in crs.axis_info: + if not ax.unit_name == "metre": + return self.to_check(session).all() + return [] + + def description(self) -> str: + return f"EPSG {self.epsg_code} is not fully defined in metres" diff --git a/threedi_modelchecker/checks/raster.py b/threedi_modelchecker/checks/raster.py index 0fcc1f3a..f86027c2 100644 --- a/threedi_modelchecker/checks/raster.py +++ b/threedi_modelchecker/checks/raster.py @@ -141,53 +141,37 @@ def description(self): return f"The file in {self.column_name} has multiple or no bands." -class RasterHasProjectionCheck(BaseRasterCheck): - """Check whether a raster has a projected coordinate system.""" - - def is_valid(self, path: str, interface_cls: Type[RasterInterface]): - with interface_cls(path) as raster: - if not raster.is_valid_geotiff: - return True - return raster.has_projection - - def description(self): - return f"The file in {self.column_name} has no CRS." - - -class RasterIsProjectedCheck(BaseRasterCheck): - """Check whether a raster has a projected coordinate system.""" - - def is_valid(self, path: str, interface_cls: Type[RasterInterface]): - with interface_cls(path) as raster: - if not raster.is_valid_geotiff or not raster.has_projection: - return True - return not raster.is_geographic - - def description(self): - return f"The file in {self.column_name} does not use a projected CRS." - - class RasterHasMatchingEPSGCheck(BaseRasterCheck): """Check whether a raster's EPSG code matches the EPSG code in the global settings for the SQLite.""" + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.ref_epsg_name = "" + self.ref_epsg_code = None + def get_invalid(self, session): - epsg_code_query = session.query(models.ModelSettings.epsg_code).first() - if epsg_code_query is not None: - self.epsg_code = epsg_code_query[0] - else: - self.epsg_code = None + if session.ref_epsg_code is None: + return [] + self.ref_epsg_name = session.ref_epsg_name + self.ref_epsg_code = session.ref_epsg_code return super().get_invalid(session) def is_valid(self, path: str, interface_cls: Type[RasterInterface]): + if self.ref_epsg_code is None: + return True with interface_cls(path) as raster: if not raster.is_valid_geotiff or not raster.has_projection: return True - if self.epsg_code is None or raster.epsg_code is None: + if raster.epsg_code is None: return False - return raster.epsg_code == self.epsg_code + return raster.epsg_code == self.ref_epsg_code def description(self): - return f"The file in {self.column_name} has no EPSG code or the EPSG code does not match does not match model_settings.epsg_code" + return f"The file in {self.column_name} has no EPSG code or the EPSG code does not match does not match {self.ref_epsg_name}" class RasterSquareCellsCheck(BaseRasterCheck): diff --git a/threedi_modelchecker/config.py b/threedi_modelchecker/config.py index 58887a4c..dade85e6 100644 --- a/threedi_modelchecker/config.py +++ b/threedi_modelchecker/config.py @@ -1,11 +1,11 @@ from typing import List +from geoalchemy2 import functions as geo_func from sqlalchemy import and_, func, or_, true from sqlalchemy.orm import Query from threedi_schema import constants, models from threedi_schema.beta_features import BETA_COLUMNS, BETA_VALUES -from .checks import geo_query from .checks.base import ( AllEqualCheck, BaseCheck, @@ -40,6 +40,8 @@ from .checks.factories import ( ForeignKeyCheckSetting, generate_enum_checks, + generate_epsg_geom_checks, + generate_epsg_raster_checks, generate_foreign_key_checks, generate_geometry_checks, generate_geometry_type_checks, @@ -90,10 +92,7 @@ RasterCompressionUsedCheck, RasterExistsCheck, RasterGridSizeCheck, - RasterHasMatchingEPSGCheck, RasterHasOneBandCheck, - RasterHasProjectionCheck, - RasterIsProjectedCheck, RasterIsValidCheck, RasterPixelCountCheck, RasterRangeCheck, @@ -1062,13 +1061,13 @@ def is_none_or_empty(col): ## 020x: Spatial checks -CHECKS += [ConnectionNodesDistance(error_code=201, minimum_distance=0.001)] +CHECKS += [ConnectionNodesDistance(error_code=201, minimum_distance=0.1)] CHECKS += [ QueryCheck( error_code=202, level=CheckLevel.WARNING, column=table.id, - invalid=Query(table).filter(geo_query.length(table.geom) < 5), + invalid=Query(table).filter(geo_func.ST_Length(table.geom) < 5), message=f"The length of {table.__tablename__} is very short (< 5 m). A length of at least 5.0 m is recommended to avoid timestep reduction.", ) for table in [models.Channel, models.Culvert] @@ -1369,8 +1368,8 @@ def is_none_or_empty(col): invalid=Query(models.ExchangeLine) .join(models.Channel, models.Channel.id == models.ExchangeLine.channel_id) .filter( - geo_query.length(models.ExchangeLine.geom) - < (0.8 * geo_query.length(models.Channel.geom)) + geo_func.ST_Length(models.ExchangeLine.geom) + < (0.8 * geo_func.ST_Length(models.Channel.geom)) ), message=( "exchange_line.geom should not be significantly shorter than its " @@ -1384,7 +1383,7 @@ def is_none_or_empty(col): invalid=Query(models.ExchangeLine) .join(models.Channel, models.Channel.id == models.ExchangeLine.channel_id) .filter( - geo_query.distance(models.ExchangeLine.geom, models.Channel.geom) > 500.0 + geo_func.ST_Distance(models.ExchangeLine.geom, models.Channel.geom) > 500.0 ), message=("exchange_line.geom is far (> 500 m) from its corresponding channel"), ), @@ -1455,7 +1454,7 @@ def is_none_or_empty(col): invalid=Query(models.PotentialBreach) .join(models.Channel, models.Channel.id == models.PotentialBreach.channel_id) .filter( - geo_query.distance( + geo_func.ST_Distance( func.PointN(models.PotentialBreach.geom, 1), models.Channel.geom ) > TOLERANCE_M @@ -1579,19 +1578,6 @@ def is_none_or_empty(col): column=models.ModelSettings.manhole_aboveground_storage_area, min_value=0, ), - QueryCheck( - error_code=317, - column=models.ModelSettings.epsg_code, - invalid=CONDITIONS["has_no_dem"].filter(models.ModelSettings.epsg_code == None), - message="model_settings.epsg_code may not be NULL if no dem file is provided", - ), - QueryCheck( - error_code=318, - level=CheckLevel.WARNING, - column=models.ModelSettings.epsg_code, - invalid=CONDITIONS["has_dem"].filter(models.ModelSettings.epsg_code == None), - message="if model_settings.epsg_code is NULL, it will be extracted from the DEM later, however, the modelchecker will use ESPG:28992 for its spatial checks", - ), QueryCheck( error_code=319, column=models.ModelSettings.use_2d_flow, @@ -2375,17 +2361,6 @@ def is_none_or_empty(col): for i, column in enumerate(RASTER_COLUMNS) ] CHECKS += [ - RasterHasProjectionCheck( - error_code=761 + i, - column=column, - ) - for i, column in enumerate(RASTER_COLUMNS) -] -CHECKS += [ - RasterIsProjectedCheck( - error_code=779, - column=models.ModelSettings.dem_file, - ), RasterSquareCellsCheck( error_code=780, column=models.ModelSettings.dem_file, @@ -2476,11 +2451,6 @@ def is_none_or_empty(col): min_value=-9998.0, max_value=8848.0, ), - RasterHasMatchingEPSGCheck( - error_code=797, - level=CheckLevel.WARNING, - column=models.ModelSettings.dem_file, - ), RasterGridSizeCheck( error_code=798, column=models.ModelSettings.dem_file, @@ -3550,6 +3520,10 @@ def generate_checks(self): message=f"{model.id.name} must be a positive signed 32-bit integer.", ) ] + self.checks += generate_epsg_geom_checks(model.__table__, error_code=9) + self.checks += generate_epsg_raster_checks( + model.__table__, RASTER_COLUMNS, error_code=10 + ) self.checks += CHECKS if not self.allow_beta_features: diff --git a/threedi_modelchecker/model_checks.py b/threedi_modelchecker/model_checks.py index ecf0332b..5d11f66c 100644 --- a/threedi_modelchecker/model_checks.py +++ b/threedi_modelchecker/model_checks.py @@ -1,6 +1,6 @@ from typing import Dict, Iterator, NamedTuple, Optional, Tuple -from threedi_schema import ThreediDatabase +from threedi_schema import models, ThreediDatabase from .checks.base import BaseCheck, CheckLevel from .checks.raster import LocalContext, ServerContext @@ -9,6 +9,33 @@ __all__ = ["ThreediModelChecker"] +def get_epsg_data_from_raster(session) -> Tuple[int, str]: + """ + Retrieve epsg code for schematisation loaded in session. This is done by + iterating over all geometries in the declared models and all raster files, and + stopping at the first geometry or raster file with data. + + Returns the epsg code and the name (table.column) of the source. + """ + context = session.model_checker_context + raster_interface = context.raster_interface if context is not None else None + epsg_code = None + epsg_source = "" + raster = models.ModelSettings.dem_file + raster_files = session.query(raster).filter(raster != None, raster != "").first() + if raster_files is not None and raster_files is not None: + if isinstance(context, ServerContext): + if isinstance(context.available_rasters, dict): + abs_path = context.available_rasters.get(raster.name) + else: + abs_path = context.base_path.joinpath("rasters", raster_files[0]) + with raster_interface(abs_path) as ro: + if ro.epsg_code is not None: + epsg_code = ro.epsg_code + epsg_source = "model_settings.dem_file" + return epsg_code, epsg_source + + class ThreediModelChecker: def __init__( self, @@ -57,7 +84,13 @@ def errors( """ session = self.db.get_session() session.model_checker_context = self.context - + if self.db.schema.epsg_code is not None: + session.epsg_ref_code = self.db.schema.epsg_code + session.epsg_ref_name = self.db.schema.epsg_source + else: + epsg_ref_code, epsg_ref_name = get_epsg_data_from_raster(session) + session.epsg_ref_code = epsg_ref_code + session.epsg_ref_name = epsg_ref_name for check in self.checks(level=level, ignore_checks=ignore_checks): model_errors = check.get_invalid(session) for error_row in model_errors: diff --git a/threedi_modelchecker/tests/conftest.py b/threedi_modelchecker/tests/conftest.py index b5dbee94..a3cfb890 100644 --- a/threedi_modelchecker/tests/conftest.py +++ b/threedi_modelchecker/tests/conftest.py @@ -24,7 +24,9 @@ def threedi_db(tmpdir_factory): shutil.copyfile(data_dir / "empty.sqlite", tmp_sqlite) db = ThreediDatabase(tmp_sqlite) schema = ModelSchema(db) - schema.upgrade(backup=False, upgrade_spatialite_version=False) + schema.upgrade( + backup=False, upgrade_spatialite_version=False, custom_epsg_code=28992 + ) schema.set_spatial_indexes() return db @@ -40,7 +42,6 @@ def session(threedi_db): :return: sqlalchemy.orm.session.Session """ s = threedi_db.get_session(future=True) - factories.inject_session(s) s.model_checker_context = LocalContext(base_path=data_dir) diff --git a/threedi_modelchecker/tests/factories.py b/threedi_modelchecker/tests/factories.py index 5086fbc3..b496cae0 100644 --- a/threedi_modelchecker/tests/factories.py +++ b/threedi_modelchecker/tests/factories.py @@ -4,6 +4,14 @@ from factory import Faker from threedi_schema import constants, models +# Default geometry strings + +DEFAULT_POINT = "SRID=28992;POINT (142742 473443)" +# line from DEFAULT_POINT to another point +DEFAULT_LINE = "SRID=28992;LINESTRING (142742 473443, 142747 473448)" +# polygon containing DEFAULT_POINT +DEFAULT_POLYGON = "SRID=28992;POLYGON ((142742 473443, 142743 473443, 142744 473444, 142742 473444, 142742 473443))" + def inject_session(session): """Inject the session into all factories""" @@ -47,7 +55,7 @@ class Meta: sqlalchemy_session = None code = Faker("name") - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class ChannelFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -58,7 +66,7 @@ class Meta: display_name = Faker("name") code = Faker("name") exchange_type = constants.CalculationType.CONNECTED - geom = "SRID=4326;LINESTRING(-71.064544 42.28787, -71.0645 42.287)" + geom = DEFAULT_LINE class WeirFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -75,7 +83,7 @@ class Meta: sewerage = False connection_node_id_start = 1 connection_node_id_end = 1 - geom = "SRID=4326;LINESTRING(4.885534714757985 52.38513158257129,4.88552805617346 52.38573773758626)" + geom = DEFAULT_LINE class BoundaryConditions2DFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -86,7 +94,7 @@ class Meta: type = constants.BoundaryType.WATERLEVEL.value timeseries = "0,-0.5" display_name = Faker("name") - geom = "SRID=4326;LINESTRING(4.885534714757985 52.38513158257129,4.88552805617346 52.38573773758626)" + geom = DEFAULT_LINE class BoundaryConditions1DFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -97,7 +105,7 @@ class Meta: type = constants.BoundaryType.WATERLEVEL timeseries = "0,-0.5" connection_node_id = 1 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class PumpMapFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -105,7 +113,7 @@ class Meta: model = models.PumpMap sqlalchemy_session = None - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class PumpFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -120,7 +128,7 @@ class Meta: start_level = 1.0 lower_stop_level = 0.0 capacity = 5.0 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class CrossSectionLocationFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -132,7 +140,7 @@ class Meta: reference_level = 0.0 friction_type = constants.FrictionType.CHEZY friction_value = 0.0 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class AggregationSettingsFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -163,7 +171,7 @@ class Meta: timeseries = "0,-0.1" connection_node_id = 1 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class Lateral2DFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -172,7 +180,7 @@ class Meta: sqlalchemy_session = None timeseries = "0,-0.2" - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT type = constants.Later2dType.SURFACE @@ -195,7 +203,7 @@ class Meta: model = models.DryWeatherFlowMap sqlalchemy_session = None - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class DryWeatherFlowFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -203,7 +211,7 @@ class Meta: model = models.DryWeatherFlow sqlalchemy_session = None - geom = "SRID=4326;POLYGON((30 10, 40 40, 20 40, 10 20, 30 10))" + geom = DEFAULT_POLYGON class DryWeatherFlowDistributionFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -223,7 +231,7 @@ class Meta: sqlalchemy_session = None area = 0.0 - geom = "SRID=4326;POLYGON((30 10, 40 40, 20 40, 10 20, 30 10))" + geom = DEFAULT_POLYGON # surface_parameters = factory.SubFactory(SurfaceParameterFactory) @@ -234,7 +242,7 @@ class Meta: sqlalchemy_session = None percentage = 100.0 - geom = "SRID=4326;LINESTRING(30 10, 10 30, 40 40)" + geom = DEFAULT_LINE class TagsFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -253,7 +261,7 @@ class Meta: measure_operator = constants.MeasureOperators.greater_than target_type = constants.StructureControlTypes.channel target_id = 10 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class ControlMemoryFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -270,7 +278,7 @@ class Meta: is_active = True upper_threshold = 1.0 lower_threshold = -1.0 - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class ControlMeasureMapFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -279,7 +287,7 @@ class Meta: sqlalchemy_session = None control_type = constants.MeasureVariables.waterlevel - geom = "SRID=4326;LINESTRING(30 10, 10 30, 40 40)" + geom = DEFAULT_LINE class ControlMeasureLocationFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -287,7 +295,7 @@ class Meta: model = models.ControlMeasureLocation sqlalchemy_session = None - geom = "SRID=4326;POINT(-71.064544 42.28787)" + geom = DEFAULT_POINT class CulvertFactory(factory.alchemy.SQLAlchemyModelFactory): @@ -298,7 +306,7 @@ class Meta: code = "code" display_name = Faker("name") exchange_type = constants.CalculationTypeCulvert.ISOLATED_NODE - geom = "SRID=4326;LINESTRING(-71.064544 42.28787, -71.0645 42.287)" + geom = DEFAULT_LINE friction_value = 0.03 friction_type = 2 invert_level_start = 0.1 @@ -314,7 +322,7 @@ class Meta: display_name = Faker("name") code = "code" - geom = "SRID=4326;LINESTRING(-71.06452 42.2874, -71.06452 42.286)" + geom = DEFAULT_LINE class VegetationDragFactory(factory.alchemy.SQLAlchemyModelFactory): diff --git a/threedi_modelchecker/tests/test_checks_base.py b/threedi_modelchecker/tests/test_checks_base.py index 18a48cef..f50ee39e 100644 --- a/threedi_modelchecker/tests/test_checks_base.py +++ b/threedi_modelchecker/tests/test_checks_base.py @@ -4,11 +4,11 @@ from sqlalchemy.orm import Query from threedi_schema import constants, custom_types, models -from threedi_modelchecker.checks import geo_query from threedi_modelchecker.checks.base import ( _sqlalchemy_to_sqlite_types, AllEqualCheck, EnumCheck, + EPSGGeomCheck, ForeignKeyCheck, GeometryCheck, GeometryTypeCheck, @@ -158,7 +158,6 @@ def test_unique_check_multiple_description(): def test_all_equal_check(session): factories.ModelSettingsFactory(minimum_table_step_size=0.5) factories.ModelSettingsFactory(minimum_table_step_size=0.5) - check = AllEqualCheck(models.ModelSettings.minimum_table_step_size) invalid_rows = check.get_invalid(session) assert len(invalid_rows) == 0 @@ -310,7 +309,7 @@ def test_type_check_boolean(session): def test_geometry_check(session): - factories.ConnectionNodeFactory(geom="SRID=4326;POINT(-371.064544 42.28787)") + factories.ConnectionNodeFactory(geom=factories.DEFAULT_POINT) geometry_check = GeometryCheck(models.ConnectionNode.geom) invalid_rows = geometry_check.get_invalid(session) @@ -319,9 +318,7 @@ def test_geometry_check(session): def test_geometry_type_check(session): - factories.ConnectionNodeFactory.create_batch( - 2, geom="SRID=4326;POINT(-71.064544 42.28787)" - ) + factories.ConnectionNodeFactory.create_batch(2, geom=factories.DEFAULT_POINT) geometry_type_check = GeometryTypeCheck(models.ConnectionNode.geom) invalid_rows = geometry_type_check.get_invalid(session) @@ -532,58 +529,6 @@ def test_global_settings_use_1d_flow_and_no_1d_elements(session): assert len(errors) == 0 -def test_length_geom_linestring_in_4326(session): - factories.ModelSettingsFactory(epsg_code=28992) - channel_too_short = factories.ChannelFactory( - geom="SRID=4326;LINESTRING(" - "-0.38222938832999598 -0.13872236685816669, " - "-0.38222930900909202 -0.13872236685816669)", - ) - factories.ChannelFactory( - geom="SRID=4326;LINESTRING(" - "-0.38222938468305784 -0.13872235682908687, " - "-0.38222931083256106 -0.13872235591735235, " - "-0.38222930992082654 -0.13872207236791409, " - "-0.38222940929989008 -0.13872235591735235)", - ) - - q = Query(models.Channel).filter(geo_query.length(models.Channel.geom) < 0.05) - check_length_linestring = QueryCheck( - column=models.Channel.geom, - invalid=q, - message="Length of the channel is too short, should be at least 0.05m", - ) - - errors = check_length_linestring.get_invalid(session) - assert len(errors) == 1 - assert errors[0].id == channel_too_short.id - - -def test_length_geom_linestring_missing_epsg_from_global_settings(session): - factories.ChannelFactory( - geom="SRID=4326;LINESTRING(" - "-0.38222938832999598 -0.13872236685816669, " - "-0.38222930900909202 -0.13872236685816669)", - ) - factories.ChannelFactory( - geom="SRID=4326;LINESTRING(" - "-0.38222938468305784 -0.13872235682908687, " - "-0.38222931083256106 -0.13872235591735235, " - "-0.38222930992082654 -0.13872207236791409, " - "-0.38222940929989008 -0.13872235591735235)", - ) - - q = Query(models.Channel).filter(geo_query.length(models.Channel.geom) < 0.05) - check_length_linestring = QueryCheck( - column=models.Channel.geom, - invalid=q, - message="Length of the channel is too short, should be at least 0.05m", - ) - - errors = check_length_linestring.get_invalid(session) - assert len(errors) == 1 - - @pytest.mark.parametrize( "min_value,max_value,left_inclusive,right_inclusive", [ @@ -646,3 +591,20 @@ def test_list_of_inst_check(session, tag_ids_string, nof_invalid_expected): check = ListOfIntsCheck(column=models.DryWeatherFlow.tags) invalid_rows = check.get_invalid(session) assert len(invalid_rows) == nof_invalid_expected + + +@pytest.mark.parametrize( + "ref_epsg, valid", + [ + (28992, True), + (4326, False), + (None, True), # skip check when there is no epsg + ], +) +def test_epsg_geom_check(session, ref_epsg, valid): + factories.ConnectionNodeFactory() + session.ref_epsg_code = ref_epsg + session.ref_epsg_name = "foo" + check = EPSGGeomCheck(column=models.ConnectionNode.geom) + invalids = check.get_invalid(session) + assert (len(invalids) == 0) == valid diff --git a/threedi_modelchecker/tests/test_checks_factories.py b/threedi_modelchecker/tests/test_checks_factories.py index 6eeb33f3..adbf9755 100644 --- a/threedi_modelchecker/tests/test_checks_factories.py +++ b/threedi_modelchecker/tests/test_checks_factories.py @@ -5,6 +5,8 @@ from threedi_modelchecker.checks.factories import ( ForeignKeyCheckSetting, generate_enum_checks, + generate_epsg_geom_checks, + generate_epsg_raster_checks, generate_foreign_key_checks, generate_geometry_checks, generate_not_null_checks, @@ -112,3 +114,34 @@ def test_gen_enum_checks_custom_mapping(name): checks = {check.column.name: check for check in enum_checks} assert checks["aggregation_method"].level == CheckLevel.WARNING assert checks["flow_variable"].level == CheckLevel.ERROR + + +@pytest.mark.parametrize( + "model, nof_checks_expected", + [ + (models.ModelSettings, 0), + (models.ConnectionNode, 1), + ], +) +def test_gen_epsg_geom_checks(model, nof_checks_expected): + assert len(generate_epsg_geom_checks(model.__table__)) == nof_checks_expected + + +@pytest.mark.parametrize( + "model, nof_checks_expected", + [ + (models.ConnectionNode, 0), + (models.GroundWater, 1), + (models.ModelSettings, 2), + ], +) +def test_gen_epsg_raster_checks(model, nof_checks_expected): + raster_columns = [ + models.ModelSettings.dem_file, + models.ModelSettings.friction_coefficient_file, + models.GroundWater.equilibrium_infiltration_rate_file, + ] + assert ( + len(generate_epsg_raster_checks(model.__table__, raster_columns)) + == nof_checks_expected + ) diff --git a/threedi_modelchecker/tests/test_checks_location.py b/threedi_modelchecker/tests/test_checks_location.py index af3c3572..668f0460 100644 --- a/threedi_modelchecker/tests/test_checks_location.py +++ b/threedi_modelchecker/tests/test_checks_location.py @@ -12,51 +12,47 @@ ) from threedi_modelchecker.tests import factories +SRID = 28992 +POINT1 = "142742 473443" +POINT2 = "142743 473443" # POINT1 + 1 meter +POINT3 = "142747 473443" # POINT1 + 5 meter -def test_point_location_check(session): - factories.ChannelFactory( - id=1, - geom="SRID=4326;LINESTRING(5.387204 52.155172, 5.387204 52.155262)", - ) - factories.CrossSectionLocationFactory( - channel_id=1, geom="SRID=4326;POINT(5.387204 52.155200)" - ) - factories.CrossSectionLocationFactory( - channel_id=1, geom="SRID=4326;POINT(5.387218 52.155244)" + +@pytest.mark.parametrize( + "bc_geom, nof_invalid", [(POINT1, 0), (POINT2, 0), (POINT3, 1)] +) +def test_point_location_check(session, bc_geom, nof_invalid): + factories.ConnectionNodeFactory(id=0, geom=f"SRID={SRID};POINT({POINT1})") + factories.BoundaryConditions1DFactory( + id=0, connection_node_id=0, geom=f"SRID={SRID};POINT({bc_geom})" ) errors = PointLocationCheck( - column=models.CrossSectionLocation.geom, - ref_column=models.CrossSectionLocation.channel_id, - ref_table=models.Channel, - max_distance=0.1, + column=models.BoundaryCondition1D.geom, + ref_column=models.BoundaryCondition1D.connection_node_id, + ref_table=models.ConnectionNode, + max_distance=1, ).get_invalid(session) - assert len(errors) == 1 + assert len(errors) == nof_invalid @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ("LINESTRING(5.387218 52.155172, 5.387218 52.155262)", 0), # within tolerance - ("LINESTRING(5.387204 52.155262, 5.387204 52.155172)", 0), # reversed - ( - "LINESTRING(5.387218 52.155262, 5.387218 52.155172)", - 0, - ), # reversed, within tolerance - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong - ("LINESTRING(5.387204 52.155172, 5.387204 52.164151)", 1), # endpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING({POINT2}, {POINT3})", 0), # within tolerance + (f"LINESTRING({POINT3}, {POINT1})", 0), # reversed + (f"LINESTRING({POINT3}, {POINT2})", 0), # within tolerance, within tolerance + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong + (f"LINESTRING({POINT1}, 142752 473443)", 1), # endpoint is wrong ], ) def test_linestring_location_check(session, channel_geom, nof_invalid): - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155172)") - factories.ConnectionNodeFactory(id=2, geom="SRID=4326;POINT(5.387204 52.155262)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.ConnectionNodeFactory(id=2, geom=f"SRID={SRID};POINT({POINT3})") factories.ChannelFactory( connection_node_id_start=1, connection_node_id_end=2, - geom=f"SRID=4326;{channel_geom}", + geom=f"SRID={SRID};{channel_geom}", ) errors = LinestringLocationCheck( column=models.Channel.geom, @@ -72,20 +68,17 @@ def test_linestring_location_check(session, channel_geom, nof_invalid): @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong ], ) def test_connection_node_linestring_location_check(session, channel_geom, nof_invalid): - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155172)") - factories.ConnectionNodeFactory(id=2, geom="SRID=4326;POINT(5.387204 52.155262)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.ConnectionNodeFactory(id=2, geom=f"SRID={SRID};POINT({POINT3})") factories.ChannelFactory( connection_node_id_start=1, connection_node_id_end=2, - geom=f"SRID=4326;{channel_geom}", + geom=f"SRID={SRID};{channel_geom}", ) errors = ConnectionNodeLinestringLocationCheck( column=models.Channel.geom, max_distance=1.01 @@ -96,11 +89,8 @@ def test_connection_node_linestring_location_check(session, channel_geom, nof_in @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong ], ) @pytest.mark.parametrize( @@ -110,22 +100,20 @@ def test_connection_node_linestring_location_check(session, channel_geom, nof_in def test_control_measure_map_linestring_map_location_check( session, control_table, control_type, channel_geom, nof_invalid ): - factories.ControlMeasureLocationFactory( - id=1, geom="SRID=4326;POINT(5.387204 52.155172)" - ) - factories.ControlMemoryFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155262)") - factories.ControlTableFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155262)") + factories.ControlMeasureLocationFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.ControlMemoryFactory(id=1, geom=f"SRID={SRID};POINT({POINT3})") + factories.ControlTableFactory(id=1, geom=f"SRID={SRID};POINT({POINT3})") factories.ControlMeasureMapFactory( measure_location_id=1, control_id=1, control_type="memory", - geom=f"SRID=4326;{channel_geom}", + geom=f"SRID={SRID};{channel_geom}", ) factories.ControlMeasureMapFactory( measure_location_id=1, control_id=1, control_type="table", - geom=f"SRID=4326;{channel_geom}", + geom=f"SRID={SRID};{channel_geom}", ) errors = ControlMeasureMapLinestringMapLocationCheck( control_table=control_table, @@ -138,23 +126,19 @@ def test_control_measure_map_linestring_map_location_check( @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong ], ) def test_dwf_map_linestring_map_location_check(session, channel_geom, nof_invalid): - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155172)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.DryWeatherFlowFactory( id=1, - geom="SRID=4326;POLYGON((5.387204 52.155262, 5.387204 52.155172, 5.387304 52.155172, 5.387304 52.155262, 5.387204 52.155262))", + geom=f"SRID={SRID};POLYGON(({POINT1}, {POINT3}, 142742 473445, {POINT1}))", ) factories.DryWheatherFlowMapFactory( - connection_node_id=1, - dry_weather_flow_id=1, - geom=f"SRID=4326;{channel_geom}", + connection_node_id=1, dry_weather_flow_id=1, geom=f"SRID={SRID};{channel_geom}" ) errors = DWFMapLinestringLocationCheck(max_distance=1.01).get_invalid(session) assert len(errors) == nof_invalid @@ -163,20 +147,15 @@ def test_dwf_map_linestring_map_location_check(session, channel_geom, nof_invali @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong ], ) def test_pump_map_linestring_map_location_check(session, channel_geom, nof_invalid): - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155172)") - factories.PumpFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155262)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.PumpFactory(id=1, geom=f"SRID={SRID};POINT({POINT3})") factories.PumpMapFactory( - connection_node_id_end=1, - pump_id=1, - geom=f"SRID=4326;{channel_geom}", + connection_node_id_end=1, pump_id=1, geom=f"SRID={SRID};{channel_geom}" ) errors = PumpMapLinestringLocationCheck(max_distance=1.01).get_invalid(session) assert len(errors) == nof_invalid @@ -185,23 +164,18 @@ def test_pump_map_linestring_map_location_check(session, channel_geom, nof_inval @pytest.mark.parametrize( "channel_geom, nof_invalid", [ - ("LINESTRING(5.387204 52.155172, 5.387204 52.155262)", 0), - ( - "LINESTRING(5.387204 52.164151, 5.387204 52.155262)", - 1, - ), # startpoint is wrong + (f"LINESTRING({POINT1}, {POINT3})", 0), + (f"LINESTRING(142732 473443, {POINT3})", 1), # startpoint is wrong ], ) def test_surface_map_linestring_map_location_check(session, channel_geom, nof_invalid): - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(5.387204 52.155172)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") factories.SurfaceFactory( id=1, - geom="SRID=4326;POLYGON((5.387204 52.155262, 5.387204 52.155172, 5.387304 52.155172, 5.387304 52.155262, 5.387204 52.155262))", + geom=f"SRID={SRID};POLYGON(({POINT1}, {POINT3}, 142742 473445, {POINT1}))", ) factories.SurfaceMapFactory( - connection_node_id=1, - surface_id=1, - geom=f"SRID=4326;{channel_geom}", + connection_node_id=1, surface_id=1, geom=f"SRID={SRID};{channel_geom}" ) errors = SurfaceMapLinestringLocationCheck(max_distance=1.01).get_invalid(session) assert len(errors) == nof_invalid diff --git a/threedi_modelchecker/tests/test_checks_other.py b/threedi_modelchecker/tests/test_checks_other.py index 73aec95a..04038e17 100644 --- a/threedi_modelchecker/tests/test_checks_other.py +++ b/threedi_modelchecker/tests/test_checks_other.py @@ -1,8 +1,7 @@ from unittest import mock import pytest -from sqlalchemy import func, select, text -from sqlalchemy.orm import aliased, Query +from sqlalchemy import select, text from threedi_schema import constants, models, ThreediDatabase from threedi_schema.beta_features import BETA_COLUMNS, BETA_VALUES @@ -25,6 +24,9 @@ FeatureClosedCrossSectionCheck, InflowNoFeaturesCheck, MaxOneRecordCheck, + ModelEPSGCheckProjected, + ModelEPSGCheckUnits, + ModelEPSGCheckValid, NodeSurfaceConnectionsCheck, OpenChannelsWithNestedNewton, PotentialBreachInterdistanceCheck, @@ -41,6 +43,11 @@ from . import factories +SRID = 28992 +POINT1 = "142742 473443" +POINT2 = "142743 473443" # POINT1 + 1 meter +POINT3 = "142747 473443" # POINT1 + 5 meter + @pytest.mark.parametrize( "aggregation_method,flow_variable,expected_result", @@ -73,29 +80,23 @@ def test_aggregation_settings( def test_connection_nodes_length(session): - factories.ModelSettingsFactory(epsg_code=28992) - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(-71.064544 42.28787)") - factories.ConnectionNodeFactory(id=2, geom="SRID=4326;POINT(-71.0645 42.287)") - factories.ConnectionNodeFactory( - id=3, geom="SRID=4326;POINT(-0.38222938832999598 -0.13872236685816669)" - ), - factories.ConnectionNodeFactory( - id=4, geom="SRID=4326;POINT(-0.38222930900909202 -0.13872236685816669)" - ), + factories.ModelSettingsFactory() + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.ConnectionNodeFactory(id=2, geom=f"SRID={SRID};POINT({POINT2})") factories.WeirFactory( connection_node_id_start=1, connection_node_id_end=2, ) weir_too_short = factories.WeirFactory( - connection_node_id_start=3, - connection_node_id_end=4, + connection_node_id_start=1, + connection_node_id_end=1, ) check_length = ConnectionNodesLength( column=models.Weir.id, start_node=models.Weir.connection_node_id_start, end_node=models.Weir.connection_node_id_end, - min_distance=0.05, + min_distance=0.01, ) errors = check_length.get_invalid(session) @@ -104,11 +105,9 @@ def test_connection_nodes_length(session): def test_connection_nodes_length_missing_start_node(session): - factories.ModelSettingsFactory(epsg_code=28992) + factories.ModelSettingsFactory() factories.WeirFactory(connection_node_id_start=1, connection_node_id_end=2) - factories.ConnectionNodeFactory( - id=2, geom="SRID=4326;POINT(-0.38222930900909202 -0.13872236685816669)" - ) + factories.ConnectionNodeFactory(id=2) check_length = ConnectionNodesLength( column=models.Weir.id, start_node=models.Weir.connection_node_id_start, @@ -123,12 +122,9 @@ def test_connection_nodes_length_missing_start_node(session): def test_connection_nodes_length_missing_end_node(session): if session.bind.name == "postgresql": pytest.skip("Postgres only accepts coords in epsg 4326") - factories.ModelSettingsFactory(epsg_code=28992) + factories.ModelSettingsFactory() factories.WeirFactory(connection_node_id_start=1, connection_node_id_end=2) - factories.ConnectionNodeFactory( - id=1, geom="SRID=4326;POINT(-0.38222930900909202 -0.13872236685816669)" - ) - + factories.ConnectionNodeFactory(id=1) check_length = ConnectionNodesLength( column=models.Weir.id, start_node=models.Weir.connection_node_id_start, @@ -142,29 +138,29 @@ def test_connection_nodes_length_missing_end_node(session): def test_open_channels_with_nested_newton(session): factories.NumericalSettingsFactory(use_nested_newton=0) - factories.ConnectionNodeFactory(id=1, geom="SRID=4326;POINT(-71.064544 42.28787)") - factories.ConnectionNodeFactory(id=2, geom="SRID=4326;POINT(-71.0645 42.287)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({POINT1})") + factories.ConnectionNodeFactory(id=2, geom=f"SRID={SRID};POINT({POINT2})") factories.ChannelFactory( id=1, connection_node_id_start=1, connection_node_id_end=2, - geom="SRID=4326;LINESTRING(-71.064544 42.28787, -71.0645 42.287)", + geom=f"SRID={SRID};LINESTRING({POINT1}, {POINT2})", ) factories.CrossSectionLocationFactory( channel_id=1, cross_section_shape=constants.CrossSectionShape.TABULATED_TRAPEZIUM, cross_section_table="0,1\n1,0", - geom="SRID=4326;POINT(-71.0645 42.287)", + geom=f"SRID={SRID};POINT({POINT2})", ) factories.ChannelFactory( id=2, connection_node_id_start=1, connection_node_id_end=2, - geom="SRID=4326;LINESTRING(-71.064544 42.28787, -71.0645 42.287)", + geom=f"SRID={SRID};LINESTRING({POINT1}, {POINT2})", ) factories.CrossSectionLocationFactory( channel_id=2, - geom="SRID=4326;POINT(-71.0645 42.287)", + geom=f"SRID={SRID};POINT({POINT2})", cross_section_shape=constants.CrossSectionShape.EGG, ) @@ -197,33 +193,31 @@ def test_channel_manhole_level_check( # using factories, create one minimal test case which passes, and one which fails # once that works, parametrise. # use nested factories for channel and connectionNode - starting_coordinates = "4.718300 52.696686" - ending_coordinates = "4.718255 52.696709" factories.ConnectionNodeFactory( id=1, - geom=f"SRID=4326;POINT({starting_coordinates})", + geom=f"SRID={SRID};POINT({POINT1})", bottom_level=manhole_level, ) factories.ConnectionNodeFactory( id=2, - geom=f"SRID=4326;POINT({ending_coordinates})", + geom=f"SRID={SRID};POINT({POINT2})", bottom_level=manhole_level, ) factories.ChannelFactory( id=1, - geom=f"SRID=4326;LINESTRING({starting_coordinates}, {ending_coordinates})", + geom=f"SRID={SRID};LINESTRING({POINT1}, {POINT2})", connection_node_id_start=1, connection_node_id_end=2, ) # starting cross-section location factories.CrossSectionLocationFactory( - geom="SRID=4326;POINT(4.718278 52.696697)", + geom=f"SRID={SRID};POINT(142742.25 473443)", reference_level=starting_reference_level, channel_id=1, ) # ending cross-section location factories.CrossSectionLocationFactory( - geom="SRID=4326;POINT(4.718264 52.696704)", + geom=f"SRID={SRID};POINT(142743.25 473443)", reference_level=ending_reference_level, channel_id=1, ) @@ -232,33 +226,18 @@ def test_channel_manhole_level_check( assert len(errors) == errors_number -def test_node_distance(session): - con1_too_close = factories.ConnectionNodeFactory( - geom="SRID=4326;POINT(4.728282 52.64579283592512)" - ) - con2_too_close = factories.ConnectionNodeFactory( - geom="SRID=4326;POINT(4.72828 52.64579283592512)" - ) - # Good distance - factories.ConnectionNodeFactory( - geom="SRID=4326;POINT(4.726838755789598 52.64514133594995)" - ) - - # sanity check to see the distances between the nodes - node_a = aliased(models.ConnectionNode) - node_b = aliased(models.ConnectionNode) - distances_query = Query(func.ST_Distance(node_a.geom, node_b.geom, 1)).filter( - node_a.id != node_b.id - ) - # Shows the distances between all 3 nodes: node 1 and 2 are too close - distances_query.with_session(session).all() - - check = ConnectionNodesDistance(minimum_distance=10) +@pytest.mark.parametrize("other_x, valid", [(142740.05, False), (142740.15, True)]) +def test_node_distance_alt(session, other_x, valid): + factories.ConnectionNodeFactory(id=0, geom=f"SRID={SRID};POINT(142740 473443)") + factories.ConnectionNodeFactory(id=1, geom=f"SRID={SRID};POINT({other_x} 473443)") + # Note that the test uses plain sqlite, so this needs to be committed + session.commit() + check = ConnectionNodesDistance(minimum_distance=0.1) invalid = check.get_invalid(session) - assert len(invalid) == 2 - invalid_ids = [i.id for i in invalid] - assert con1_too_close.id in invalid_ids - assert con2_too_close.id in invalid_ids + assert (len(invalid) == 0) == valid + # Remove connection nodes + session.query(models.ConnectionNode).delete() + session.commit() class TestCrossSectionSameConfiguration: @@ -390,16 +369,13 @@ def test_cross_section_same_configuration( """ factories.ChannelFactory( id=1, - geom="SRID=4326;LINESTRING(4.718301 52.696686, 4.718255 52.696709)", ) factories.ChannelFactory( id=2, - geom="SRID=4326;LINESTRING(4.718301 52.696686, 4.718255 52.696709)", ) # shape 1 is always open factories.CrossSectionLocationFactory( channel_id=1, - geom="SRID=4326;POINT(4.718278 52.696697)", cross_section_shape=1, cross_section_width=3, cross_section_height=4, @@ -407,7 +383,6 @@ def test_cross_section_same_configuration( # the second one is parametrised factories.CrossSectionLocationFactory( channel_id=1 if same_channels else 2, - geom="SRID=4326;POINT(4.718265 52.696704)", cross_section_shape=shape, cross_section_width=width, cross_section_height=height, @@ -434,20 +409,20 @@ def test_spatial_index_disabled(empty_sqlite_v4): @pytest.mark.parametrize( - "x,y,ok", + "start,ok", [ - (-71.064544, 42.28787, True), # at start - (-71.0645, 42.287, True), # at end - (-71.06452, 42.2874, True), # middle - (-71.064544, 42.287869, False), # close to start - (-71.064499, 42.287001, False), # close to end + ("142742 473443", True), # at start + ("142747 473448", True), # at end + ("142744.5 473445.5", True), # middle + ("142742 473443.01", False), # close to start + ("142747 473447.99", False), # close to end ], ) -def test_potential_breach_start_end(session, x, y, ok): - # channel geom: LINESTRING (-71.064544 42.28787, -71.0645 42.287) +def test_potential_breach_start_end(session, start, ok): + # channel geom: "SRID={SRID}};LINESTRING (142742 473443, 142747 473448)" factories.ChannelFactory(id=1) factories.PotentialBreachFactory( - geom=f"SRID=4326;LINESTRING({x} {y}, -71.064544 42.286)", channel_id=1 + geom=f"SRID={SRID};LINESTRING({start}, 142750 473450)", channel_id=1 ) check = PotentialBreachStartEndCheck(models.PotentialBreach.geom, min_distance=1.0) invalid = check.get_invalid(session) @@ -458,21 +433,23 @@ def test_potential_breach_start_end(session, x, y, ok): @pytest.mark.parametrize( - "x,y,ok", + "start,channel_id,ok", [ - (-71.06452, 42.2874, True), # exactly on other - (-71.06452, 42.287401, False), # too close to other - (-71.0645, 42.287, True), # far enough from other + (POINT1, 1, True), # exactly on other + ("142742.5 473443.5", 1, False), # too close to other + ("142742.5 473443.5", 2, True), # too close to other but other channel + ("142740 473440", 1, True), # far enough from other ], ) -def test_potential_breach_interdistance(session, x, y, ok): - # channel geom: LINESTRING (-71.064544 42.28787, -71.0645 42.287) +def test_potential_breach_interdistance(session, start, channel_id, ok): + # channel geom: "SRID=28992;LINESTRING (142742 473443, 142747 473448)" factories.ChannelFactory(id=1) + factories.ChannelFactory(id=2) factories.PotentialBreachFactory( - geom="SRID=4326;LINESTRING(-71.06452 42.2874, -71.0646 42.286)", channel_id=1 + geom=f"SRID={SRID};LINESTRING({POINT1}, {POINT3})", channel_id=1 ) factories.PotentialBreachFactory( - geom=f"SRID=4326;LINESTRING({x} {y}, -71.064544 42.286)", channel_id=1 + geom=f"SRID={SRID};LINESTRING({start}, 142737 473438)", channel_id=channel_id ) check = PotentialBreachInterdistanceCheck( models.PotentialBreach.geom, min_distance=1.0 @@ -484,23 +461,6 @@ def test_potential_breach_interdistance(session, x, y, ok): assert len(invalid) == 1 -def test_potential_breach_interdistance_other_channel(session): - factories.ChannelFactory(id=1) - factories.ChannelFactory(id=2) - factories.PotentialBreachFactory( - geom="SRID=4326;LINESTRING(-71.06452 42.2874, -71.0646 42.286)", channel_id=1 - ) - factories.PotentialBreachFactory( - geom="SRID=4326;LINESTRING(-71.06452 42.287401, -71.064544 42.286)", - channel_id=2, - ) - check = PotentialBreachInterdistanceCheck( - models.PotentialBreach.geom, min_distance=1.0 - ) - invalid = check.get_invalid(session) - assert len(invalid) == 0 - - @pytest.mark.parametrize( "storage_area,time_step,expected_result,capacity", [ @@ -605,13 +565,18 @@ def test_feature_closed_cross_section(session, shape, expected_result): @pytest.mark.parametrize( "defined_area, max_difference, expected_result", [ - (1.5, 1.2, 0), - (2, 1, 1), - (1, 0.5, 1), + (1.2, 0.5, 0), + (1, 1, 0), + (2.1, 1, 1), + (1.6, 0.5, 1), ], ) def test_defined_area(session, defined_area, max_difference, expected_result): - geom = "SRID=4326;POLYGON((4.7 52.5, 4.7 52.50001, 4.70001 52.50001, 4.70001 52.50001))" + # create square polygon with area 1 + x0 = int(POINT1.split()[0]) + y0 = int(POINT1.split()[1]) + geom = f"SRID={SRID};POLYGON(({x0} {y0}, {x0 + 1} {y0}, {x0 + 1} {y0 + 1}, {x0} {y0 + 1}, {x0} {y0}))" + factories.SurfaceFactory(area=defined_area, geom=geom) check = DefinedAreaCheck(models.Surface.area, max_difference=max_difference) invalid = check.get_invalid(session) @@ -961,3 +926,50 @@ def test_dwf_distribution_sum_check(session, distribution, valid): check = DWFDistributionSumCheck() invalids = check.get_invalid(session) assert (len(invalids) == 0) == valid + + +@pytest.mark.parametrize( + "ref_epsg, valid", + [ + (28992, True), + (999999, False), + ], +) +def test_model_epsg_check_valid(session, ref_epsg, valid): + factories.ModelSettingsFactory() + session.ref_epsg_code = ref_epsg + check = ModelEPSGCheckValid() + invalids = check.get_invalid(session) + assert (len(invalids) == 0) == valid + + +@pytest.mark.parametrize( + "ref_epsg, valid", + [ + (28992, True), + (4326, False), + (None, True), # skip check when there is no epsg + ], +) +def test_model_epsg_check_projected(session, ref_epsg, valid): + factories.ModelSettingsFactory() + session.ref_epsg_code = ref_epsg + check = ModelEPSGCheckProjected() + invalids = check.get_invalid(session) + assert (len(invalids) == 0) == valid + + +@pytest.mark.parametrize( + "ref_epsg, valid", + [ + (28992, True), + (4326, False), + (None, True), # skip check when there is no epsg + ], +) +def test_model_epsg_check_units(session, ref_epsg, valid): + factories.ModelSettingsFactory() + session.ref_epsg_code = ref_epsg + check = ModelEPSGCheckUnits() + invalids = check.get_invalid(session) + assert (len(invalids) == 0) == valid diff --git a/threedi_modelchecker/tests/test_checks_raster.py b/threedi_modelchecker/tests/test_checks_raster.py index c55fcf69..5480ce32 100644 --- a/threedi_modelchecker/tests/test_checks_raster.py +++ b/threedi_modelchecker/tests/test_checks_raster.py @@ -10,8 +10,6 @@ RasterGridSizeCheck, RasterHasMatchingEPSGCheck, RasterHasOneBandCheck, - RasterHasProjectionCheck, - RasterIsProjectedCheck, RasterIsValidCheck, RasterPixelCountCheck, RasterRangeCheck, @@ -230,23 +228,6 @@ def test_one_band_err(tmp_path, interface_cls): assert not check.is_valid(path, interface_cls) -@pytest.mark.parametrize( - "interface_cls", [GDALRasterInterface, RasterIORasterInterface] -) -def test_has_projection_ok(valid_geotiff, interface_cls): - check = RasterHasProjectionCheck(column=models.ModelSettings.dem_file) - assert check.is_valid(valid_geotiff, interface_cls) - - -@pytest.mark.parametrize( - "interface_cls", [GDALRasterInterface, RasterIORasterInterface] -) -def test_has_projection_err(tmp_path, interface_cls): - path = create_geotiff(tmp_path / "raster.tiff", epsg=None) - check = RasterHasProjectionCheck(column=models.ModelSettings.dem_file) - assert not check.is_valid(path, interface_cls) - - NULL_EPSG_CODE = ( 'PROJCS["unnamed",GEOGCS["GRS 1980(IUGG, 1980)"' + ',DATUM["unknown",SPHEROID["GRS80",6378137,298.257222101]],' @@ -267,42 +248,16 @@ def test_has_projection_err(tmp_path, interface_cls): (28992, 28992, True), (28992, 27700, False), (NULL_EPSG_CODE, 28992, False), - (27700, None, False), + (27700, None, True), ], ) def test_has_epsg(tmp_path, interface_cls, raster_epsg, sqlite_epsg, validity): path = create_geotiff(tmp_path / "raster.tiff", epsg=raster_epsg) check = RasterHasMatchingEPSGCheck(column=models.ModelSettings.dem_file) - check.epsg_code = sqlite_epsg + check.ref_epsg_code = sqlite_epsg assert check.is_valid(path, interface_cls) == validity -@pytest.mark.parametrize( - "interface_cls", [GDALRasterInterface, RasterIORasterInterface] -) -def test_is_projected_ok(valid_geotiff, interface_cls): - check = RasterIsProjectedCheck(column=models.ModelSettings.dem_file) - assert check.is_valid(valid_geotiff, interface_cls) - - -@pytest.mark.parametrize( - "interface_cls", [GDALRasterInterface, RasterIORasterInterface] -) -def test_is_projected_err(tmp_path, interface_cls): - path = create_geotiff(tmp_path / "raster.tiff", epsg=4326) - check = RasterIsProjectedCheck(column=models.ModelSettings.dem_file) - assert not check.is_valid(path, interface_cls) - - -@pytest.mark.parametrize( - "interface_cls", [GDALRasterInterface, RasterIORasterInterface] -) -def test_is_projected_no_projection(tmp_path, interface_cls): - path = create_geotiff(tmp_path / "raster.tiff", epsg=None) - check = RasterIsProjectedCheck(column=models.ModelSettings.dem_file) - assert check.is_valid(path, interface_cls) - - @pytest.mark.parametrize( "interface_cls", [GDALRasterInterface, RasterIORasterInterface] ) @@ -426,9 +381,7 @@ def test_raster_range_no_data(tmp_path, interface_cls): "check", [ RasterHasOneBandCheck(column=models.ModelSettings.dem_file), - RasterHasProjectionCheck(column=models.ModelSettings.dem_file), RasterHasMatchingEPSGCheck(column=models.ModelSettings.dem_file), - RasterIsProjectedCheck(column=models.ModelSettings.dem_file), RasterSquareCellsCheck(column=models.ModelSettings.dem_file), RasterGridSizeCheck(column=models.ModelSettings.dem_file), RasterRangeCheck(column=models.ModelSettings.dem_file, min_value=0), diff --git a/threedi_modelchecker/tests/test_model_checks.py b/threedi_modelchecker/tests/test_model_checks.py index a8c1ee5b..af4635b9 100644 --- a/threedi_modelchecker/tests/test_model_checks.py +++ b/threedi_modelchecker/tests/test_model_checks.py @@ -6,9 +6,12 @@ from threedi_modelchecker.config import CHECKS from threedi_modelchecker.model_checks import ( BaseCheck, + get_epsg_data_from_raster, LocalContext, ThreediModelChecker, ) +from threedi_modelchecker.tests import factories +from threedi_modelchecker.tests.test_checks_raster import create_geotiff @pytest.fixture @@ -65,3 +68,20 @@ def test_individual_checks(threedi_db, check): with threedi_db.get_session() as session: session.model_checker_context = LocalContext(base_path=threedi_db.base_path) assert len(check.get_invalid(session)) == 0 + + +def test_get_epsg_data_from_raster_empty_schema(session): + epsg_code, epsg_name = get_epsg_data_from_raster(session) + assert epsg_code is None + assert epsg_name == "" + + +def test_get_epsg_data_from_raster(session, threedi_db, tmp_path): + old_context_path = session.model_checker_context.base_path + session.model_checker_context.base_path = tmp_path + create_geotiff(tmp_path / "rasters" / "raster.tiff", epsg=28992) + factories.ModelSettingsFactory(dem_file="raster.tiff") + epsg_code, epsg_name = get_epsg_data_from_raster(session) + assert epsg_code == 28992 + assert epsg_name == "model_settings.dem_file" + session.model_checker_context.base_path = old_context_path