Skip to content

Commit

Permalink
Support reprojection (#432)
Browse files Browse the repository at this point in the history
* Adapt to schema 230 where all geometries use the model CRS and model_settings.epsg_code is no longer available

* Remove checks for model_settings.epsg_code (317 and 318)

* Remove usage of epsg 4326 in the tests because this CRS is no longer valid

* Remove no longer needed transformations

* Add checks for mathing epsg in all geometries and raster files

* Add checks for valid epsg (existing code, projected, in meters) which requires pyproj

* Change ConnectionNodeCheck (201) to require minimum distance of 10cm
  • Loading branch information
margrietpalm authored Jan 16, 2025
1 parent 81cad44 commit 0f96faf
Show file tree
Hide file tree
Showing 19 changed files with 494 additions and 459 deletions.
12 changes: 8 additions & 4 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
Changelog of threedi-modelchecker
=================================



2.15.1 (unreleased)
2.16.0 (unreleased)
-------------------

- Nothing changed yet.
- Adapt to schema 230 where all geometries use the model CRS and model_settings.epsg_code is no longer available
- Remove checks for model_settings.epsg_code (317 and 318)
- Remove usage of epsg 4326 in the tests because this CRS is no longer valid
- Remove no longer needed transformations
- Add checks for mathing epsg in all geometries and raster files
- Add checks for valid epsg (existing code, projected, in meters) which requires pyproj
- Change ConnectionNodeCheck (201) to require minimum distance of 10cm


2.15.0 (2025-01-08)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ dependencies = [
"Click",
"GeoAlchemy2>=0.9,!=0.11.*",
"SQLAlchemy>=1.4",
"threedi-schema==0.229.*"
"pyproj",
"threedi-schema==0.230.*"
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion threedi_modelchecker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .model_checks import * # NOQA

# fmt: off
__version__ = '2.15.1.dev0'
__version__ = '2.16.0.dev0'
# fmt: on
24 changes: 24 additions & 0 deletions threedi_modelchecker/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import IntEnum
from typing import List, NamedTuple

from geoalchemy2.functions import ST_SRID
from sqlalchemy import and_, false, func, types
from sqlalchemy.orm.session import Session
from threedi_schema.domain import custom_types
Expand Down Expand Up @@ -412,3 +413,26 @@ def description(self) -> str:
return (
f"{self.table.name}.{self.column} is not a comma seperated list of integers"
)


class EPSGGeomCheck(BaseCheck):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ref_epsg_name = ""

def get_invalid(self, session: Session) -> List[NamedTuple]:
self.ref_epsg_name = session.ref_epsg_name
if session.ref_epsg_code is None:
return []
return (
self.to_check(session)
.filter(ST_SRID(self.column) != session.ref_epsg_code)
.all()
)

def description(self) -> str:
return f"The epsg of {self.table.name}.{self.column_name} should match {self.ref_epsg_name}"
24 changes: 24 additions & 0 deletions threedi_modelchecker/checks/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from .base import (
EnumCheck,
EPSGGeomCheck,
ForeignKeyCheck,
GeometryCheck,
GeometryTypeCheck,
NotNullCheck,
TypeCheck,
UniqueCheck,
)
from .raster import RasterHasMatchingEPSGCheck


@dataclass
Expand Down Expand Up @@ -124,3 +126,25 @@ def generate_enum_checks(table, custom_level_map=None, **kwargs):
level = get_level(table, column, custom_level_map)
enum_checks.append(EnumCheck(column, level=level, **kwargs))
return enum_checks


def generate_epsg_geom_checks(table, custom_level_map=None, **kwargs):
custom_level_map = custom_level_map or {}
column = table.columns.get("geom")
if column is not None:
level = get_level(table, column, custom_level_map)
return [EPSGGeomCheck(column=column, level=level, **kwargs)]
else:
return []


def generate_epsg_raster_checks(table, raster_columns, custom_level_map=None, **kwargs):
custom_level_map = custom_level_map or {}
checks = []
for column in table.columns:
if column in raster_columns:
level = get_level(table, column, custom_level_map)
checks.append(
RasterHasMatchingEPSGCheck(column=column, level=level, **kwargs)
)
return checks
24 changes: 0 additions & 24 deletions threedi_modelchecker/checks/geo_query.py

This file was deleted.

19 changes: 8 additions & 11 deletions threedi_modelchecker/checks/location.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import List, NamedTuple

from sqlalchemy import func
from geoalchemy2.functions import ST_Distance, ST_NPoints, ST_PointN
from sqlalchemy.orm import aliased, Session
from threedi_schema.domain import models

from threedi_modelchecker.checks.base import BaseCheck
from threedi_modelchecker.checks.geo_query import distance


class PointLocationCheck(BaseCheck):
Expand All @@ -32,7 +31,7 @@ def get_invalid(self, session):
self.ref_table,
self.ref_table.id == self.ref_column,
)
.filter(distance(self.column, self.ref_table.geom) > self.max_distance)
.filter(ST_Distance(self.column, self.ref_table.geom) > self.max_distance)
.all()
)

