From c0b5bb20f5c63f2367b09a79f9f09d0572d4eee5 Mon Sep 17 00:00:00 2001 From: Nicky Bulthuis Date: Fri, 28 Jun 2024 14:21:27 +0200 Subject: [PATCH] Adds Year and Month support to Duration (#53) * Added support for Duration PxY (year) and PxM (month) --- CHANGELOG.rst | 11 +++- odata_query/ast.py | 18 ++++-- odata_query/grammar.py | 2 +- odata_query/sql/base.py | 6 +- .../django/test_odata_to_django_q.py | 14 +++++ .../sql/test_odata_to_athena_sql.py | 8 +++ tests/integration/sql/test_odata_to_sql.py | 16 +++++ tests/integration/sql/test_odata_to_sqlite.py | 8 +++ .../test_odata_to_sqlalchemy_core.py | 12 ++++ .../test_odata_to_sqlalchemy_orm.py | 12 ++++ tests/unit/test_odata_parser.py | 59 +++++++++++-------- 11 files changed, 132 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 11ce4cf..0d2510e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,13 @@ All notable changes to this project will be documented in this file. The format is based on `Keep a Changelog `_\ , and this project adheres to `Semantic Versioning `_. +[Unreleased] +--------------------- + +Added +^^^^^ + +* Parser: Added support for durations specified in years and months. [0.10.0] - 2024-01-21 --------------------- @@ -35,9 +42,9 @@ Fixed * SQLAlchemy functions defined by ``odata_query`` - - no longer clash with other functions defined in ``sqlalchemy.func`` with + - no longer clash with other functions defined in ``sqlalchemy.func`` with the same name. - - inherit cache to prevent SQLAlchemy performance warnings. + - inherit cache to prevent SQLAlchemy performance warnings. [0.8.1] - 2023-02-17 diff --git a/odata_query/ast.py b/odata_query/ast.py index 0e268d5..918f757 100644 --- a/odata_query/ast.py +++ b/odata_query/ast.py @@ -6,7 +6,7 @@ from dateutil.parser import isoparse -DURATION_PATTERN = re.compile(r"([+-])?P(\d+D)?(?:T(\d+H)?(\d+M)?(\d+(?:\.\d+)?S)?)?") +DURATION_PATTERN = re.compile(r"([+-])?P(\d+Y)?(\d+M)?(\d+D)?(?:T(\d+H)?(\d+M)?(\d+(?:\.\d+)?S)?)?") @dataclass(frozen=True) @@ -125,7 +125,15 @@ class Duration(_Literal): @property def py_val(self) -> dt.timedelta: - sign, days, hours, minutes, seconds = self.unpack() + sign, years, months, days, hours, minutes, seconds = self.unpack() + + # Initialize days to 0 if None + days = float(days or 0) + + # Approximate conversion, adjust as necessary for more precision + days += float(years or 0) * 365.25 # Average including leap years + days += float(months or 0) * 30.44 # Average month length + delta = dt.timedelta( days=float(days or 0), hours=float(hours or 0), @@ -150,14 +158,16 @@ def unpack( if not match: raise ValueError(f"Could not unpack Duration with value {self.val}") - sign, days, hours, minutes, seconds = match.groups() + sign, years, months, days, hours, minutes, seconds = match.groups() + _years = years[:-1] if years else None + _months = months[:-1] if months else None _days = days[:-1] if days else None _hours = hours[:-1] if hours else None _minutes = minutes[:-1] if minutes else None _seconds = seconds[:-1] if seconds else None - return sign, _days, _hours, _minutes, _seconds + return sign, _years, _months, _days, _hours, _minutes, _seconds @dataclass(frozen=True) diff --git a/odata_query/grammar.py b/odata_query/grammar.py index 3f88cbc..88035f9 100644 --- a/odata_query/grammar.py +++ b/odata_query/grammar.py @@ -119,7 +119,7 @@ def error(self, token: Token): # Primitive literals #################################################################################### - @_(r"duration'[+-]?P(?:\d+D)?(?:T(?:\d+H)?(?:\d+M)?(?:\d+(?:\.\d+)?S)?)?'") + @_(r"duration'[+-]?P(?:\d+Y)?(?:\d+M)?(?:\d+D)?(?:T(?:\d+H)?(?:\d+M)?(?:\d+(?:\.\d+)?S)?)?'") def DURATION(self, t): ":meta private:" val = t.value.upper() diff --git a/odata_query/sql/base.py b/odata_query/sql/base.py index 239421c..a305f07 100644 --- a/odata_query/sql/base.py +++ b/odata_query/sql/base.py @@ -65,10 +65,14 @@ def visit_DateTime(self, node: ast.DateTime) -> str: def visit_Duration(self, node: ast.Duration) -> str: ":meta private:" - sign, days, hours, minutes, seconds = node.unpack() + sign, years, months, days, hours, minutes, seconds = node.unpack() sign = sign or "" intervals = [] + if years: + intervals.append(f"INTERVAL '{years}' YEAR") + if months: + intervals.append(f"INTERVAL '{months}' MONTH") if days: intervals.append(f"INTERVAL '{days}' DAY") if hours: diff --git a/tests/integration/django/test_odata_to_django_q.py b/tests/integration/django/test_odata_to_django_q.py index c442748..2092785 100644 --- a/tests/integration/django/test_odata_to_django_q.py +++ b/tests/integration/django/test_odata_to_django_q.py @@ -130,6 +130,20 @@ def tz(offset: int) -> dt.tzinfo: + Value(dt.timedelta(days=1, hours=1, minutes=1, seconds=1)) ), ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P1Y'", + Q( + published_at__exact=Value(dt.datetime(2019, 1, 1, 0, 0, 0)) + + Value(dt.timedelta(days=365.25)) # 1 times 365.25 (average year in days) + ), + ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P2M'", + Q( + published_at__exact=Value(dt.datetime(2019, 1, 1, 0, 0, 0)) + + Value(dt.timedelta(days=60.88)) # 2 times 30.44 (average month in days) + ), + ), ("contains(title, 'copy')", Q(title__contains=Value("copy"))), ("startswith(title, 'copy')", Q(title__startswith=Value("copy"))), ("endswith(title, 'bla')", Q(title__endswith=Value("bla"))), diff --git a/tests/integration/sql/test_odata_to_athena_sql.py b/tests/integration/sql/test_odata_to_athena_sql.py index 7ca5788..4bca330 100644 --- a/tests/integration/sql/test_odata_to_athena_sql.py +++ b/tests/integration/sql/test_odata_to_athena_sql.py @@ -53,6 +53,14 @@ "period_start add duration'P365D' ge period_end", '"period_start" + INTERVAL \'365\' DAY >= "period_end"', ), + ( + "period_start add duration'P1Y' ge period_end", + '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', + ), + ( + "period_start add duration'P2M' ge period_end", + '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', + ), ( "period_start add duration'P365DT12H1M1.1S' ge period_end", "\"period_start\" + (INTERVAL '365' DAY + INTERVAL '12' HOUR + INTERVAL '1' MINUTE + INTERVAL '1.1' SECOND) >= \"period_end\"", diff --git a/tests/integration/sql/test_odata_to_sql.py b/tests/integration/sql/test_odata_to_sql.py index 7c1b250..7c0ab8a 100644 --- a/tests/integration/sql/test_odata_to_sql.py +++ b/tests/integration/sql/test_odata_to_sql.py @@ -63,11 +63,27 @@ "period_start add duration'PT1S' ge period_end", '"period_start" + INTERVAL \'1\' SECOND >= "period_end"', ), + ( + "period_start add duration'P1Y' ge period_end", + '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', + ), + ( + "period_start add duration'P2M' ge period_end", + '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', + ), ("year(period_start) eq 2019", 'EXTRACT (YEAR FROM "period_start") = 2019'), ( "period_end lt now() sub duration'P365D'", "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '365' DAY", ), + ( + "period_end lt now() sub duration'P1Y'", + "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '1' YEAR", + ), + ( + "period_end lt now() sub duration'P2M'", + "\"period_end\" < CURRENT_TIMESTAMP - INTERVAL '2' MONTH", + ), ( "startswith(trim(meter_id), '999')", "TRIM(\"meter_id\") LIKE '999%'", diff --git a/tests/integration/sql/test_odata_to_sqlite.py b/tests/integration/sql/test_odata_to_sqlite.py index 31ceb5b..71aa34c 100644 --- a/tests/integration/sql/test_odata_to_sqlite.py +++ b/tests/integration/sql/test_odata_to_sqlite.py @@ -61,6 +61,14 @@ "period_start add duration'PT1S' ge period_end", '"period_start" + INTERVAL \'1\' SECOND >= "period_end"', ), + ( + "period_start add duration'P1Y' ge period_end", + '"period_start" + INTERVAL \'1\' YEAR >= "period_end"', + ), + ( + "period_start add duration'P2M' ge period_end", + '"period_start" + INTERVAL \'2\' MONTH >= "period_end"', + ), ( "year(period_start) eq 2019", "CAST(STRFTIME('%Y', \"period_start\") AS INTEGER) = 2019", diff --git a/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py b/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py index e82f2a8..8303100 100644 --- a/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py +++ b/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_core.py @@ -138,6 +138,18 @@ def tz(offset: int) -> dt.tzinfo: == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + dt.timedelta(days=1, hours=1, minutes=1, seconds=1), ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P1Y'", + BlogPost.c.published_at + == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + + dt.timedelta(days=365.25), # 1 times 365.25 (average year in days) + ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P2M'", + BlogPost.c.published_at + == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + + dt.timedelta(days=60.88), # 2 times 30.44 (average month in days) + ), ("contains(title, 'copy')", BlogPost.c.title.contains("copy")), ("startswith(title, 'copy')", BlogPost.c.title.startswith("copy")), ("endswith(title, 'bla')", BlogPost.c.title.endswith("bla")), diff --git a/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py b/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py index 495d8ca..b392d2c 100644 --- a/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py +++ b/tests/integration/sqlalchemy/test_odata_to_sqlalchemy_orm.py @@ -119,6 +119,18 @@ def tz(offset: int) -> dt.tzinfo: == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + dt.timedelta(days=1, hours=1, minutes=1, seconds=1), ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P1Y'", + BlogPost.published_at + == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + + dt.timedelta(days=365.25), # 1 times 365.25 (average year in days) + ), + ( + "published_at eq 2019-01-01T00:00:00 add duration'P2M'", + BlogPost.published_at + == literal(dt.datetime(2019, 1, 1, 0, 0, 0)) + + dt.timedelta(days=60.88), # 2 times 30.44 (average month in days) + ), ("contains(title, 'copy')", BlogPost.title.contains("copy")), ("startswith(title, 'copy')", BlogPost.title.startswith("copy")), ("endswith(title, 'bla')", BlogPost.title.endswith("bla")), diff --git a/tests/unit/test_odata_parser.py b/tests/unit/test_odata_parser.py index a040fca..9560242 100644 --- a/tests/unit/test_odata_parser.py +++ b/tests/unit/test_odata_parser.py @@ -75,32 +75,36 @@ def test_primitive_literal_parsing(value: str, expected_type: type): @pytest.mark.parametrize( "value, expected_unpacked", [ - ("duration'P12DT23H59M59.9S'", (None, "12", "23", "59", "59.9")), - ("duration'-P12DT23H59M59.9S'", ("-", "12", "23", "59", "59.9")), - ("duration'P12D'", (None, "12", None, None, None)), - ("duration'-P12D'", ("-", "12", None, None, None)), - ("duration'PT23H59M59.9S'", (None, None, "23", "59", "59.9")), - ("duration'-PT23H59M59.9S'", ("-", None, "23", "59", "59.9")), - ("duration'PT23H59M59.9S'", (None, None, "23", "59", "59.9")), - ("duration'-PT23H59M59.9S'", ("-", None, "23", "59", "59.9")), - ("duration'PT23H59M59S'", (None, None, "23", "59", "59")), - ("duration'-PT23H59M59S'", ("-", None, "23", "59", "59")), - ("duration'PT23H'", (None, None, "23", None, None)), - ("duration'-PT23H'", ("-", None, "23", None, None)), - ("duration'PT59M'", (None, None, None, "59", None)), - ("duration'-PT59M'", ("-", None, None, "59", None)), - ("duration'PT59S'", (None, None, None, None, "59")), - ("duration'-PT59S'", ("-", None, None, None, "59")), - ("duration'PT59.9S'", (None, None, None, None, "59.9")), - ("duration'-PT59.9S'", ("-", None, None, None, "59.9")), - ("duration'P12DT23H'", (None, "12", "23", None, None)), - ("duration'-P12DT23H'", ("-", "12", "23", None, None)), - ("duration'P12DT59M'", (None, "12", None, "59", None)), - ("duration'-P12DT59M'", ("-", "12", None, "59", None)), - ("duration'P12DT59S'", (None, "12", None, None, "59")), - ("duration'-P12DT59S'", ("-", "12", None, None, "59")), - ("duration'P12DT59.9S'", (None, "12", None, None, "59.9")), - ("duration'-P12DT59.9S'", ("-", "12", None, None, "59.9")), + ("duration'P12DT23H59M59.9S'", (None, None, None, "12", "23", "59", "59.9")), + ("duration'-P12DT23H59M59.9S'", ("-", None, None, "12", "23", "59", "59.9")), + ("duration'P12D'", (None, None, None, "12", None, None, None)), + ("duration'-P12D'", ("-", None, None, "12", None, None, None)), + ("duration'PT23H59M59.9S'", (None, None, None, None, "23", "59", "59.9")), + ("duration'-PT23H59M59.9S'", ("-", None, None, None, "23", "59", "59.9")), + ("duration'PT23H59M59.9S'", (None, None, None, None, "23", "59", "59.9")), + ("duration'-PT23H59M59.9S'", ("-", None, None, None, "23", "59", "59.9")), + ("duration'PT23H59M59S'", (None, None, None, None, "23", "59", "59")), + ("duration'-PT23H59M59S'", ("-", None, None, None, "23", "59", "59")), + ("duration'PT23H'", (None, None, None, None, "23", None, None)), + ("duration'-PT23H'", ("-", None, None, None, "23", None, None)), + ("duration'PT59M'", (None, None, None, None, None, "59", None)), + ("duration'-PT59M'", ("-", None, None, None, None, "59", None)), + ("duration'PT59S'", (None, None, None, None, None, None, "59")), + ("duration'-PT59S'", ("-", None, None, None, None, None, "59")), + ("duration'PT59.9S'", (None, None, None, None, None, None, "59.9")), + ("duration'-PT59.9S'", ("-", None, None, None, None, None, "59.9")), + ("duration'P12DT23H'", (None, None, None, "12", "23", None, None)), + ("duration'-P12DT23H'", ("-", None, None, "12", "23", None, None)), + ("duration'P12DT59M'", (None, None, None, "12", None, "59", None)), + ("duration'-P12DT59M'", ("-", None, None, "12", None, "59", None)), + ("duration'P12DT59S'", (None, None, None, "12", None, None, "59")), + ("duration'-P12DT59S'", ("-", None, None, "12", None, None, "59")), + ("duration'P12DT59.9S'", (None, None, None, "12", None, None, "59.9")), + ("duration'-P12DT59.9S'", ("-", None, None, "12", None, None, "59.9")), + ("duration'-P1Y'", ("-", "1", None, None, None, None, None)), + ("duration'P1Y'", (None, "1", None, None, None, None, None)), + ("duration'-P12M'", ("-", None, "12", None, None, None, None)), + ("duration'P2Y3M'", (None, "2", "3", None, None, None, None)), ], ) def test_duration_parsing(value: str, expected_unpacked: tuple): @@ -166,6 +170,9 @@ def test_geography_literal_parsing(value: str, expected: str): dt.timedelta(days=12, hours=23, minutes=59, seconds=59.9), ), ("duration'P12D'", dt.timedelta(days=12)), + ("duration'P1Y'", dt.timedelta(days=365.25)), # Average including leap years + ("duration'P1M'", dt.timedelta(days=30.44)), # Average month length + ("duration'P1M2D'", dt.timedelta(days=32.44)), ], ) def test_python_value_of_literals(odata_val: str, exp_py_val):