Expand Down Expand Up @@ -70,15 +69,13 @@ def __init__(
def get_invalid(self, session: Session) -> List[NamedTuple]:
start_node = aliased(self.ref_table_start)
end_node = aliased(self.ref_table_end)

tol = self.max_distance
start_point = func.ST_PointN(self.column, 1)
end_point = func.ST_PointN(self.column, func.ST_NPoints(self.column))

start_ok = distance(start_point, start_node.geom) <= tol
end_ok = distance(end_point, end_node.geom) <= tol
start_ok_if_reversed = distance(end_point, start_node.geom) <= tol
end_ok_if_reversed = distance(start_point, end_node.geom) <= tol
start_point = ST_PointN(self.column, 1)
end_point = ST_PointN(self.column, ST_NPoints(self.column))
start_ok = ST_Distance(start_point, start_node.geom) <= tol
end_ok = ST_Distance(end_point, end_node.geom) <= tol
start_ok_if_reversed = ST_Distance(end_point, start_node.geom) <= tol
end_ok_if_reversed = ST_Distance(start_point, end_node.geom) <= tol
return (
self.to_check(session)
.join(start_node, start_node.id == self.ref_column_start)
Expand Down
115 changes: 85 additions & 30 deletions threedi_modelchecker/checks/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass
from typing import List, Literal, NamedTuple

import pyproj
from geoalchemy2.functions import ST_Distance, ST_Length
from sqlalchemy import (
and_,
case,
Expand All @@ -19,7 +21,6 @@

from .base import BaseCheck, CheckLevel
from .cross_section_definitions import cross_section_configuration_for_record
from .geo_query import distance, length, transform


class CorrectAggregationSettingsExist(BaseCheck):
Expand Down Expand Up @@ -258,8 +259,7 @@ def get_invalid(self, session):

class ConnectionNodesLength(BaseCheck):
"""Check that the distance between `start_node` and `end_node` is at least
`min_distance`. The coords will be transformed into (the first entry) of
ModelSettings.epsg_code.
`min_distance`.
"""

def __init__(
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_invalid(self, session):
self.to_check(session)
.join(start_node, start_node.id == self.start_node)
.join(end_node, end_node.id == self.end_node)
.filter(distance(start_node.geom, end_node.geom) < self.min_distance)
.filter(ST_Distance(start_node.geom, end_node.geom) < self.min_distance)
)
return list(q.with_session(session).all())

Expand Down Expand Up @@ -321,30 +321,26 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:
distance between all connection nodes.
"""
query = text(
f"""SELECT *
FROM connection_node AS cn1, connection_node AS cn2
WHERE
distance(cn1.geom, cn2.geom, 1) < :min_distance
AND cn1.ROWID != cn2.ROWID
AND cn2.ROWID IN (
SELECT ROWID
FROM SpatialIndex
WHERE (
f_table_name = "connection_node"
AND search_frame = Buffer(cn1.geom, {self.minimum_distance / 2})));
f"""
SELECT *
FROM connection_node AS cn1, connection_node AS cn2
WHERE ST_Distance(cn1.geom, cn2.geom) < {self.minimum_distance}
AND cn1.ROWID != cn2.ROWID
AND cn2.ROWID IN (
SELECT ROWID
FROM SpatialIndex
WHERE (
f_table_name = "connection_node"
AND search_frame = Buffer(cn1.geom, {self.minimum_distance / 2})))
"""
)
results = (
session.connection()
.execute(query, {"min_distance": self.minimum_distance})
.fetchall()
)

results = session.connection().execute(query).fetchall()
return results

def description(self) -> str:

return (
f"The connection_node is within {self.minimum_distance} degrees of "
f"The connection_node is within {self.minimum_distance * 100 } cm of "
f"another connection_node."
)

Expand Down Expand Up @@ -548,10 +544,10 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:
linestring = models.Channel.geom
tol = self.min_distance
breach_point = func.Line_Locate_Point(
transform(linestring), transform(func.ST_PointN(self.column, 1))
linestring, func.ST_PointN(self.column, 1)
)
dist_1 = breach_point * length(linestring)
dist_2 = (1 - breach_point) * length(linestring)
dist_1 = breach_point * ST_Length(linestring)
dist_2 = (1 - breach_point) * ST_Length(linestring)
return (
self.to_check(session)
.join(models.Channel, self.table.c.channel_id == models.Channel.id)
Expand All @@ -577,10 +573,8 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:

# First fetch the position of each potential breach per channel
def get_position(point, linestring):
breach_point = func.Line_Locate_Point(
transform(linestring), transform(func.ST_PointN(point, 1))
)
return (breach_point * length(linestring)).label("position")
breach_point = func.Line_Locate_Point(linestring, func.ST_PointN(point, 1))
return (breach_point * ST_Length(linestring)).label("position")

potential_breaches = sorted(
session.query(self.table, get_position(self.column, models.Channel.geom))
Expand Down Expand Up @@ -776,7 +770,7 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:
self.table.c.id,
self.table.c.area,
self.table.c.geom,
func.ST_Area(transform(self.table.c.geom)).label("calculated_area"),
func.ST_Area(self.table.c.geom).label("calculated_area"),
).subquery()
return (
session.query(all_results)
Expand Down Expand Up @@ -1114,3 +1108,64 @@ def get_invalid(self, session: Session) -> List[NamedTuple]:

def description(self) -> str:
return f"The values in {self.table.name}.{self.column_name} should add up to"


class ModelEPSGCheckValid(BaseCheck):
def __init__(self, *args, **kwargs):
super().__init__(column=models.ModelSettings.id, *args, **kwargs)
self.epsg_code = None

def get_invalid(self, session: Session) -> List[NamedTuple]:
self.epsg_code = session.ref_epsg_code
if self.epsg_code is not None:
try:
pyproj.CRS.from_epsg(self.epsg_code)
except pyproj.exceptions.CRSError:
return self.to_check(session).all()
return []

def description(self) -> str:
return f"Found invalid EPSG: {self.epsg_code}"


class ModelEPSGCheckProjected(BaseCheck):
def __init__(self, *args, **kwargs):
super().__init__(column=models.ModelSettings.id, *args, **kwargs)
self.epsg_code = None

def get_invalid(self, session: Session) -> List[NamedTuple]:
self.epsg_code = session.ref_epsg_code
if self.epsg_code is not None:
try:
crs = pyproj.CRS.from_epsg(self.epsg_code)
except pyproj.exceptions.CRSError:
# handled by ModelEPSGCheckValid
return []
if not crs.is_projected:
return self.to_check(session).all()
return []

def description(self) -> str:
return f"EPSG {self.epsg_code} is not projected"


class ModelEPSGCheckUnits(BaseCheck):
def __init__(self, *args, **kwargs):
super().__init__(column=models.ModelSettings.id, *args, **kwargs)
self.epsg_code = None

def get_invalid(self, session: Session) -> List[NamedTuple]:
self.epsg_code = session.ref_epsg_code
if self.epsg_code is not None:
try:
crs = pyproj.CRS.from_epsg(self.epsg_code)
except pyproj.exceptions.CRSError:
# handled by ModelEPSGCheckValid
return []
for ax in crs.axis_info:
if not ax.unit_name == "metre":
return self.to_check(session).all()
return []

def description(self) -> str:
return f"EPSG {self.epsg_code} is not fully defined in metres"
Loading

0 comments on commit 0f96faf

Please sign in to comment.