diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e1a4db53..a2143e61 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,6 @@ on: - 'v*.*.*' concurrency: group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true jobs: start-runner: name: Start self-hosted EC2 runner @@ -17,7 +16,7 @@ jobs: ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v3 with: aws-access-key-id: ${{ secrets.EC2_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.EC2_ACCESS_KEY_SECRET }} @@ -73,6 +72,7 @@ jobs: manylinux: auto rustup-components: rust-std working-directory: . + before-script-linux: yum install -y perl-core - name: Build package macOS Apple silicon if: ${{ matrix.os == 'ec2-macOS'}} @@ -224,7 +224,7 @@ jobs: timeout-minutes: 15 steps: - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v3 with: aws-access-key-id: ${{ secrets.ARTIFACT_AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.ARTIFACT_AWS_SECRET_ACCESS_KEY }} @@ -293,7 +293,7 @@ jobs: timeout-minutes: 15 steps: - name: Configure AWS Credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v3 with: aws-access-key-id: ${{ secrets.ARTIFACT_AWS_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.ARTIFACT_AWS_SECRET_ACCESS_KEY }} @@ -328,7 +328,7 @@ jobs: if: ${{ always() }} steps: - name: Configure AWS credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v3 with: aws-access-key-id: ${{ secrets.EC2_ACCESS_KEY_ID }} aws-secret-access-key: ${{ secrets.EC2_ACCESS_KEY_SECRET }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 207464e6..92fb0abd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,21 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +## [0.3.3] - 2023-09-14 + +### Added +- Module `pathway.dt` to construct and manipulate DTypes. +- New argument `keep_queries` in `pw.io.http.rest_connector`. + +### Changed +- Internal representation of DTypes. Inputting types is compatible backwards. +- Temporal functions now accept arguments of mixed types (ints and floats). For example, `pw.temporal.interval` can use ints while columns it interacts with are floats. +- Single-element arrays are now treated as arrays, not as scalars. + +### Fixed +- `to_string()` method on datetimes always prints 9 fractional digits. +- `%f` format code in `strptime()` parses fractional part of a second correctly regardless of the number of digits. + ## [0.3.2] - 2023-09-07 ### Added diff --git a/Cargo.lock b/Cargo.lock index 1e07ddfb..74186311 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,9 +243,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" @@ -264,9 +264,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.28" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95ed24df0632f708f5f6d8082675bef2596f7084dee3dd55f632290bf35bfe0f" +checksum = "defd4e7873dbddba6c7c91e199c7fcb946abc4a6a4ac3195400bcfb01b5de877" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1112,9 +1112,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "lock_api" @@ -1446,7 +1446,7 @@ dependencies = [ [[package]] name = "pathway" -version = "0.3.2" +version = "0.3.3" dependencies = [ "arc-swap", "arcstr", @@ -1995,9 +1995,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.11" +version = "0.38.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0c3dde1fc030af041adc40e79c0e7fbcf431dd24870053d187d7c66e4b87453" +checksum = "bdf14a7a466ce88b5eac3da815b53aefc208ce7e74d1c263aabb04d88c4abeb1" dependencies = [ "bitflags 2.4.0", "errno", @@ -2093,9 +2093,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" dependencies = [ "itoa", "ryu", @@ -2476,9 +2476,9 @@ checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.0.0", "toml_datetime", @@ -2816,6 +2816,6 @@ dependencies = [ [[package]] name = "xxhash-rust" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "735a71d46c4d68d71d4b24d03fdc2b98e38cea81730595801db779c04fe80d70" +checksum = "9828b178da53440fa9c766a3d2f73f7cf5d0ac1fe3980c1e5018d899fd19e07b" diff --git a/Cargo.toml b/Cargo.toml index 55ca9fba..ada39a8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pathway" -version = "0.3.2" +version = "0.3.3" edition = "2021" publish = false rust-version = "1.71.0" @@ -22,7 +22,7 @@ arcstr = { version = "1.1.5", default-features = false, features = ["serde", "st base32 = "0.4.0" bincode = "1.3.3" cfg-if = "1.0.0" -chrono = { version = "0.4.28", features = ["std", "clock"], default-features = false } +chrono = { version = "0.4.30", features = ["std", "clock"], default-features = false } chrono-tz = "0.8.3" crossbeam-channel = "0.5.8" csv = "1.2.2" @@ -60,7 +60,7 @@ thiserror = "1.0.48" timely = { path = "./external/timely-dataflow/timely", features = ["bincode"] } tokio = "1.32.0" typed-arena = "2.0.2" -xxhash-rust = { version = "0.8.6", features = ["xxh3"] } +xxhash-rust = { version = "0.8.7", features = ["xxh3"] } [target.'cfg(target_os = "linux")'.dependencies] inotify = "0.10.2" diff --git a/integration_tests/kafka/test_kafka.py b/integration_tests/kafka/test_kafka.py index 9aa72f6b..b9626597 100644 --- a/integration_tests/kafka/test_kafka.py +++ b/integration_tests/kafka/test_kafka.py @@ -5,31 +5,15 @@ import pathlib import random -import pandas as pd from utils import KafkaTestContext import pathway as pw from pathway.internals.parse_graph import G -from pathway.tests.utils import wait_result_with_checker, write_lines - - -def expect_csv_checker(expected, output_path, usecols=("k", "v"), index_col=("k")): - expected = ( - pw.debug._markdown_to_pandas(expected) - .set_index(index_col, drop=False) - .sort_index() - ) - - def checker(): - result = ( - pd.read_csv(output_path, usecols=[*usecols, *index_col]) - .convert_dtypes() - .set_index(index_col, drop=False) - .sort_index() - ) - return expected.equals(result) - - return checker +from pathway.tests.utils import ( + expect_csv_checker, + wait_result_with_checker, + write_lines, +) def test_kafka_raw(tmp_path: pathlib.Path, kafka_context: KafkaTestContext): diff --git a/integration_tests/webserver/test_rest_connector.py b/integration_tests/webserver/test_rest_connector.py index f04f5286..a9ae63f9 100644 --- a/integration_tests/webserver/test_rest_connector.py +++ b/integration_tests/webserver/test_rest_connector.py @@ -9,6 +9,7 @@ import pathway as pw from pathway.tests.utils import ( CsvLinesNumberChecker, + expect_csv_checker, wait_result_with_checker, xfail_on_darwin, ) @@ -126,3 +127,54 @@ def target(): t = threading.Thread(target=target, daemon=True) t.start() assert wait_result_with_checker(CsvLinesNumberChecker(output_path, 4), 30) + + +@xfail_on_darwin(reason="running pw.run from separate process not supported") +@pytest.mark.xdist_group(name="streaming_tests") +def test_server_keep_queries(tmp_path: pathlib.Path): + port = int(os.environ.get("PATHWAY_MONITORING_HTTP_PORT", "20000")) + 10003 + output_path = tmp_path / "output.csv" + + class InputSchema(pw.Schema): + k: int + v: int + + def target(): + time.sleep(5) + requests.post( + f"http://127.0.0.1:{port}/", + json={"k": 1, "v": 1}, + ).raise_for_status() + requests.post( + f"http://127.0.0.1:{port}/", + json={"k": 1, "v": 2}, + ).raise_for_status() + + queries, response_writer = pw.io.http.rest_connector( + host="127.0.0.1", port=port, schema=InputSchema, keep_queries=True + ) + response_writer(queries.select(query_id=queries.id, result=pw.this.v)) + + sum = queries.groupby(pw.this.k).reduce( + key=pw.this.k, sum=pw.reducers.sum(pw.this.v) + ) + + pw.io.csv.write(sum, output_path) + + t = threading.Thread(target=target, daemon=True) + t.start() + + assert wait_result_with_checker( + expect_csv_checker( + """ + key | sum | diff + 1 | 1 | 1 + 1 | 1 | -1 + 1 | 3 | 1 + """, + output_path, + usecols=["sum", "diff"], + index_col=["key"], + ), + 10, + ) diff --git a/python/pathway/__init__.py b/python/pathway/__init__.py index 2dba4e4a..20ed6fca 100644 --- a/python/pathway/__init__.py +++ b/python/pathway/__init__.py @@ -6,6 +6,8 @@ import os +from pathway.internals.dtype import DATE_TIME_NAIVE, DATE_TIME_UTC, DURATION + # flake8: noqa: E402 if "PYTEST_CURRENT_TEST" not in os.environ: @@ -16,15 +18,13 @@ filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning) +import pathway.dt as dt import pathway.reducers as reducers from pathway import debug, demo, io from pathway.internals import ( ClassArg, ColumnExpression, ColumnReference, - DateTimeNaive, - DateTimeUtc, - Duration, FilteredJoinResult, GroupedJoinResult, GroupedTable, @@ -143,9 +143,9 @@ "column_definition", "TableSlice", "demo", - "DateTimeNaive", - "DateTimeUtc", - "Duration", + "DATE_TIME_NAIVE", + "DATE_TIME_UTC", + "DURATION", "unwrap", "SchemaProperties", ] diff --git a/python/pathway/demo/__init__.py b/python/pathway/demo/__init__.py index 0f01a885..731b9117 100644 --- a/python/pathway/demo/__init__.py +++ b/python/pathway/demo/__init__.py @@ -6,7 +6,7 @@ ... name: str ... age: int >>> pw.demo.replay_csv("./input_stream.csv", schema=InputSchema) -Table{'name': , 'age': } +Table{'name': STR, 'age': INT} """ # Copyright © 2023 Pathway @@ -55,19 +55,20 @@ def generate_custom_stream( Example: >>> value_functions = { - ... 'id': lambda x: x + 1, + ... 'number': lambda x: x + 1, ... 'name': lambda x: f'Person {x}', ... 'age': lambda x: 20 + x, ... } >>> class InputSchema(pw.Schema): - ... id: int + ... number: int ... name: str ... age: int >>> pw.demo.generate_custom_stream(value_functions, schema=InputSchema, nb_rows=10) - Table{'id': , 'name': , 'age': } + Table{'number': INT, 'name': STR, 'age': INT} - In the above example, a data stream is generated with 10 rows, where each row has columns 'id', 'name', and 'age'. - The 'id' column contains values incremented by 1 from 1 to 10, the 'name' column contains 'Person' + In the above example, a data stream is generated with 10 rows, where each row has columns \ + 'number', 'name', and 'age'. + The 'number' column contains values incremented by 1 from 1 to 10, the 'name' column contains 'Person' followed by the respective row index, and the 'age' column contains values starting from 20 incremented by the row index. """ @@ -133,14 +134,14 @@ def noisy_linear_stream(nb_rows: int = 10, input_rate: float = 1.0) -> pw.Table: random.seed(0) def _get_value(i): - return i + (2 * random.random() - 1) / 10 + return float(i + (2 * random.random() - 1) / 10) class InputSchema(pw.Schema): x: float = pw.column_definition(primary_key=True) y: float value_generators = { - "x": (lambda x: x), + "x": (lambda x: float(x)), "y": _get_value, } autocommit_duration_ms = 1000 @@ -183,7 +184,7 @@ class InputSchema(pw.Schema): value: float value_generators = { - "value": (lambda x: x + offset), + "value": (lambda x: float(x + offset)), } autocommit_duration_ms = 1000 return generate_custom_stream( @@ -235,10 +236,10 @@ def run(self): return pw.io.python.read( FileStreamSubject(), - schema=schema, + schema=schema.update_types(**{name: str for name in schema.column_names()}), autocommit_duration_ms=autocommit_ms, format="json", - ) + ).cast_to_types(**schema) def replay_csv_with_time( @@ -271,6 +272,10 @@ def replay_csv_with_time( """ + time_column_type = schema.as_dict().get(time_column, None) + if time_column_type != pw.dt.INT: + raise ValueError("Invalid schema. Time columns must be int.") + if unit not in ["s", "ms", "us", "ns"]: raise ValueError( "demo.replay_csv_with_time: unit should be either 's', 'ms, 'us', or 'ns'." @@ -296,7 +301,7 @@ def run(self): csvreader = csv.DictReader(csvfile) for row in csvreader: values = {key: row[key] for key in columns} - current_value = float(values[time_column]) + current_value = int(values[time_column]) if last_time_value < 0: last_time_value = current_value tts = current_value - last_time_value @@ -308,7 +313,7 @@ def run(self): return pw.io.python.read( FileStreamSubject(), - schema=schema, + schema=schema.update_types(**{name: str for name in schema.column_names()}), autocommit_duration_ms=autocommit_ms, format="json", - ) + ).cast_to_types(**schema) diff --git a/python/pathway/dt.py b/python/pathway/dt.py new file mode 100644 index 00000000..ce5d6afe --- /dev/null +++ b/python/pathway/dt.py @@ -0,0 +1,37 @@ +from pathway.internals.dtype import ( + ANY, + ANY_TUPLE, + BOOL, + DATE_TIME_NAIVE, + DATE_TIME_UTC, + DURATION, + FLOAT, + INT, + POINTER, + STR, + Array, + Callable, + Optional, + Pointer, + Tuple, + wrap, +) + +__all__ = [ + "ANY", + "ANY_TUPLE", + "BOOL", + "DATE_TIME_NAIVE", + "DATE_TIME_UTC", + "DURATION", + "FLOAT", + "INT", + "POINTER", + "STR", + "Array", + "Callable", + "Optional", + "Pointer", + "Tuple", + "wrap", +] diff --git a/python/pathway/internals/__init__.py b/python/pathway/internals/__init__.py index e04f8ff4..fd745471 100644 --- a/python/pathway/internals/__init__.py +++ b/python/pathway/internals/__init__.py @@ -20,7 +20,6 @@ udf_async, unwrap, ) -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration from pathway.internals.decorators import ( attribute, input_attribute, @@ -114,9 +113,9 @@ "schema_builder", "column_definition", "TableSlice", - "DateTimeNaive", - "DateTimeUtc", - "Duration", + "DATE_TIME_NAIVE", + "DATE_TIME_UTC", + "DURATION", "unwrap", "SchemaProperties", ] diff --git a/python/pathway/internals/_io_helpers.py b/python/pathway/internals/_io_helpers.py index 6cbefbf3..24cb8d02 100644 --- a/python/pathway/internals/_io_helpers.py +++ b/python/pathway/internals/_io_helpers.py @@ -4,8 +4,9 @@ from typing import List, Type -from pathway.internals import api, schema -from pathway.internals.dtype import unoptionalize +from pathway.internals import api +from pathway.internals import dtype as dt +from pathway.internals import schema from pathway.internals.table import Table @@ -24,9 +25,7 @@ def _form_value_fields(schema: Type[schema.Schema]) -> List[api.ValueField]: # XXX fix mapping schema types to PathwayType types = { - name: api._TYPES_TO_ENGINE_MAPPING.get( - unoptionalize(dtype), api.PathwayType.ANY - ) + name: dt.unoptionalize(dtype).to_engine() for name, dtype in schema.as_dict().items() } diff --git a/python/pathway/internals/api.py b/python/pathway/internals/api.py index 838c70cb..8159bfca 100644 --- a/python/pathway/internals/api.py +++ b/python/pathway/internals/api.py @@ -9,7 +9,7 @@ import pandas as pd from pathway.engine import * -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration +from pathway.internals import dtype as dt from pathway.internals.schema import Schema Value = Union[ @@ -31,9 +31,7 @@ # XXX: engine calls return BasePointer, not Pointer class Pointer(BasePointer, Generic[TSchema]): """Pointer to row type. - Example: - >>> import pathway as pw >>> t1 = pw.debug.parse_to_table(''' ... age | owner | pet @@ -43,8 +41,8 @@ class Pointer(BasePointer, Generic[TSchema]): ... 7 | Bob | dog ... ''') >>> t2 = t1.select(col=t1.id) - >>> t2.schema['col'] is pw.Pointer - True + >>> t2.schema['col'] + Pointer """ pass @@ -100,14 +98,14 @@ def denumpify(x): _TYPES_TO_ENGINE_MAPPING: Mapping[Any, PathwayType] = { - bool: PathwayType.BOOL, - int: PathwayType.INT, - float: PathwayType.FLOAT, - Pointer: PathwayType.POINTER, - str: PathwayType.STRING, - DateTimeNaive: PathwayType.DATE_TIME_NAIVE, - DateTimeUtc: PathwayType.DATE_TIME_UTC, - Duration: PathwayType.DURATION, - np.ndarray: PathwayType.ARRAY, - Any: PathwayType.ANY, + dt.BOOL: PathwayType.BOOL, + dt.INT: PathwayType.INT, + dt.FLOAT: PathwayType.FLOAT, + dt.POINTER: PathwayType.POINTER, + dt.STR: PathwayType.STRING, + dt.DATE_TIME_NAIVE: PathwayType.DATE_TIME_NAIVE, + dt.DATE_TIME_UTC: PathwayType.DATE_TIME_UTC, + dt.DURATION: PathwayType.DURATION, + dt.Array(): PathwayType.ARRAY, + dt.ANY: PathwayType.ANY, } diff --git a/python/pathway/internals/column.py b/python/pathway/internals/column.py index bbb588e8..2303f9e9 100644 --- a/python/pathway/internals/column.py +++ b/python/pathway/internals/column.py @@ -5,22 +5,12 @@ from abc import ABC from dataclasses import dataclass, field from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - Optional, - Tuple, - get_args, - get_origin, -) - -import numpy as np +from types import EllipsisType +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple import pathway.internals as pw -from pathway.internals import api, trace -from pathway.internals.dtype import DType, types_lca +from pathway.internals import dtype as dt +from pathway.internals import trace from pathway.internals.expression import ColumnExpression, ColumnRefOrIxExpression from pathway.internals.helpers import SetOnceProperty, StableSet @@ -62,12 +52,12 @@ def table(self) -> Table: class Column(ABC): - dtype: DType + dtype: dt.DType universe: Universe lineage: SetOnceProperty[ColumnLineage] = SetOnceProperty() """Lateinit by operator.""" - def __init__(self, dtype: DType, universe: Universe) -> None: + def __init__(self, dtype: dt.DType, universe: Universe) -> None: super().__init__() self.dtype = dtype self.universe = universe @@ -101,7 +91,7 @@ class ColumnWithContext(Column, ABC): context: Context - def __init__(self, dtype: DType, context: Context, universe: Universe): + def __init__(self, dtype: dt.DType, context: Context, universe: Universe): super().__init__(dtype, universe) self.context = context @@ -111,7 +101,7 @@ def column_dependencies(self) -> StableSet[Column]: class IdColumn(ColumnWithContext): def __init__(self, context: Context) -> None: - super().__init__(DType(api.Pointer), context, context.universe) + super().__init__(dt.POINTER, context, context.universe) class ColumnWithExpression(ColumnWithContext): @@ -205,7 +195,7 @@ def _get_type_interpreter(self): return TypeInterpreter() - def expression_type(self, expression: ColumnExpression) -> DType: + def expression_type(self, expression: ColumnExpression) -> dt.DType: return self.expression_with_type(expression)._dtype def expression_with_type(self, expression: ColumnExpression) -> ColumnExpression: @@ -381,18 +371,20 @@ def get_flatten_column_dtype(flatten_column: ColumnWithExpression): from pathway.internals import type_interpreter dtype = type_interpreter.eval_type(flatten_column.expression) - if get_origin(dtype) == list: - return get_args(dtype)[0] - elif get_origin(dtype) == tuple: - return_dtype = get_args(dtype)[0] - for single_dtype in get_args(dtype)[1:]: - if single_dtype != Ellipsis: - return_dtype = types_lca(return_dtype, single_dtype) + if isinstance(dtype, dt.List): + return dtype.wrapped + if isinstance(dtype, dt.Tuple): + if dtype == dt.ANY_TUPLE: + return dt.ANY + assert not isinstance(dtype.args, EllipsisType) + return_dtype = dtype.args[0] + for single_dtype in dtype.args[1:]: + return_dtype = dt.types_lca(return_dtype, single_dtype) return return_dtype - elif dtype == str: - return str - elif dtype in {np.ndarray, Any}: - return Any + elif dtype == dt.STR: + return dt.STR + elif dtype in {dt.Array(), dt.ANY}: + return dt.ANY else: raise TypeError( f"Cannot flatten column {flatten_column.expression!r} of type {dtype}." diff --git a/python/pathway/internals/common.py b/python/pathway/internals/common.py index 2b2a5734..ef7a2a10 100644 --- a/python/pathway/internals/common.py +++ b/python/pathway/internals/common.py @@ -5,6 +5,7 @@ import functools from typing import Any, Callable, Optional, Union, overload +from pathway.internals import dtype as dt from pathway.internals import expression as expr from pathway.internals import operator as op from pathway.internals import table @@ -267,7 +268,7 @@ def numba_apply( def apply_with_type( fun: Callable, - ret_type: type, + ret_type: type | dt.DType, *args: expr.ColumnExpressionOrValue, **kwargs: expr.ColumnExpressionOrValue, ) -> expr.ColumnExpression: @@ -346,14 +347,17 @@ def declare_type( >>> t1 = pw.debug.parse_to_table(''' ... val ... 1 10 - ... 2 9 + ... 2 9.5 ... 3 8 ... 4 7''') >>> t1.schema.as_dict() - {'val': } - >>> t2 = t1.select(val = pw.declare_type(int | float, t1.val)) + {'val': FLOAT} + >>> t2 = t1.filter(t1.val == pw.cast(int, t1.val)) >>> t2.schema.as_dict() - {'val': int | float} + {'val': FLOAT} + >>> t3 = t2.select(val = pw.declare_type(int, t2.val)) + >>> t3.schema.as_dict() + {'val': INT} """ return expr.DeclareTypeExpression(target_type, col) @@ -372,7 +376,7 @@ def cast(target_type: Any, col: expr.ColumnExpressionOrValue) -> expr.CastExpres ... 3 8 ... 4 7''') >>> t1.schema.as_dict() - {'val': } + {'val': INT} >>> pw.debug.compute_and_print(t1, include_id=False) val 7 @@ -381,7 +385,7 @@ def cast(target_type: Any, col: expr.ColumnExpressionOrValue) -> expr.CastExpres 10 >>> t2 = t1.select(val = pw.cast(float, t1.val)) >>> t2.schema.as_dict() - {'val': } + {'val': FLOAT} >>> pw.debug.compute_and_print(t2, include_id=False) val 7.0 @@ -530,7 +534,7 @@ def unwrap(col: expr.ColumnExpressionOrValue) -> expr.ColumnExpression: ... 3 | None ... 4 | 15''') >>> t1.schema.as_dict() - {'colA': , 'colB': typing.Optional[int]} + {'colA': INT, 'colB': Optional(INT)} >>> pw.debug.compute_and_print(t1, include_id=False) colA | colB 1 | 5 @@ -539,14 +543,14 @@ def unwrap(col: expr.ColumnExpressionOrValue) -> expr.ColumnExpression: 4 | 15 >>> t2 = t1.filter(t1.colA < 3) >>> t2.schema.as_dict() - {'colA': , 'colB': typing.Optional[int]} + {'colA': INT, 'colB': Optional(INT)} >>> pw.debug.compute_and_print(t2, include_id=False) colA | colB 1 | 5 2 | 9 >>> t3 = t2.select(colB = pw.unwrap(t2.colB)) >>> t3.schema.as_dict() - {'colB': } + {'colB': INT} >>> pw.debug.compute_and_print(t3, include_id=False) colB 5 diff --git a/python/pathway/internals/datetime_types.py b/python/pathway/internals/datetime_types.py deleted file mode 100644 index 44666a7f..00000000 --- a/python/pathway/internals/datetime_types.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright © 2023 Pathway - -import pandas as pd - - -class DateTimeNaive(pd.Timestamp): - pass - - -class DateTimeUtc(pd.Timestamp): - pass - - -class Duration(pd.Timedelta): - pass diff --git a/python/pathway/internals/decorators.py b/python/pathway/internals/decorators.py index 1c02ebf0..eea3d211 100644 --- a/python/pathway/internals/decorators.py +++ b/python/pathway/internals/decorators.py @@ -10,10 +10,10 @@ from pathway.internals.datasink import DataSink from pathway.internals.schema import Schema +from pathway.internals import dtype as dt from pathway.internals import operator as op from pathway.internals import row_transformer as rt from pathway.internals.datasource import EmptyDataSource, StaticDataSource -from pathway.internals.dtype import DType from pathway.internals.helpers import function_spec, with_optional_kwargs from pathway.internals.parse_graph import G @@ -68,7 +68,7 @@ def input_attribute(type=float): 9 | 10 10 | 11 """ - return rt.InputAttribute(dtype=DType(type)) + return rt.InputAttribute(dtype=dt.wrap(type)) def input_method(type=float): @@ -103,7 +103,7 @@ def input_method(type=float): ... 7''') >>> t2 = first_transformer(table=t1.select(a=t1.age)).table >>> t2.schema.as_dict() - {'fun': typing.Callable[..., int]} + {'fun': Callable(..., INT)} >>> t3 = second_transformer(table=t2.select(m=t2.fun)).table >>> pw.debug.compute_and_print(t1 + t3, include_id=False) age | val @@ -112,7 +112,7 @@ def input_method(type=float): 9 | 18 10 | 20 """ - return rt.InputMethod(dtype=DType(type)) + return rt.InputMethod(dtype=dt.wrap(type)) @with_optional_kwargs @@ -213,7 +213,7 @@ def method(func, **kwargs): ... 7''') >>> t2 = simple_transformer(table=t1.select(a=t1.age)).table >>> t2.schema.as_dict() - {'b': , 'fun': typing.Callable[..., float]} + {'b': FLOAT, 'fun': Callable(..., FLOAT)} >>> pw.debug.compute_and_print(t1 + t2.select(t2.b), include_id=False) age | b 7 | 49 diff --git a/python/pathway/internals/dtype.py b/python/pathway/internals/dtype.py index 86b71edb..711a2b87 100644 --- a/python/pathway/internals/dtype.py +++ b/python/pathway/internals/dtype.py @@ -2,157 +2,502 @@ from __future__ import annotations -import sys -from typing import Any, Callable, NewType, Optional, Tuple, Union, get_args, get_origin +import collections +import typing +from abc import ABC +from types import EllipsisType +from warnings import warn + +import numpy as np +import pandas as pd from pathway.internals import api -DType = NewType("DType", object) +if typing.TYPE_CHECKING: + from pathway.internals.schema import Schema -NoneType = DType(type(None)) -# TODO: require later that types need to be typing.Optional or EmptyType +class DType(ABC): + _cache: dict[typing.Any, DType] = {} + def to_engine(self) -> api.PathwayType: + return api.PathwayType.ANY -def is_optional(input_type: DType) -> bool: - return ( - get_origin(input_type) is Union - and len(get_args(input_type)) == 2 - and isinstance(None, get_args(input_type)[1]) - ) + def __call__(self, *args, **kwargs): + raise RuntimeError("Not implemented.") + + def _set_args(self, *args): + raise RuntimeError("Not implemented.") + + @classmethod + def _cached_new(cls, *args): + key = (cls, args) + if key not in DType._cache: + ret = super().__new__(cls) + ret._set_args(*args) + cls._cache[key] = ret + return DType._cache[key] + + def __class_getitem__(cls, args): + if isinstance(args, tuple): + return cls(*args) + else: + return cls(args) # type: ignore[call-arg] + + +class _SimpleDType(DType): + wrapped: type + + def __repr__(self): + return { + INT: "INT", + BOOL: "BOOL", + STR: "STR", + FLOAT: "FLOAT", + }[self] + + def _set_args(self, wrapped): + self.wrapped = wrapped + + def __new__(cls, wrapped: type) -> _SimpleDType: + return cls._cached_new(wrapped) + + def to_engine(self) -> api.PathwayType: + return { + INT: api.PathwayType.INT, + BOOL: api.PathwayType.BOOL, + STR: api.PathwayType.STRING, + FLOAT: api.PathwayType.FLOAT, + }[self] + + def __call__(self, arg): + return self.wrapped(arg) + + +INT: DType = _SimpleDType(int) +BOOL: DType = _SimpleDType(bool) +STR: DType = _SimpleDType(str) +FLOAT: DType = _SimpleDType(float) + + +class _NoneDType(DType): + def __repr__(self): + return "NONE" + + def _set_args(self): + pass + + def __new__(cls) -> _NoneDType: + return cls._cached_new() + + def __call__(self): + return None + + +NONE: DType = _NoneDType() + + +class _AnyDType(DType): + def __repr__(self): + return "ANY" + + def _set_args(self): + pass + + def __new__(cls) -> _AnyDType: + return cls._cached_new() + + +ANY: DType = _AnyDType() + + +class Callable(DType): + arg_types: EllipsisType | tuple[DType | EllipsisType, ...] + return_type: DType + + def __repr__(self): + if isinstance(self.arg_types, EllipsisType): + return f"Callable(..., {self.return_type})" + else: + return f"Callable({self.arg_types}, {self.return_type})" + + def _set_args(self, arg_types, return_type): + if isinstance(arg_types, EllipsisType): + self.arg_types = ... + else: + self.arg_types = tuple( + Ellipsis if isinstance(dtype, EllipsisType) else wrap(dtype) + for dtype in arg_types + ) + self.return_type = wrap(return_type) + + def __new__( + cls, + arg_types: EllipsisType | tuple[DType | EllipsisType, ...], + return_type: DType, + ) -> Callable: + return cls._cached_new(arg_types, return_type) + + +class Array(DType): + def __repr__(self): + return "Array" + + def _set_args(self): + pass + + def to_engine(self) -> api.PathwayType: + return api.PathwayType.ARRAY + + def __new__(cls) -> Array: + return cls._cached_new() + + +T = typing.TypeVar("T") + + +class Pointer(DType, typing.Generic[T]): + wrapped: typing.Optional[typing.Type[Schema]] = None + + def __repr__(self): + if self.wrapped is not None: + return f"Pointer({self.wrapped.__name__})" + else: + return "Pointer" + + def _set_args(self, wrapped): + self.wrapped = wrapped + + def to_engine(self) -> api.PathwayType: + return api.PathwayType.POINTER + + def __new__(cls, wrapped: typing.Optional[typing.Type[Schema]] = None) -> Pointer: + return cls._cached_new(wrapped) + + +POINTER: DType = Pointer() + + +class Optional(DType): + wrapped: DType + + def __init__(self, arg): + super().__init__() + + def __repr__(self): + return f"Optional({self.wrapped})" + + def _set_args(self, wrapped): + self.wrapped = wrap(wrapped) + + def __new__(cls, arg: DType) -> DType: # type:ignore[misc] + if arg == NONE or isinstance(arg, Optional) or arg == ANY: + return arg + return cls._cached_new(arg) + + def __call__(self, arg): + if arg is None: + return None + else: + return self.wrapped(arg) + + +class Tuple(DType): + args: tuple[DType, ...] | EllipsisType + + def __init__(self, *args): + super().__init__() + + def __repr__(self): + if isinstance(self.args, EllipsisType): + return "Tuple" + else: + return f"Tuple({', '.join(str(arg) for arg in self.args)})" + + def _set_args(self, args): + if isinstance(args, EllipsisType): + self.args = ... + else: + self.args = tuple(wrap(arg) for arg in args) + + def __new__(cls, *args: DType | EllipsisType) -> Tuple | List: # type: ignore[misc] + if any(isinstance(arg, EllipsisType) for arg in args): + if len(args) == 2: + arg, placeholder = args + assert isinstance(placeholder, EllipsisType) + assert isinstance(arg, DType) + return List(arg) + assert len(args) == 1 + return cls._cached_new(Ellipsis) + else: + return cls._cached_new(args) -def sanitize_type(input_type): - if ( +class List(DType): + wrapped: DType + + def __repr__(self): + return f"List({self.wrapped})" + + def __new__(cls, wrapped: DType) -> List: + return cls._cached_new(wrapped) + + def _set_args(self, wrapped): + self.wrapped = wrap(wrapped) + + +ANY_TUPLE: DType = Tuple(...) + + +class _DateTimeNaive(DType): + def __repr__(self): + return "DATE_TIME_NAIVE" + + def _set_args(self): + pass + + def __new__(cls) -> _DateTimeNaive: + return cls._cached_new() + + def __call__(self, arg): + ret = pd.Timestamp(arg) + assert ret.tzinfo is None + return ret + + +DATE_TIME_NAIVE = _DateTimeNaive() + + +class _DateTimeUtc(DType): + def __repr__(self): + return "DATE_TIME_UTC" + + def _set_args(self): + pass + + def __new__(cls) -> _DateTimeUtc: + return cls._cached_new() + + def __call__(self, arg): + ret = pd.Timestamp(arg) + assert ret.tzinfo is not None + return ret + + +DATE_TIME_UTC = _DateTimeUtc() + + +class _Duration(DType): + def __repr__(self): + return "DURATION" + + def _set_args(self): + pass + + def __new__(cls) -> _Duration: + return cls._cached_new() + + def __call__(self, arg): + return pd.Timedelta(arg) + + +DURATION = _Duration() + + +_NoneType = type(None) + + +def wrap(input_type) -> DType: + assert input_type != ... + assert input_type != Optional + assert input_type != Pointer + assert input_type != Tuple + assert input_type != Callable + assert input_type != Array + assert input_type != List + if isinstance(input_type, DType): + return input_type + if input_type == _NoneType: + return NONE + elif input_type == typing.Any: + return ANY + elif ( input_type is api.BasePointer or input_type is api.Pointer - or get_origin(input_type) is api.Pointer + or typing.get_origin(input_type) is api.Pointer ): - return api.Pointer - elif sys.version_info >= (3, 10) and isinstance(input_type, NewType): - # NewType is a class starting from Python 3.10 - return sanitize_type(input_type.__supertype__) + args = typing.get_args(input_type) + if len(args) == 0: + return Pointer() + else: + assert len(args) == 1 + return Pointer(args[0]) elif isinstance(input_type, str): - return Any # TODO: input_type is annotation for class - elif get_origin(input_type) is get_origin( - Callable - ): # for some weird reason get_origin(Callable) == collections.abc.Callable - return Callable - elif is_optional(input_type): - return sanitize_type(unoptionalize(input_type)) + return ANY # TODO: input_type is annotation for class + elif typing.get_origin(input_type) == collections.abc.Callable: + c_args = get_args(input_type) + if c_args == (): + return Callable(..., ANY) + arg_types, ret_type = c_args + if isinstance(arg_types, Tuple): + callable_args = arg_types.args + else: + assert isinstance(arg_types, EllipsisType) + callable_args = arg_types + assert isinstance(ret_type, DType), type(ret_type) + return Callable(callable_args, ret_type) + elif ( + typing.get_origin(input_type) is typing.Union + and len(typing.get_args(input_type)) == 2 + and isinstance(None, typing.get_args(input_type)[1]) + ): + arg, _ = get_args(input_type) + assert isinstance(arg, DType) + return Optional(arg) + elif input_type in [list, tuple, typing.List, typing.Tuple]: + return ANY_TUPLE + elif typing.get_origin(input_type) == list: + args = get_args(input_type) + (arg,) = args + return List(wrap(arg)) + elif typing.get_origin(input_type) == tuple: + args = get_args(input_type) + if args[-1] == ...: + arg, _ = args + return List(wrap(arg)) + else: + return Tuple(*[wrap(arg) for arg in args]) + elif input_type == np.ndarray: + return Array() else: - return input_type - + dtype = {int: INT, bool: BOOL, str: STR, float: FLOAT}.get(input_type, ANY) + if dtype == ANY: + warn(f"Unsupported type {input_type}, falling back to ANY.") + return dtype -def is_pointer(input_type: DType) -> bool: - return sanitize_type(input_type) == api.Pointer - -def dtype_equivalence(left: DType, right: DType): +def dtype_equivalence(left: DType, right: DType) -> bool: return dtype_issubclass(left, right) and dtype_issubclass(right, left) -def dtype_tuple_equivalence(left: DType, right: DType): - largs = get_args(left) - rargs = get_args(right) - if len(largs) != len(rargs): - return False - if largs[-1] == Ellipsis: - largs = largs[:-1] - if rargs[-1] == Ellipsis: - rargs = rargs[:-1] +def dtype_tuple_equivalence(left: Tuple | List, right: Tuple | List) -> bool: + if left == ANY_TUPLE or right == ANY_TUPLE: + return True + if isinstance(left, List) and isinstance(right, List): + return left.wrapped == right.wrapped + if isinstance(left, List): + assert isinstance(right, Tuple) + assert not isinstance(right.args, EllipsisType) + rargs = right.args + largs = tuple(left.wrapped for _arg in rargs) + elif isinstance(right, List): + assert isinstance(left, Tuple) + assert not isinstance(left.args, EllipsisType) + largs = left.args + rargs = tuple(right.wrapped for _arg in largs) + else: + assert isinstance(left, Tuple) + assert isinstance(right, Tuple) + assert not isinstance(left.args, EllipsisType) + assert not isinstance(right.args, EllipsisType) + largs = left.args + rargs = right.args if len(largs) != len(rargs): return False - return all(dtype_equivalence(l_arg, r_arg) for l_arg, r_arg in zip(largs, rargs)) + return all( + dtype_equivalence(typing.cast(DType, l_arg), typing.cast(DType, r_arg)) + for l_arg, r_arg in zip(largs, rargs) + ) -def dtype_issubclass(left: DType, right: DType): - if right == Any: # catch the case, when left=Optional[T] and right=Any +def dtype_issubclass(left: DType, right: DType) -> bool: + if right == ANY: # catch the case, when left=Optional[T] and right=Any return True - elif is_optional(left) and not is_optional(right): - return False - elif left == NoneType: - return is_optional(right) or right == NoneType - elif is_tuple_like(left) and is_tuple_like(right): - left = normalize_tuple_like(left) - right = normalize_tuple_like(right) - if left == Tuple or right == Tuple: + elif isinstance(left, Optional): + if isinstance(right, Optional): + return dtype_issubclass(unoptionalize(left), unoptionalize(right)) + else: + return False + elif left == NONE: + return isinstance(right, Optional) or right == NONE + elif isinstance(right, Optional): + return dtype_issubclass(left, unoptionalize(right)) + elif isinstance(left, (Tuple, List)) and isinstance(right, (Tuple, List)): + return dtype_tuple_equivalence(left, right) + elif isinstance(left, Pointer) and isinstance(right, Pointer): + return True # TODO + elif isinstance(left, _SimpleDType) and isinstance(right, _SimpleDType): + if left == INT and right == FLOAT: return True + elif left == BOOL and right == INT: + return False else: - return dtype_tuple_equivalence(left, right) - elif is_tuple_like(left) or is_tuple_like(right): - return False - left_ = sanitize_type(left) - right_ = sanitize_type(right) - if right_ is Any: # repeated to check after sanitization - return True - elif left_ is Any: - return False - elif left_ is int and right_ is float: + return issubclass(left.wrapped, right.wrapped) + elif isinstance(left, Callable) and isinstance(right, Callable): return True - elif left_ is bool and right_ is int: - return False - elif isinstance(left_, type): - return issubclass(left_, right_) - else: - return left_ == right_ - - -def is_tuple_like(dtype: DType): - return dtype in [list, tuple] or get_origin(dtype) in [list, tuple] - - -def normalize_tuple_like(dtype: DType): - if get_origin(dtype) == list: - args = get_args(dtype) - if len(args) == 0: - return DType(Tuple) - (arg,) = args - return DType(Tuple[arg]) - assert dtype == tuple or get_origin(dtype) == tuple - return dtype + return left == right def types_lca(left: DType, right: DType) -> DType: """LCA of two types.""" - if is_optional(left) or is_optional(right): - dtype = types_lca(unoptionalize(left), unoptionalize(right)) - return DType(Optional[dtype]) if dtype is not Any else DType(Any) - elif is_tuple_like(left) and is_tuple_like(right): - left = normalize_tuple_like(left) - right = normalize_tuple_like(right) - if left == Tuple or right == Tuple: - return DType(Tuple) + if isinstance(left, Optional) or isinstance(right, Optional): + return Optional(types_lca(unoptionalize(left), unoptionalize(right))) + elif isinstance(left, (Tuple, List)) and isinstance(right, (Tuple, List)): + if left == ANY_TUPLE or right == ANY_TUPLE: + return ANY_TUPLE elif dtype_tuple_equivalence(left, right): return left else: - return DType(Tuple) - elif is_tuple_like(left) or is_tuple_like(right): - return DType(Any) - elif dtype_issubclass(left, right): + return ANY_TUPLE + elif isinstance(left, Pointer) and isinstance(right, Pointer): + l_schema = left.wrapped + r_schema = right.wrapped + if l_schema is None: + return right + elif r_schema is None: + return left + if l_schema == r_schema: + return left + else: + return POINTER + if dtype_issubclass(left, right): return right elif dtype_issubclass(right, left): return left - elif left == NoneType: - return DType(Optional[right]) - elif right == NoneType: - return DType(Optional[left]) + + if left == NONE: + return Optional(right) + elif right == NONE: + return Optional(left) else: - return DType(Any) + return ANY def unoptionalize(dtype: DType) -> DType: - return get_args(dtype)[0] if is_optional(dtype) else dtype + return dtype.wrapped if isinstance(dtype, Optional) else dtype + + +def normalize_dtype(dtype: DType) -> DType: + if isinstance(dtype, Pointer): + return POINTER + if isinstance(dtype, Array): + return Array() + return dtype -def unoptionalize_pair(left_dtype: DType, right_dtype) -> Tuple[DType, DType]: +def unoptionalize_pair(left_dtype: DType, right_dtype) -> typing.Tuple[DType, DType]: """ Unpacks type out of typing.Optional and matches a second type with it if it is an EmptyType. """ - if left_dtype == NoneType and is_optional(right_dtype): + if left_dtype == NONE and isinstance(right_dtype, Optional): left_dtype = right_dtype - if right_dtype == NoneType and is_optional(left_dtype): + if right_dtype == NONE and isinstance(left_dtype, Optional): right_dtype = left_dtype return unoptionalize(left_dtype), unoptionalize(right_dtype) + + +def get_args(dtype: typing.Any) -> tuple[EllipsisType | DType, ...]: + arg_types = typing.get_args(dtype) + return tuple(wrap(arg) if arg != ... else ... for arg in arg_types) diff --git a/python/pathway/internals/expression.py b/python/pathway/internals/expression.py index 7bfea8c0..5be3a600 100644 --- a/python/pathway/internals/expression.py +++ b/python/pathway/internals/expression.py @@ -11,16 +11,16 @@ Callable, Dict, Iterable, - List, Optional, Tuple, Type, Union, ) -from pathway.internals import api, helpers +from pathway.internals import api +from pathway.internals import dtype as dt +from pathway.internals import helpers from pathway.internals.api import Value -from pathway.internals.dtype import DType, dtype_issubclass from pathway.internals.operator_input import OperatorInput from pathway.internals.shadows import inspect, operator from pathway.internals.trace import Trace @@ -40,7 +40,7 @@ class ColumnExpression(OperatorInput, ABC): _deps: helpers.SetOnceProperty[ Iterable[ColumnExpression] ] = helpers.SetOnceProperty(wrapper=tuple) - _dtype: DType + _dtype: dt.DType def __init__(self): self._trace = Trace.from_traceback() @@ -298,7 +298,7 @@ def to_string(self) -> MethodCallExpression: ... 3 ... 4''') >>> t1.schema.as_dict() - {'val': } + {'val': INT} >>> pw.debug.compute_and_print(t1, include_id=False) val 1 @@ -307,7 +307,7 @@ def to_string(self) -> MethodCallExpression: 4 >>> t2 = t1.select(val = pw.this.val.to_string()) >>> t2.schema.as_dict() - {'val': } + {'val': STR} >>> pw.debug.compute_and_print(t2.select(val=pw.this.val + "a"), include_id=False) val 1a @@ -316,13 +316,13 @@ def to_string(self) -> MethodCallExpression: 4a """ return MethodCallExpression( - [ + ( ( - Any, - str, + dt.ANY, + dt.STR, api.Expression.to_string, - ) - ], + ), + ), "to_string", self, ) @@ -590,7 +590,7 @@ def __init__( return_type = inspect.signature(self._fun).return_annotation except ValueError: return_type = Any - self._return_type = return_type + self._return_type = dt.wrap(return_type) self._args = tuple(_wrap_arg(arg) for arg in args) @@ -609,7 +609,7 @@ def __init__( ): super().__init__( fun, - return_type, + dt.wrap(return_type), *args, **kwargs, ) @@ -625,7 +625,7 @@ def __init__( ): super().__init__( fun, - return_type, + dt.wrap(return_type) if return_type is not None else None, *args, **kwargs, ) @@ -637,7 +637,7 @@ class CastExpression(ColumnExpression): def __init__(self, return_type: Any, expr: ColumnExpressionOrValue): super().__init__() - self._return_type = return_type + self._return_type = dt.wrap(return_type) self._expr = _wrap_arg(expr) self._deps = [self._expr] @@ -648,7 +648,7 @@ class DeclareTypeExpression(ColumnExpression): def __init__(self, return_type: Any, expr: ColumnExpressionOrValue): super().__init__() - self._return_type = return_type + self._return_type = dt.wrap(return_type) self._expr = _wrap_arg(expr) self._deps = [self._expr] @@ -777,13 +777,15 @@ def __init__( class MethodCallExpression(ColumnExpression): - _fun_mapping: List[Tuple[Any, Any, Callable]] + _fun_mapping: Tuple[Tuple[tuple[dt.DType, ...], dt.DType, Callable], ...] _name: str _args: Tuple[ColumnExpression, ...] def __init__( self, - fun_mapping: List[Tuple[Any, Any, Callable]], + fun_mapping: Tuple[ + Tuple[tuple[dt.DType, ...] | dt.DType, dt.DType, Callable], ... + ], name: str, *args: ColumnExpressionOrValue, ) -> None: @@ -821,22 +823,21 @@ def __init__( ) def _wrap_mapping_key_in_tuple( - self, mapping: List[Tuple[Any, Any, Callable]] - ) -> List[Tuple[Any, Any, Callable]]: - result = [] - for key, result_dtype, value in mapping: - if not isinstance(key, tuple): - key = (key,) - result.append((key, result_dtype, value)) - return result + self, + mapping: tuple[tuple[tuple[dt.DType, ...] | dt.DType, dt.DType, Callable], ...], + ) -> tuple[tuple[tuple[dt.DType, ...], dt.DType, Callable], ...]: + return tuple( + (key if isinstance(key, tuple) else (key,), result_dtype, value) + for key, result_dtype, value in mapping + ) def get_function( - self, dtypes: Tuple[DType, ...] - ) -> Optional[Tuple[Tuple[DType, ...], DType, Callable]]: + self, dtypes: tuple[dt.DType, ...] + ) -> Optional[tuple[tuple[dt.DType, ...], dt.DType, Callable]]: for key, target_type, fun in self._fun_mapping: assert len(dtypes) == len(key) if all( - dtype_issubclass(arg_dtype, key_dtype) + dt.dtype_issubclass(arg_dtype, key_dtype) for arg_dtype, key_dtype in zip(dtypes, key) ): return (key, target_type, fun) diff --git a/python/pathway/internals/expression_printer.py b/python/pathway/internals/expression_printer.py index 573e06b0..2566d38d 100644 --- a/python/pathway/internals/expression_printer.py +++ b/python/pathway/internals/expression_printer.py @@ -180,7 +180,9 @@ def get_expression_info(expression: expr.ColumnExpression) -> str: def _type_name(return_type): - if isinstance(return_type, str): + from pathway.internals import dtype as dt + + if isinstance(return_type, str) or isinstance(return_type, dt.DType): return repr(return_type) else: return return_type.__name__ diff --git a/python/pathway/internals/expressions/date_time.py b/python/pathway/internals/expressions/date_time.py index 387e0b91..8a274383 100644 --- a/python/pathway/internals/expressions/date_time.py +++ b/python/pathway/internals/expressions/date_time.py @@ -2,9 +2,11 @@ from typing import Optional, Union +import pandas as pd + import pathway.internals.expression as expr from pathway.internals import api -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration +from pathway.internals import dtype as dt class DateTimeNamespace: @@ -60,10 +62,10 @@ def nanosecond(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_nanosecond), - (DateTimeUtc, int, api.Expression.date_time_utc_nanosecond), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_nanosecond), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_nanosecond), + ), "dt.nanosecond", self._expression, ) @@ -99,10 +101,14 @@ def microsecond(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_microsecond), - (DateTimeUtc, int, api.Expression.date_time_utc_microsecond), - ], + ( + ( + dt.DATE_TIME_NAIVE, + dt.INT, + api.Expression.date_time_naive_microsecond, + ), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_microsecond), + ), "dt.microsecond", self._expression, ) @@ -138,10 +144,14 @@ def millisecond(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_millisecond), - (DateTimeUtc, int, api.Expression.date_time_utc_millisecond), - ], + ( + ( + dt.DATE_TIME_NAIVE, + dt.INT, + api.Expression.date_time_naive_millisecond, + ), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_millisecond), + ), "dt.millisecond", self._expression, ) @@ -177,10 +187,10 @@ def second(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_second), - (DateTimeUtc, int, api.Expression.date_time_utc_second), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_second), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_second), + ), "dt.second", self._expression, ) @@ -216,10 +226,10 @@ def minute(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_minute), - (DateTimeUtc, int, api.Expression.date_time_utc_minute), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_minute), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_minute), + ), "dt.minute", self._expression, ) @@ -251,10 +261,10 @@ def hour(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_hour), - (DateTimeUtc, int, api.Expression.date_time_utc_hour), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_hour), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_hour), + ), "dt.hour", self._expression, ) @@ -286,10 +296,10 @@ def day(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_day), - (DateTimeUtc, int, api.Expression.date_time_utc_day), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_day), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_day), + ), "dt.day", self._expression, ) @@ -321,10 +331,10 @@ def month(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_month), - (DateTimeUtc, int, api.Expression.date_time_utc_month), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_month), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_month), + ), "dt.month", self._expression, ) @@ -356,10 +366,10 @@ def year(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_year), - (DateTimeUtc, int, api.Expression.date_time_utc_year), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_year), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_year), + ), "dt.year", self._expression, ) @@ -428,10 +438,10 @@ def timestamp(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (DateTimeNaive, int, api.Expression.date_time_naive_timestamp), - (DateTimeUtc, int, api.Expression.date_time_utc_timestamp), - ], + ( + (dt.DATE_TIME_NAIVE, dt.INT, api.Expression.date_time_naive_timestamp), + (dt.DATE_TIME_UTC, dt.INT, api.Expression.date_time_utc_timestamp), + ), "dt.timestamp", self._expression, ) @@ -477,10 +487,18 @@ def strftime(self, fmt: Union[expr.ColumnExpression, str]) -> expr.ColumnExpress """ return expr.MethodCallExpression( - [ - ((DateTimeNaive, str), str, api.Expression.date_time_naive_strftime), - ((DateTimeUtc, str), str, api.Expression.date_time_utc_strftime), - ], + ( + ( + (dt.DATE_TIME_NAIVE, dt.STR), + dt.STR, + api.Expression.date_time_naive_strftime, + ), + ( + (dt.DATE_TIME_UTC, dt.STR), + dt.STR, + api.Expression.date_time_utc_strftime, + ), + ), "dt.strftime", self._expression, fmt, @@ -579,13 +597,13 @@ def strptime( if contains_timezone: fun = api.Expression.date_time_utc_strptime - return_type: type = DateTimeUtc + return_type: dt.DType = dt.DATE_TIME_UTC else: fun = api.Expression.date_time_naive_strptime - return_type = DateTimeNaive + return_type = dt.DATE_TIME_NAIVE return expr.MethodCallExpression( - [((str, str), return_type, fun)], + (((dt.STR, dt.STR), return_type, fun),), "dt.strptime", self._expression, fmt, @@ -669,13 +687,13 @@ def to_utc( """ return expr.MethodCallExpression( - [ + ( ( - (DateTimeNaive, str), - DateTimeUtc, + (dt.DATE_TIME_NAIVE, dt.STR), + dt.DATE_TIME_UTC, api.Expression.date_time_naive_to_utc, ), - ], + ), "dt.to_utc", self._expression, from_timezone, @@ -759,13 +777,13 @@ def to_naive_in_timezone( """ return expr.MethodCallExpression( - [ + ( ( - (DateTimeUtc, str), - DateTimeNaive, + (dt.DATE_TIME_UTC, dt.STR), + dt.DATE_TIME_NAIVE, api.Expression.date_time_utc_to_naive, ), - ], + ), "dt.to_to_naive_in_timezone", self._expression, timezone, @@ -773,7 +791,7 @@ def to_naive_in_timezone( def add_duration_in_timezone( self, - duration: Union[expr.ColumnExpression, Duration], + duration: Union[expr.ColumnExpression, pd.Timedelta], timezone: Union[expr.ColumnExpression, str], ) -> expr.ColumnExpression: """Adds Duration to DateTime taking into account time zone. @@ -818,7 +836,7 @@ def add_duration_in_timezone( def subtract_duration_in_timezone( self, - duration: Union[expr.ColumnExpression, Duration], + duration: Union[expr.ColumnExpression, pd.Timedelta], timezone: Union[expr.ColumnExpression, str], ) -> expr.ColumnExpression: """Subtracts Duration from DateTime taking into account time zone. @@ -863,7 +881,7 @@ def subtract_duration_in_timezone( def subtract_date_time_in_timezone( self, - date_time: Union[expr.ColumnExpression, DateTimeNaive], + date_time: Union[expr.ColumnExpression, pd.Timestamp], timezone: Union[expr.ColumnExpression, str], ) -> expr.ColumnBinaryOpExpression: """Subtracts two DateTimeNaives taking into account time zone. @@ -907,7 +925,7 @@ def subtract_date_time_in_timezone( return self.to_utc(timezone) - expr._wrap_arg(date_time).dt.to_utc(timezone) def round( - self, duration: Union[expr.ColumnExpression, Duration] + self, duration: Union[expr.ColumnExpression, pd.Timedelta] ) -> expr.ColumnExpression: """Rounds DateTime to precision specified by `duration` argument. @@ -947,25 +965,25 @@ def round( """ return expr.MethodCallExpression( - [ + ( ( - (DateTimeNaive, Duration), - DateTimeNaive, + (dt.DATE_TIME_NAIVE, dt.DURATION), + dt.DATE_TIME_NAIVE, api.Expression.date_time_naive_round, ), ( - (DateTimeUtc, Duration), - DateTimeUtc, + (dt.DATE_TIME_UTC, dt.DURATION), + dt.DATE_TIME_UTC, api.Expression.date_time_utc_round, ), - ], + ), "dt.round", self._expression, duration, ) def floor( - self, duration: Union[expr.ColumnExpression, Duration] + self, duration: Union[expr.ColumnExpression, pd.Timedelta] ) -> expr.ColumnExpression: """Truncates DateTime to precision specified by `duration` argument. @@ -1005,18 +1023,18 @@ def floor( """ return expr.MethodCallExpression( - [ + ( ( - (DateTimeNaive, Duration), - DateTimeNaive, + (dt.DATE_TIME_NAIVE, dt.DURATION), + dt.DATE_TIME_NAIVE, api.Expression.date_time_naive_floor, ), ( - (DateTimeUtc, Duration), - DateTimeUtc, + (dt.DATE_TIME_UTC, dt.DURATION), + dt.DATE_TIME_UTC, api.Expression.date_time_utc_floor, ), - ], + ), "dt.floor", self._expression, duration, @@ -1061,7 +1079,7 @@ def nanoseconds(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_nanoseconds)], + ((dt.DURATION, dt.INT, api.Expression.duration_nanoseconds),), "dt.nanoseconds", self._expression, ) @@ -1105,7 +1123,7 @@ def microseconds(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_microseconds)], + ((dt.DURATION, dt.INT, api.Expression.duration_microseconds),), "dt.microseconds", self._expression, ) @@ -1149,7 +1167,7 @@ def milliseconds(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_milliseconds)], + ((dt.DURATION, dt.INT, api.Expression.duration_milliseconds),), "dt.milliseconds", self._expression, ) @@ -1191,7 +1209,7 @@ def seconds(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_seconds)], + ((dt.DURATION, dt.INT, api.Expression.duration_seconds),), "dt.seconds", self._expression, ) @@ -1233,7 +1251,7 @@ def minutes(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_minutes)], + ((dt.DURATION, dt.INT, api.Expression.duration_minutes),), "dt.minutes", self._expression, ) @@ -1277,7 +1295,7 @@ def hours(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_hours)], + ((dt.DURATION, dt.INT, api.Expression.duration_hours),), "dt.hours", self._expression, ) @@ -1321,7 +1339,7 @@ def days(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_days)], + ((dt.DURATION, dt.INT, api.Expression.duration_days),), "dt.days", self._expression, ) @@ -1364,7 +1382,7 @@ def weeks(self) -> expr.ColumnExpression: 52 """ return expr.MethodCallExpression( - [(Duration, int, api.Expression.duration_weeks)], + ((dt.DURATION, dt.INT, api.Expression.duration_weeks),), "dt.weeks", self._expression, ) @@ -1400,13 +1418,13 @@ def from_timestamp(self, unit: str) -> expr.ColumnExpression: if unit not in ("s", "ms", "us", "ns"): raise ValueError(f"unit has to be one of s, ms, us, ns but is {unit}.") return expr.MethodCallExpression( - [ + ( ( - (int, str), - DateTimeNaive, + (dt.INT, dt.STR), + dt.DATE_TIME_NAIVE, api.Expression.date_time_naive_from_timestamp, ), - ], + ), "datetime_from_timestamp", self._expression, unit, diff --git a/python/pathway/internals/expressions/numerical.py b/python/pathway/internals/expressions/numerical.py index 91bd8c6b..bdc9e6e3 100644 --- a/python/pathway/internals/expressions/numerical.py +++ b/python/pathway/internals/expressions/numerical.py @@ -1,10 +1,11 @@ # Copyright © 2023 Pathway import math -from typing import Optional, Union +from typing import Union import pathway.internals.expression as expr from pathway.internals import api +from pathway.internals import dtype as dt class NumericalNamespace: @@ -56,10 +57,10 @@ def abs(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (int, int, lambda x: api.Expression.apply(abs, x)), - (float, float, lambda x: api.Expression.apply(abs, x)), - ], + ( + (dt.INT, dt.INT, lambda x: api.Expression.apply(abs, x)), + (dt.FLOAT, dt.FLOAT, lambda x: api.Expression.apply(abs, x)), + ), "num.abs", self._expression, ) @@ -123,10 +124,18 @@ def round( """ return expr.MethodCallExpression( - [ - ((int, int), int, lambda x, y: api.Expression.apply(round, x, y)), - ((float, int), float, lambda x, y: api.Expression.apply(round, x, y)), - ], + ( + ( + (dt.INT, dt.INT), + dt.INT, + lambda x, y: api.Expression.apply(round, x, y), + ), + ( + (dt.FLOAT, dt.INT), + dt.FLOAT, + lambda x, y: api.Expression.apply(round, x, y), + ), + ), "num.round", self._expression, decimals, @@ -164,25 +173,25 @@ def fill_na(self, default_value: Union[int, float]) -> expr.ColumnExpression: # XXX Update to api.Expression.if_else when a isnan operator is supported. return expr.MethodCallExpression( - [ - (int, int, lambda x: x), + ( + (dt.INT, dt.INT, lambda x: x), ( - float, - float, + dt.FLOAT, + dt.FLOAT, lambda x: api.Expression.apply( lambda y: float(default_value) if math.isnan(y) else y, x ), ), ( - Optional[int], - int, + dt.Optional(dt.INT), + dt.INT, lambda x: api.Expression.apply( lambda y: int(default_value) if y is None else y, x ), ), ( - Optional[float], - float, + dt.Optional(dt.FLOAT), + dt.FLOAT, lambda x: api.Expression.apply( lambda y: float(default_value) if ((y is None) or math.isnan(y)) @@ -190,7 +199,7 @@ def fill_na(self, default_value: Union[int, float]) -> expr.ColumnExpression: x, ), ), - ], + ), "num.fill_na", self._expression, ) diff --git a/python/pathway/internals/expressions/string.py b/python/pathway/internals/expressions/string.py index de85c87c..31b2970e 100644 --- a/python/pathway/internals/expressions/string.py +++ b/python/pathway/internals/expressions/string.py @@ -4,6 +4,7 @@ import pathway.internals.expression as expr from pathway.internals import api +from pathway.internals import dtype as dt class StringNamespace: @@ -58,9 +59,7 @@ def lower(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, str, lambda x: api.Expression.apply(str.lower, x)), - ], + ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.lower, x)),), "str.lower", self._expression, ) @@ -93,9 +92,7 @@ def upper(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, str, lambda x: api.Expression.apply(str.upper, x)), - ], + ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.upper, x)),), "str.upper", self._expression, ) @@ -128,9 +125,7 @@ def reversed(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, str, lambda x: api.Expression.apply(lambda y: y[::-1], x)), - ], + ((dt.STR, dt.STR, lambda x: api.Expression.apply(lambda y: y[::-1], x)),), "str.reverse", self._expression, ) @@ -163,9 +158,7 @@ def len(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, int, lambda x: api.Expression.apply(len, x)), - ], + ((dt.STR, dt.INT, lambda x: api.Expression.apply(len, x)),), "str.len", self._expression, ) @@ -228,15 +221,15 @@ def replace( """ return expr.MethodCallExpression( - [ + ( ( - (str, str, str, int), - str, + (dt.STR, dt.STR, dt.STR, dt.INT), + dt.STR, lambda x, y, z, c: api.Expression.apply( lambda s1, s2, s3, cnt: s1.replace(s2, s3, cnt), x, y, z, c ), ), - ], + ), "str.replace", self._expression, old_value, @@ -272,13 +265,13 @@ def startswith( """ return expr.MethodCallExpression( - [ + ( ( - (str, str), - bool, + (dt.STR, dt.STR), + dt.BOOL, lambda x, y: api.Expression.apply(str.startswith, x, y), ), - ], + ), "str.starts_with", self._expression, prefix, @@ -312,13 +305,13 @@ def endswith( """ return expr.MethodCallExpression( - [ + ( ( - (str, str), - bool, + (dt.STR, dt.STR), + dt.BOOL, lambda x, y: api.Expression.apply(str.endswith, x, y), ), - ], + ), "str.ends_with", self._expression, suffix, @@ -349,9 +342,7 @@ def swapcase(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, str, lambda x: api.Expression.apply(str.swapcase, x)), - ], + ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.swapcase, x)),), "str.swap_case", self._expression, ) @@ -385,13 +376,13 @@ def strip( """ return expr.MethodCallExpression( - [ + ( ( - (str, Optional[str]), - str, + (dt.STR, dt.Optional(dt.STR)), + dt.STR, lambda x, y: api.Expression.apply(str.strip, x, y), ), - ], + ), "str.strip", self._expression, chars, @@ -418,9 +409,7 @@ def title(self) -> expr.ColumnExpression: """ return expr.MethodCallExpression( - [ - (str, str, lambda x: api.Expression.apply(str.title, x)), - ], + ((dt.STR, dt.STR, lambda x: api.Expression.apply(str.title, x)),), "str.title", self._expression, ) @@ -457,18 +446,18 @@ def count( """ return expr.MethodCallExpression( - [ + ( ( ( - str, - str, - Optional[int], - Optional[int], + dt.STR, + dt.STR, + dt.Optional(dt.INT), + dt.Optional(dt.INT), ), - int, + dt.INT, lambda *args: api.Expression.apply(str.count, *args), ), - ], + ), "str.count", self._expression, sub, @@ -509,18 +498,18 @@ def find( """ return expr.MethodCallExpression( - [ + ( ( ( - str, - str, - Optional[int], - Optional[int], + dt.STR, + dt.STR, + dt.Optional(dt.INT), + dt.Optional(dt.INT), ), - int, + dt.INT, lambda *args: api.Expression.apply(str.find, *args), ), - ], + ), "str.find", self._expression, sub, @@ -561,18 +550,18 @@ def rfind( """ return expr.MethodCallExpression( - [ + ( ( ( - str, - str, - Optional[int], - Optional[int], + dt.STR, + dt.STR, + dt.Optional(dt.INT), + dt.Optional(dt.INT), ), - int, + dt.INT, lambda *args: api.Expression.apply(str.rfind, *args), ), - ], + ), "str.rfind", self._expression, sub, @@ -625,13 +614,13 @@ def removeprefix( """ return expr.MethodCallExpression( - [ + ( ( - (str, str), - str, + (dt.STR, dt.STR), + dt.STR, lambda x, y: api.Expression.apply(str.removeprefix, x, y), ), - ], + ), "str.remove_prefix", self._expression, prefix, @@ -682,13 +671,13 @@ def removesuffix( """ return expr.MethodCallExpression( - [ + ( ( - (str, str), - str, + (dt.STR, dt.STR), + dt.STR, lambda x, y: api.Expression.apply(str.removesuffix, x, y), ), - ], + ), "str.remove_suffix", self._expression, suffix, @@ -724,10 +713,10 @@ def slice( """ return expr.MethodCallExpression( - [ + ( ( - (str, int, int), - str, + (dt.STR, dt.INT, dt.INT), + dt.STR, lambda x, y, z: api.Expression.apply( lambda s, slice_start, slice_end: s[slice_start:slice_end], x, @@ -735,7 +724,7 @@ def slice( z, ), ), - ], + ), "str.slice", self._expression, start, @@ -754,10 +743,10 @@ def parse_int(self, optional: bool = False) -> expr.ColumnExpression: >>> df = pd.DataFrame({"a": ["-5", "0", "200"]}, dtype=str) >>> table = pw.debug.table_from_pandas(df) >>> table.schema.as_dict() - {'a': } + {'a': STR} >>> table = table.select(a=table.a.str.parse_int()) >>> table.schema.as_dict() - {'a': } + {'a': INT} >>> pw.debug.compute_and_print(table, include_id=False) a -5 @@ -765,13 +754,13 @@ def parse_int(self, optional: bool = False) -> expr.ColumnExpression: 200 """ return expr.MethodCallExpression( - [ + ( ( - str, - Optional[int] if optional else int, + dt.STR, + dt.Optional(dt.INT) if optional else dt.INT, lambda x: api.Expression.parse_int(x, optional), - ) - ], + ), + ), "str.parse_int", self._expression, ) @@ -788,10 +777,10 @@ def parse_float(self, optional: bool = False) -> expr.ColumnExpression: >>> df = pd.DataFrame({"a": ["-5", "0.1", "200.999"]}, dtype=str) >>> table = pw.debug.table_from_pandas(df) >>> table.schema.as_dict() - {'a': } + {'a': STR} >>> table = table.select(a=table.a.str.parse_float()) >>> table.schema.as_dict() - {'a': } + {'a': FLOAT} >>> pw.debug.compute_and_print(table, include_id=False) a -5.0 @@ -799,13 +788,13 @@ def parse_float(self, optional: bool = False) -> expr.ColumnExpression: 200.999 """ return expr.MethodCallExpression( - [ + ( ( - str, - Optional[float] if optional else float, + dt.STR, + dt.Optional(dt.FLOAT) if optional else dt.FLOAT, lambda x: api.Expression.parse_float(x, optional), - ) - ], + ), + ), "str.parse_float", self._expression, ) @@ -842,7 +831,7 @@ def parse_bool( >>> df = pd.DataFrame({"a": ["0", "TRUE", "on"]}, dtype=str) >>> table = pw.debug.table_from_pandas(df) >>> table.schema.as_dict() - {'a': } + {'a': STR} >>> pw.debug.compute_and_print(table, include_id=False) a 0 @@ -850,7 +839,7 @@ def parse_bool( on >>> table = table.select(a=table.a.str.parse_bool()) >>> table.schema.as_dict() - {'a': } + {'a': BOOL} >>> pw.debug.compute_and_print(table, include_id=False) a False @@ -861,15 +850,15 @@ def parse_bool( lowercase_false_values = [s.lower() for s in false_values] return expr.MethodCallExpression( - [ + ( ( - str, - Optional[bool] if optional else bool, + dt.STR, + dt.Optional(dt.BOOL) if optional else dt.BOOL, lambda x: api.Expression.parse_bool( x, lowercase_true_values, lowercase_false_values, optional ), - ) - ], + ), + ), "str.parse_bool", self._expression, ) diff --git a/python/pathway/internals/graph_runner/expression_evaluator.py b/python/pathway/internals/graph_runner/expression_evaluator.py index a480636d..b93b2e9f 100644 --- a/python/pathway/internals/graph_runner/expression_evaluator.py +++ b/python/pathway/internals/graph_runner/expression_evaluator.py @@ -7,7 +7,6 @@ from functools import cached_property from typing import ( TYPE_CHECKING, - Any, Callable, ClassVar, Dict, @@ -16,19 +15,12 @@ Tuple, Type, cast, - get_origin, ) import pathway.internals.column as clmn import pathway.internals.expression as expr from pathway.internals import api, asynchronous -from pathway.internals.dtype import ( - DType, - NoneType, - dtype_equivalence, - is_optional, - unoptionalize_pair, -) +from pathway.internals import dtype as dt from pathway.internals.expression_printer import get_expression_info from pathway.internals.expression_visitor import ExpressionVisitor from pathway.internals.graph_runner.scope_context import ScopeContext @@ -223,6 +215,7 @@ def eval_binary_op( if ( dtype := common_dtype_in_binary_operator(left_dtype, right_dtype) ) is not None: + dtype = dt.wrap(dtype) left_dtype = dtype right_dtype = dtype left_expression = expr.CastExpression(dtype, left_expression) @@ -242,7 +235,7 @@ def eval_binary_op( ) is not None: return result_expression - left_dtype_unoptionalized, right_dtype_unoptionalized = unoptionalize_pair( + left_dtype_unoptionalized, right_dtype_unoptionalized = dt.unoptionalize_pair( left_dtype, right_dtype, ) @@ -260,8 +253,8 @@ def eval_binary_op( expression_info = get_expression_info(expression) if ( - get_origin(left_dtype) == get_origin(Tuple) - and get_origin(right_dtype) == get_origin(Tuple) + isinstance(left_dtype, dt.Tuple) + and isinstance(right_dtype, dt.Tuple) and expression._operator in tuple_handling_operators ): warnings.warn( @@ -351,10 +344,10 @@ def eval_cast( target_type = expression._return_type if ( - dtype_equivalence(target_type, source_type) - or dtype_equivalence(DType(Optional[source_type]), target_type) - or (source_type == NoneType and is_optional(target_type)) - or target_type == Any + dt.dtype_equivalence(target_type, source_type) + or dt.dtype_equivalence(dt.Optional(source_type), target_type) + or (source_type == dt.NONE and isinstance(target_type, dt.Optional)) + or target_type == dt.ANY ): return arg # then cast is noop if ( @@ -383,8 +376,10 @@ def eval_coalesce( dtype = self.expression_type(expression) args: List[api.Expression] = [] for expr_arg in expression._args: - if not is_optional(dtype) and is_optional(expr_arg._dtype): - arg_dtype = DType(Optional[dtype]) + if not isinstance(dtype, dt.Optional) and isinstance( + expr_arg._dtype, dt.Optional + ): + arg_dtype = dt.Optional(dtype) else: arg_dtype = dtype arg_expr = expr.CastExpression(arg_dtype, expr_arg) diff --git a/python/pathway/internals/graph_runner/row_transformer_operator_handler.py b/python/pathway/internals/graph_runner/row_transformer_operator_handler.py index 753d7fdb..fd9fe06e 100644 --- a/python/pathway/internals/graph_runner/row_transformer_operator_handler.py +++ b/python/pathway/internals/graph_runner/row_transformer_operator_handler.py @@ -235,7 +235,7 @@ class RowReference: proper indexes understood by backend. """ - id: api.Pointer + id: api.BasePointer _context: api.Context _class_arg: rt.ClassArgMeta @@ -249,7 +249,7 @@ def __init__( context: api.Context, mapping: TransformerColumnIndexMapping, operator_id: int, - id: api.Pointer, + id: api.BasePointer, ): self.id = id self._context = context @@ -261,7 +261,7 @@ def __init__( def transformer(self): return TransformerReference(self) - def _with_class_arg(self, class_arg: rt.ClassArgMeta, id: api.Pointer): + def _with_class_arg(self, class_arg: rt.ClassArgMeta, id: api.BasePointer): return RowReference( class_arg, self._context, self._mapping, self._operator_id, id ) diff --git a/python/pathway/internals/operator_mapping.py b/python/pathway/internals/operator_mapping.py index 82e1156d..0358dce5 100644 --- a/python/pathway/internals/operator_mapping.py +++ b/python/pathway/internals/operator_mapping.py @@ -1,36 +1,27 @@ from typing import Any, Callable, Mapping, Optional, Tuple -import numpy as np - from pathway.internals import api -from pathway.internals.api import _TYPES_TO_ENGINE_MAPPING, Pointer -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration -from pathway.internals.dtype import ( - DType, - NoneType, - is_optional, - types_lca, - unoptionalize, -) +from pathway.internals import dtype as dt +from pathway.internals.api import _TYPES_TO_ENGINE_MAPPING from pathway.internals.shadows import operator UnaryOperator = Callable[[Any], Any] ApiUnaryOperator = Callable[[api.Expression], api.Expression] UnaryOperatorMapping = Mapping[ - Tuple[UnaryOperator, Any], - type, + Tuple[UnaryOperator, dt.DType], + dt.DType, ] BinaryOperator = Callable[[Any, Any], Any] ApiBinaryOperator = Callable[[api.Expression, api.Expression], api.Expression] BinaryOperatorMapping = Mapping[ - Tuple[BinaryOperator, Any, Any], - type, + Tuple[BinaryOperator, dt.DType, dt.DType], + dt.DType, ] OptionalMapping = Mapping[ BinaryOperator, - Tuple[type, ApiBinaryOperator], + Tuple[dt.DType, ApiBinaryOperator], ] _unary_operators_to_engine: Mapping[UnaryOperator, api.UnaryOperator] = { @@ -61,10 +52,10 @@ } _unary_operators_mapping: UnaryOperatorMapping = { - (operator.inv, bool): bool, - (operator.neg, int): int, - (operator.neg, float): float, - (operator.neg, Duration): Duration, + (operator.inv, dt.BOOL): dt.BOOL, + (operator.neg, dt.INT): dt.INT, + (operator.neg, dt.FLOAT): dt.FLOAT, + (operator.neg, dt.DURATION): dt.DURATION, } @@ -74,7 +65,7 @@ def get_unary_operators_mapping(op, operand_dtype, default=None): def get_unary_expression(expr, op, expr_dtype, default=None): op_engine = _unary_operators_to_engine.get(op) - expr_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(expr_dtype) + expr_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(dt.normalize_dtype(expr_dtype)) if op_engine is None or expr_dtype_engine is None: return default expression = api.Expression.unary_expression(expr, op_engine, expr_dtype_engine) @@ -82,96 +73,96 @@ def get_unary_expression(expr, op, expr_dtype, default=None): _binary_operators_mapping: BinaryOperatorMapping = { - (operator.and_, bool, bool): bool, - (operator.or_, bool, bool): bool, - (operator.xor, bool, bool): bool, - (operator.eq, int, int): bool, - (operator.ne, int, int): bool, - (operator.lt, int, int): bool, - (operator.le, int, int): bool, - (operator.gt, int, int): bool, - (operator.ge, int, int): bool, - (operator.eq, bool, bool): bool, - (operator.ne, bool, bool): bool, - (operator.lt, bool, bool): bool, - (operator.le, bool, bool): bool, - (operator.gt, bool, bool): bool, - (operator.ge, bool, bool): bool, - (operator.add, int, int): int, - (operator.sub, int, int): int, - (operator.mul, int, int): int, - (operator.floordiv, int, int): int, - (operator.truediv, int, int): float, - (operator.mod, int, int): int, - (operator.pow, int, int): int, - (operator.lshift, int, int): int, - (operator.rshift, int, int): int, - (operator.and_, int, int): int, - (operator.or_, int, int): int, - (operator.xor, int, int): int, - (operator.eq, float, float): bool, - (operator.ne, float, float): bool, - (operator.lt, float, float): bool, - (operator.le, float, float): bool, - (operator.gt, float, float): bool, - (operator.ge, float, float): bool, - (operator.add, float, float): float, - (operator.sub, float, float): float, - (operator.mul, float, float): float, - (operator.floordiv, float, float): float, - (operator.truediv, float, float): float, - (operator.mod, float, float): float, - (operator.pow, float, float): float, - (operator.eq, str, str): bool, - (operator.ne, str, str): bool, - (operator.lt, str, str): bool, - (operator.le, str, str): bool, - (operator.gt, str, str): bool, - (operator.ge, str, str): bool, - (operator.add, str, str): str, - (operator.mul, str, int): str, - (operator.mul, int, str): str, - (operator.eq, Pointer, Pointer): bool, - (operator.ne, Pointer, Pointer): bool, - (operator.lt, Pointer, Pointer): bool, - (operator.le, Pointer, Pointer): bool, - (operator.gt, Pointer, Pointer): bool, - (operator.ge, Pointer, Pointer): bool, - (operator.eq, DateTimeNaive, DateTimeNaive): bool, - (operator.ne, DateTimeNaive, DateTimeNaive): bool, - (operator.lt, DateTimeNaive, DateTimeNaive): bool, - (operator.le, DateTimeNaive, DateTimeNaive): bool, - (operator.gt, DateTimeNaive, DateTimeNaive): bool, - (operator.ge, DateTimeNaive, DateTimeNaive): bool, - (operator.sub, DateTimeNaive, DateTimeNaive): Duration, - (operator.add, DateTimeNaive, Duration): DateTimeNaive, - (operator.sub, DateTimeNaive, Duration): DateTimeNaive, - (operator.eq, DateTimeUtc, DateTimeUtc): bool, - (operator.ne, DateTimeUtc, DateTimeUtc): bool, - (operator.lt, DateTimeUtc, DateTimeUtc): bool, - (operator.le, DateTimeUtc, DateTimeUtc): bool, - (operator.gt, DateTimeUtc, DateTimeUtc): bool, - (operator.ge, DateTimeUtc, DateTimeUtc): bool, - (operator.sub, DateTimeUtc, DateTimeUtc): Duration, - (operator.add, DateTimeUtc, Duration): DateTimeUtc, - (operator.sub, DateTimeUtc, Duration): DateTimeUtc, - (operator.eq, Duration, Duration): bool, - (operator.ne, Duration, Duration): bool, - (operator.lt, Duration, Duration): bool, - (operator.le, Duration, Duration): bool, - (operator.gt, Duration, Duration): bool, - (operator.ge, Duration, Duration): bool, - (operator.add, Duration, Duration): Duration, - (operator.add, Duration, DateTimeNaive): DateTimeNaive, - (operator.add, Duration, DateTimeUtc): DateTimeUtc, - (operator.sub, Duration, Duration): Duration, - (operator.mul, Duration, int): Duration, - (operator.mul, int, Duration): Duration, - (operator.floordiv, Duration, int): Duration, - (operator.floordiv, Duration, Duration): int, - (operator.truediv, Duration, Duration): float, - (operator.mod, Duration, Duration): Duration, - (operator.matmul, np.ndarray, np.ndarray): np.ndarray, + (operator.and_, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.or_, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.xor, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.eq, dt.INT, dt.INT): dt.BOOL, + (operator.ne, dt.INT, dt.INT): dt.BOOL, + (operator.lt, dt.INT, dt.INT): dt.BOOL, + (operator.le, dt.INT, dt.INT): dt.BOOL, + (operator.gt, dt.INT, dt.INT): dt.BOOL, + (operator.ge, dt.INT, dt.INT): dt.BOOL, + (operator.eq, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.ne, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.lt, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.le, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.gt, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.ge, dt.BOOL, dt.BOOL): dt.BOOL, + (operator.add, dt.INT, dt.INT): dt.INT, + (operator.sub, dt.INT, dt.INT): dt.INT, + (operator.mul, dt.INT, dt.INT): dt.INT, + (operator.floordiv, dt.INT, dt.INT): dt.INT, + (operator.truediv, dt.INT, dt.INT): dt.FLOAT, + (operator.mod, dt.INT, dt.INT): dt.INT, + (operator.pow, dt.INT, dt.INT): dt.INT, + (operator.lshift, dt.INT, dt.INT): dt.INT, + (operator.rshift, dt.INT, dt.INT): dt.INT, + (operator.and_, dt.INT, dt.INT): dt.INT, + (operator.or_, dt.INT, dt.INT): dt.INT, + (operator.xor, dt.INT, dt.INT): dt.INT, + (operator.eq, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.ne, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.lt, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.le, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.gt, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.ge, dt.FLOAT, dt.FLOAT): dt.BOOL, + (operator.add, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.sub, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.mul, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.floordiv, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.truediv, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.mod, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.pow, dt.FLOAT, dt.FLOAT): dt.FLOAT, + (operator.eq, dt.STR, dt.STR): dt.BOOL, + (operator.ne, dt.STR, dt.STR): dt.BOOL, + (operator.lt, dt.STR, dt.STR): dt.BOOL, + (operator.le, dt.STR, dt.STR): dt.BOOL, + (operator.gt, dt.STR, dt.STR): dt.BOOL, + (operator.ge, dt.STR, dt.STR): dt.BOOL, + (operator.add, dt.STR, dt.STR): dt.STR, + (operator.mul, dt.STR, dt.INT): dt.STR, + (operator.mul, dt.INT, dt.STR): dt.STR, + (operator.eq, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.ne, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.lt, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.le, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.gt, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.ge, dt.POINTER, dt.POINTER): dt.BOOL, + (operator.eq, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.ne, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.lt, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.le, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.gt, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.ge, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.BOOL, + (operator.sub, dt.DATE_TIME_NAIVE, dt.DATE_TIME_NAIVE): dt.DURATION, + (operator.add, dt.DATE_TIME_NAIVE, dt.DURATION): dt.DATE_TIME_NAIVE, + (operator.sub, dt.DATE_TIME_NAIVE, dt.DURATION): dt.DATE_TIME_NAIVE, + (operator.eq, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.ne, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.lt, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.le, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.gt, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.ge, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.BOOL, + (operator.sub, dt.DATE_TIME_UTC, dt.DATE_TIME_UTC): dt.DURATION, + (operator.add, dt.DATE_TIME_UTC, dt.DURATION): dt.DATE_TIME_UTC, + (operator.sub, dt.DATE_TIME_UTC, dt.DURATION): dt.DATE_TIME_UTC, + (operator.eq, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.ne, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.lt, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.le, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.gt, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.ge, dt.DURATION, dt.DURATION): dt.BOOL, + (operator.add, dt.DURATION, dt.DURATION): dt.DURATION, + (operator.add, dt.DURATION, dt.DATE_TIME_NAIVE): dt.DATE_TIME_NAIVE, + (operator.add, dt.DURATION, dt.DATE_TIME_UTC): dt.DATE_TIME_UTC, + (operator.sub, dt.DURATION, dt.DURATION): dt.DURATION, + (operator.mul, dt.DURATION, dt.INT): dt.DURATION, + (operator.mul, dt.INT, dt.DURATION): dt.DURATION, + (operator.floordiv, dt.DURATION, dt.INT): dt.DURATION, + (operator.floordiv, dt.DURATION, dt.DURATION): dt.INT, + (operator.truediv, dt.DURATION, dt.DURATION): dt.FLOAT, + (operator.mod, dt.DURATION, dt.DURATION): dt.DURATION, + (operator.matmul, dt.Array(), dt.Array()): dt.Array(), } tuple_handling_operators = { @@ -184,14 +175,16 @@ def get_unary_expression(expr, op, expr_dtype, default=None): } -def get_binary_operators_mapping(op, left, right, default=None) -> DType: - return DType(_binary_operators_mapping.get((op, left, right), default)) +def get_binary_operators_mapping(op, left, right, default=None): + return _binary_operators_mapping.get( + (op, dt.normalize_dtype(left), dt.normalize_dtype(right)), default + ) def get_binary_expression(left, right, op, left_dtype, right_dtype, default=None): op_engine = _binary_operators_to_engine.get(op) - left_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(left_dtype) - right_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(right_dtype) + left_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(dt.normalize_dtype(left_dtype)) + right_dtype_engine = _TYPES_TO_ENGINE_MAPPING.get(dt.normalize_dtype(right_dtype)) if op_engine is None or left_dtype_engine is None or right_dtype_engine is None: return default @@ -202,26 +195,30 @@ def get_binary_expression(left, right, op, left_dtype, right_dtype, default=None _binary_operators_mapping_optionals: OptionalMapping = { - operator.eq: (bool, api.Expression.eq), - operator.ne: (bool, api.Expression.ne), + operator.eq: (dt.BOOL, api.Expression.eq), + operator.ne: (dt.BOOL, api.Expression.ne), } def get_binary_operators_mapping_optionals(op, left, right, default=None): - if left == right or left == NoneType or right == NoneType: + if left == right or left == dt.NONE or right == dt.NONE: return _binary_operators_mapping_optionals.get(op, default) else: return default def get_cast_operators_mapping( - expr: api.Expression, source_type: DType, target_type: DType, default=None + expr: api.Expression, source_type: dt.DType, target_type: dt.DType, default=None ) -> Optional[api.Expression]: - source_type_engine = _TYPES_TO_ENGINE_MAPPING.get(unoptionalize(source_type)) - target_type_engine = _TYPES_TO_ENGINE_MAPPING.get(unoptionalize(target_type)) + source_type_engine = _TYPES_TO_ENGINE_MAPPING.get( + dt.normalize_dtype(dt.unoptionalize(source_type)) + ) + target_type_engine = _TYPES_TO_ENGINE_MAPPING.get( + dt.normalize_dtype(dt.unoptionalize(target_type)) + ) if source_type_engine is None or target_type_engine is None: return default - if is_optional(source_type) and is_optional(target_type): + if isinstance(source_type, dt.Optional) and isinstance(target_type, dt.Optional): fun = api.Expression.cast_optional else: fun = api.Expression.cast @@ -234,12 +231,14 @@ def get_cast_operators_mapping( def common_dtype_in_binary_operator( - left_dtype: DType, right_dtype: DType -) -> Optional[DType]: + left_dtype: dt.DType, right_dtype: dt.DType +) -> Optional[dt.DType]: if ( - left_dtype in [int, Optional[int]] and right_dtype in [float, Optional[float]] + left_dtype in [dt.INT, dt.Optional(dt.INT)] + and right_dtype in [dt.FLOAT, dt.Optional(dt.FLOAT)] ) or ( - left_dtype in [float, Optional[float]] and right_dtype in [int, Optional[int]] + left_dtype in [dt.FLOAT, dt.Optional(dt.FLOAT)] + and right_dtype in [dt.INT, dt.Optional(dt.INT)] ): - return types_lca(left_dtype, right_dtype) + return dt.types_lca(left_dtype, right_dtype) return None diff --git a/python/pathway/internals/reducers.py b/python/pathway/internals/reducers.py index 85e74b03..7d3a8019 100644 --- a/python/pathway/internals/reducers.py +++ b/python/pathway/internals/reducers.py @@ -3,15 +3,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List from warnings import warn import numpy as np from pathway.internals import api +from pathway.internals import dtype as dt from pathway.internals import expression as expr from pathway.internals.common import apply_with_type -from pathway.internals.dtype import DType, dtype_issubclass class UnaryReducer(ABC): @@ -24,11 +23,11 @@ def __init__(self, *, name: str): self.name = name @abstractmethod - def return_type(self, arg_type: DType) -> DType: + def return_type(self, arg_type: dt.DType) -> dt.DType: ... @abstractmethod - def engine_reducer(self, arg_type: DType) -> api.Reducer: + def engine_reducer(self, arg_type: dt.DType) -> api.Reducer: ... @@ -39,48 +38,50 @@ def __init__(self, *, name: str, engine_reducer: api.Reducer): super().__init__(name=name) self._engine_reducer = engine_reducer - def engine_reducer(self, arg_type: DType) -> api.Reducer: + def engine_reducer(self, arg_type: dt.DType) -> api.Reducer: return self._engine_reducer class FixedOutputUnaryReducer(UnaryReducerWithDefault): - output_type: DType + output_type: dt.DType - def __init__(self, *, output_type: DType, name: str, engine_reducer: api.Reducer): + def __init__( + self, *, output_type: dt.DType, name: str, engine_reducer: api.Reducer + ): super().__init__(name=name, engine_reducer=engine_reducer) self.output_type = output_type - def return_type(self, arg_type: DType) -> DType: + def return_type(self, arg_type: dt.DType) -> dt.DType: return self.output_type class TypePreservingUnaryReducer(UnaryReducerWithDefault): - def return_type(self, arg_type: DType) -> DType: + def return_type(self, arg_type: dt.DType) -> dt.DType: return arg_type class SumReducer(UnaryReducer): - def return_type(self, arg_type: DType) -> DType: - for allowed_dtype in [DType(float), DType(np.ndarray)]: - if dtype_issubclass(arg_type, allowed_dtype): + def return_type(self, arg_type: dt.DType) -> dt.DType: + for allowed_dtype in [dt.FLOAT, dt.Array()]: + if dt.dtype_issubclass(arg_type, allowed_dtype): return arg_type raise TypeError( f"Pathway does not support using reducer {self}" + f" on column of type {arg_type}.\n" ) - def engine_reducer(self, arg_type: DType) -> api.Reducer: - if arg_type == DType(int): + def engine_reducer(self, arg_type: dt.DType) -> api.Reducer: + if arg_type == dt.INT: return api.Reducer.INT_SUM - elif arg_type == DType(np.ndarray): + elif arg_type == dt.Array(): return api.Reducer.ARRAY_SUM else: return api.Reducer.FLOAT_SUM class TupleWrappingReducer(UnaryReducerWithDefault): - def return_type(self, arg_type: DType) -> DType: - return DType(List[arg_type]) # type: ignore[valid-type] + def return_type(self, arg_type: dt.DType) -> dt.DType: + return dt.List(arg_type) # type: ignore[valid-type] _min = TypePreservingUnaryReducer(name="min", engine_reducer=api.Reducer.MIN) @@ -91,10 +92,14 @@ def return_type(self, arg_type: DType) -> DType: ) _tuple = TupleWrappingReducer(name="tuple", engine_reducer=api.Reducer.TUPLE) _argmin = FixedOutputUnaryReducer( - output_type=DType(api.Pointer), name="argmin", engine_reducer=api.Reducer.ARG_MIN + output_type=dt.POINTER, + name="argmin", + engine_reducer=api.Reducer.ARG_MIN, ) _argmax = FixedOutputUnaryReducer( - output_type=DType(api.Pointer), name="argmax", engine_reducer=api.Reducer.ARG_MAX + output_type=dt.POINTER, + name="argmax", + engine_reducer=api.Reducer.ARG_MAX, ) _unique = TypePreservingUnaryReducer(name="unique", engine_reducer=api.Reducer.UNIQUE) _any = TypePreservingUnaryReducer(name="any", engine_reducer=api.Reducer.ANY) diff --git a/python/pathway/internals/row_transformer.py b/python/pathway/internals/row_transformer.py index 87593668..c4fe7d0c 100644 --- a/python/pathway/internals/row_transformer.py +++ b/python/pathway/internals/row_transformer.py @@ -4,21 +4,20 @@ from abc import ABC, abstractmethod from functools import cached_property, lru_cache -from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type import pathway import pathway.internals.row_transformer_table as tt +from pathway.internals import dtype as dt from pathway.internals import operator as op from pathway.internals import parse_graph, schema -from pathway.internals.api import BasePointer, ref_scalar +from pathway.internals.api import BasePointer, Pointer, ref_scalar from pathway.internals.column import MaterializedColumn, MethodColumn -from pathway.internals.dtype import DType from pathway.internals.schema import Schema, schema_from_types from pathway.internals.shadows import inspect if TYPE_CHECKING: from pathway import Table - from pathway.internals.api import Pointer from pathway.internals.graph_runner.row_transformer_operator_handler import ( RowReference, ) @@ -77,7 +76,7 @@ def _match_args_to_properties( @classmethod @lru_cache - def method_call_transformer(cls, n, dtype: DType): + def method_call_transformer(cls, n, dtype: dt.DType): arg_names = [f"arg_{i}" for i in range(n)] def func(self): @@ -194,6 +193,8 @@ def __init_subclass__( **{attr.output_name: attr.dtype for attr in cls._output_attributes.values()} ) if output is not Any and not schema.is_subschema(cls.output_schema, output): + print(output) + print(cls.output_schema) raise RuntimeError( f"output schema validation error, received {output.as_dict()} vs expected {cls.output_schema.as_dict()}" # noqa ) @@ -204,6 +205,7 @@ def __init_subclass__( class AbstractAttribute(ABC): is_method = False is_output = False + _dtype: Optional[dt.DType] = None class_arg: ClassArgMeta # lateinit by parent ClassArg def __init__(self, **params) -> None: @@ -211,7 +213,7 @@ def __init__(self, **params) -> None: self.params = params self.name = self.params.get("name", None) if "dtype" in self.params: - self._dtype = self.params["dtype"] + self._dtype = dt.wrap(self.params["dtype"]) def __set_name__(self, owner, name): assert self.name is None @@ -224,8 +226,8 @@ def to_transformer_column( pass @property - def dtype(self) -> DType: - return DType(Any) + def dtype(self) -> dt.DType: + return dt.ANY class ToBeComputedAttribute(AbstractAttribute, ABC): @@ -244,13 +246,13 @@ def output_name(self): return self.params.get("output_name", self.name) @cached_property - def dtype(self) -> DType: - if hasattr(self, "_dtype"): + def dtype(self) -> dt.DType: + if self._dtype is not None: return self._dtype else: assert self.func is not None sig = inspect.signature(self.func) - return sig.return_annotation + return dt.wrap(sig.return_annotation) class AbstractOutputAttribute(ToBeComputedAttribute, ABC): @@ -289,8 +291,8 @@ def to_output_column(self, universe: Universe): return MethodColumn(self.dtype, universe) @cached_property - def dtype(self) -> DType: - return DType(Callable[..., super().dtype]) + def dtype(self) -> dt.DType: + return dt.wrap(Callable[..., super().dtype]) class InputAttribute(AbstractAttribute): diff --git a/python/pathway/internals/schema.py b/python/pathway/internals/schema.py index 97cf4be0..f38a1137 100644 --- a/python/pathway/internals/schema.py +++ b/python/pathway/internals/schema.py @@ -22,9 +22,8 @@ import numpy as np import pandas as pd +from pathway.internals import dtype as dt from pathway.internals import trace -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration -from pathway.internals.dtype import DType, NoneType, dtype_issubclass from pathway.internals.runtime_type_check import runtime_type_check if TYPE_CHECKING: @@ -48,32 +47,34 @@ def schema_from_columns( return _schema_builder(_name, __dict) -def _type_converter(series): +def _type_converter(series) -> dt.DType: + if series.apply(lambda x: isinstance(x, (tuple, list))).all(): + return dt.ANY_TUPLE if (series.isna() | series.isnull()).all(): - return NoneType + return dt.NONE if (series.apply(lambda x: isinstance(x, np.ndarray))).all(): - return np.ndarray + return dt.Array() if pd.api.types.is_integer_dtype(series.dtype): - ret_type: Type = int + ret_type: dt.DType = dt.INT elif pd.api.types.is_float_dtype(series.dtype): - ret_type = float + ret_type = dt.FLOAT elif pd.api.types.is_bool_dtype(series.dtype): - ret_type = bool + ret_type = dt.BOOL elif pd.api.types.is_string_dtype(series.dtype): - ret_type = str + ret_type = dt.STR elif pd.api.types.is_datetime64_ns_dtype(series.dtype): if series.dt.tz is None: - ret_type = DateTimeNaive + ret_type = dt.DATE_TIME_NAIVE else: - ret_type = DateTimeUtc + ret_type = dt.DATE_TIME_UTC elif pd.api.types.is_timedelta64_dtype(series.dtype): - ret_type = Duration + ret_type = dt.DURATION elif pd.api.types.is_object_dtype(series.dtype): - ret_type = Any # type: ignore + ret_type = dt.ANY # type: ignore else: - ret_type = Any # type: ignore + ret_type = dt.ANY # type: ignore if series.isna().any() or series.isnull().any(): - return Optional[ret_type] + return dt.Optional(ret_type) else: return ret_type @@ -112,7 +113,7 @@ def schema_from_types( >>> import pathway as pw >>> s = pw.schema_from_types(foo=int, bar=str) >>> s.as_dict() - {'foo': , 'bar': } + {'foo': INT, 'bar': STR} >>> issubclass(s, pw.Schema) True """ @@ -157,8 +158,8 @@ def _create_column_definitions(schema: SchemaMetaclass): columns = {} for name, annotation in annotations.items(): - dtype = DType(annotation) - column = fields.pop(name, column_definition(dtype=dtype)) + coldtype = dt.wrap(annotation) + column = fields.pop(name, column_definition(dtype=coldtype)) if not isinstance(column, ColumnDefinition): raise ValueError( @@ -168,8 +169,11 @@ def _create_column_definitions(schema: SchemaMetaclass): name = column.name or name if column.dtype is None: - column = dataclasses.replace(column, dtype=dtype) - elif dtype != column.dtype: + column = dataclasses.replace(column, dtype=coldtype) + column = dataclasses.replace(column, dtype=dt.wrap(column.dtype)) + if coldtype != column.dtype: + print(coldtype) + print(column.dtype) raise TypeError( f"type annotation of column `{name}` does not match column definition" ) @@ -191,7 +195,7 @@ class SchemaProperties: class SchemaMetaclass(type): __columns__: Dict[str, ColumnDefinition] __properties__: SchemaProperties - __types__: Dict[str, DType] + __types__: Dict[str, dt.DType] @trace.trace_user_frame def __init__(self, *args, append_only=False, **kwargs) -> None: @@ -200,7 +204,8 @@ def __init__(self, *args, append_only=False, **kwargs) -> None: self.__columns__ = _create_column_definitions(self) self.__properties__ = SchemaProperties(append_only=append_only) self.__types__ = { - name: cast(DType, column.dtype) for name, column in self.__columns__.items() + name: cast(dt.DType, column.dtype) + for name, column in self.__columns__.items() } def __or__(self, other: Type[Schema]) -> Type[Schema]: # type: ignore @@ -243,6 +248,17 @@ def keys(self) -> KeysView[str]: def values(self) -> ValuesView[Any]: return self.__types__.values() + def update_types(self, **kwargs) -> Type[Schema]: + columns: Dict[str, ColumnDefinition] = dict(self.__columns__) + for name, dtype in kwargs.items(): + if name not in columns: + raise ValueError( + "Schema.update_types() argument name has to be an existing column name." + ) + columns[name] = dataclasses.replace(columns[name], dtype=dt.wrap(dtype)) + + return schema_builder(columns=columns, properties=self.__properties__) + def __getitem__(self, name): return self.__types__[name] @@ -299,14 +315,14 @@ class Schema(metaclass=SchemaMetaclass): ... 3 8 Alice cat ... 4 7 Bob dog''') >>> t1.schema.as_dict() - {'age': , 'owner': , 'pet': } + {'age': INT, 'owner': STR, 'pet': STR} >>> issubclass(t1.schema, pw.Schema) True >>> class NewSchema(pw.Schema): ... foo: int >>> SchemaSum = NewSchema | t1.schema >>> SchemaSum.as_dict() - {'age': , 'owner': , 'pet': , 'foo': } + {'age': INT, 'owner': STR, 'pet': STR, 'foo': INT} """ def __init_subclass__(cls, /, append_only: bool = False, **kwargs) -> None: @@ -326,7 +342,7 @@ def is_subschema(left: Type[Schema], right: Type[Schema]): if left.keys() != right.keys(): return False for k in left.keys(): - if not dtype_issubclass(left[k], right[k]): + if not dt.dtype_issubclass(left[k], right[k]): return False return True @@ -343,12 +359,15 @@ def __repr__(self): class ColumnDefinition: primary_key: bool = False default_value: Optional[Any] = _no_default_value_marker - dtype: Optional[DType] = DType(Any) + dtype: Optional[dt.DType] = dt.ANY name: Optional[str] = None def has_default_value(self) -> bool: return self.default_value != _no_default_value_marker + def __post_init__(self): + assert self.dtype is None or isinstance(self.dtype, dt.DType) + def column_definition( *, @@ -380,11 +399,15 @@ def column_definition( ... timestamp: str = pw.column_definition(name="@timestamp") ... data: str >>> NewSchema.as_dict() - {'key': , '@timestamp': , 'data': } + {'key': INT, '@timestamp': STR, 'data': STR} """ + from pathway.internals import dtype as dt return ColumnDefinition( - dtype=dtype, primary_key=primary_key, default_value=default_value, name=name + dtype=dt.wrap(dtype) if dtype is not None else None, + primary_key=primary_key, + default_value=default_value, + name=name, ) @@ -412,7 +435,7 @@ def schema_builder( ... 'data': pw.column_definition(dtype=int, default_value=0) ... }, name="my_schema") >>> schema.as_dict() - {'key': , 'data': } + {'key': INT, 'data': INT} """ if name is None: diff --git a/python/pathway/internals/table.py b/python/pathway/internals/table.py index 8cd4ce8d..bef07109 100644 --- a/python/pathway/internals/table.py +++ b/python/pathway/internals/table.py @@ -22,8 +22,8 @@ import pathway.internals.column as clmn import pathway.internals.expression as expr +from pathway.internals import dtype as dt from pathway.internals import groupby, thisclass, universes -from pathway.internals.api import Pointer from pathway.internals.arg_handlers import ( arg_handler, groupby_handler, @@ -37,7 +37,6 @@ table_to_datasink, ) from pathway.internals.desugaring import combine_args_kwargs, desugar -from pathway.internals.dtype import DType, is_pointer, types_lca, unoptionalize from pathway.internals.helpers import SetOnceProperty, StableSet from pathway.internals.ix import IxIndexer from pathway.internals.join import Joinable @@ -186,9 +185,9 @@ def schema(self) -> Type[Schema]: ... 7 | Bob | dog ... ''') >>> t1.schema.as_dict() - {'age': , 'owner': , 'pet': } + {'age': INT, 'owner': STR, 'pet': STR} >>> t1.schema['age'] - + INT """ return self._schema @@ -348,7 +347,7 @@ def concat_reindex(self, *tables: Table) -> Table: @trace_user_frame @staticmethod @runtime_type_check - def empty(**kwargs: DType) -> Table: + def empty(**kwargs: dt.DType) -> Table: """Creates an empty table with a schema specified by kwargs. Args: @@ -506,7 +505,7 @@ def filter(self, filter_expression: expr.ColumnExpression) -> Table: 7 | 0 """ filter_type = eval_type(filter_expression) - if filter_type != DType(bool): + if filter_type != dt.BOOL: raise TypeError( f"Filter argument of Table.filter() has to be bool, found {filter_type}." ) @@ -516,7 +515,7 @@ def filter(self, filter_expression: expr.ColumnExpression) -> Table: ) is not None and filter_col.table == self: name = filter_col.name dtype = self._columns[name].dtype - ret = ret.update_types(**{name: unoptionalize(dtype)}) + ret = ret.update_types(**{name: dt.unoptionalize(dtype)}) return ret @contextualized_operator @@ -911,7 +910,7 @@ def concat(self, *others: Table) -> Table: schema = { key: functools.reduce( - types_lca, [other.schema[key] for other in others], self.schema[key] + dt.types_lca, [other.schema[key] for other in others], self.schema[key] ) for key in self.keys() } @@ -999,7 +998,8 @@ def update_cells(self, other: Table) -> Table: ) schema = { - key: types_lca(self.schema[key], other.schema[key]) for key in other.keys() + key: dt.types_lca(self.schema[key], other.schema[key]) + for key in other.keys() } return Table._update_cells( self.cast_to_types(**schema), other.cast_to_types(**schema) @@ -1079,9 +1079,9 @@ def update_rows(self, other: Table) -> Table: raise ValueError( "Columns do not match between argument of Table.update_rows() and the updated table." ) - schema = { - key: types_lca(self.schema[key], other.schema[key]) for key in self.keys() + key: dt.types_lca(self.schema[key], other.schema[key]) + for key in self.keys() } return Table._update_rows( self.cast_to_types(**schema), other.cast_to_types(**schema) @@ -1242,7 +1242,7 @@ def _with_new_index( ) -> Table: self._validate_expression(new_index) index_type = eval_type(new_index) - if not is_pointer(index_type): + if not isinstance(index_type, dt.Pointer): raise TypeError( f"Pathway supports reindexing Tables with Pointer type only. The type used was {index_type}." ) @@ -1637,8 +1637,8 @@ def _sort_experimental( ) -> Table: if not isinstance(instance, expr.ColumnExpression): instance = expr.ColumnConstExpression(instance) - prev_column = clmn.MaterializedColumn(DType(Optional[Pointer]), self._universe) - next_column = clmn.MaterializedColumn(DType(Optional[Pointer]), self._universe) + prev_column = clmn.MaterializedColumn(dt.Optional(dt.POINTER), self._universe) + next_column = clmn.MaterializedColumn(dt.Optional(dt.POINTER), self._universe) context = clmn.SortingContext( self._universe, self._eval(key), @@ -1891,8 +1891,8 @@ def dtypes(self): ... 7 | Bob | dog ... ''') >>> t1.dtypes() - {'age': , 'owner': , 'pet': } + {'age': INT, 'owner': STR, 'pet': STR} >>> t1.schema['age'] - + INT """ return self.schema.as_dict() diff --git a/python/pathway/internals/type_interpreter.py b/python/pathway/internals/type_interpreter.py index e9760e8e..6f44e53b 100644 --- a/python/pathway/internals/type_interpreter.py +++ b/python/pathway/internals/type_interpreter.py @@ -5,24 +5,11 @@ import datetime import warnings from dataclasses import dataclass -from typing import Any, Iterable, Optional, Tuple, TypeVar, get_args, get_origin - -import numpy as np +from types import EllipsisType +from typing import Any, Iterable, Optional, TypeVar +from pathway.internals import dtype as dt from pathway.internals import expression as expr -from pathway.internals.api import Pointer -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration -from pathway.internals.dtype import ( - DType, - dtype_issubclass, - is_optional, - is_pointer, - is_tuple_like, - normalize_tuple_like, - types_lca, - unoptionalize, - unoptionalize_pair, -) from pathway.internals.expression_printer import get_expression_info from pathway.internals.expression_visitor import IdentityTransform from pathway.internals.operator_mapping import ( @@ -70,11 +57,12 @@ def _eval_column_val( self, expression: expr.ColumnReference, state: Optional[TypeInterpreterState] = None, - ) -> DType: + ) -> dt.DType: dtype = expression._column.dtype assert state is not None + assert isinstance(dtype, dt.DType) if state.check_colref_to_unoptionalize_from_colrefs(expression): - return unoptionalize(dtype) + return dt.unoptionalize(dtype) return dtype def eval_unary_op( @@ -116,11 +104,11 @@ def eval_binary_op( def _eval_binary_op( self, - left_dtype: DType, - right_dtype: DType, + left_dtype: dt.DType, + right_dtype: dt.DType, operator: Any, expression: expr.ColumnExpression, - ) -> DType: + ) -> dt.DType: original_left = left_dtype original_right = right_dtype if ( @@ -132,31 +120,37 @@ def _eval_binary_op( if ( dtype := get_binary_operators_mapping(operator, left_dtype, right_dtype) ) is not None: - return DType(dtype) + return dtype - left_dtype, right_dtype = unoptionalize_pair(left_dtype, right_dtype) + left_dtype, right_dtype = dt.unoptionalize_pair(left_dtype, right_dtype) if ( dtype_and_handler := get_binary_operators_mapping_optionals( operator, left_dtype, right_dtype ) ) is not None: - return DType(dtype_and_handler[0]) + return dtype_and_handler[0] if ( - get_origin(left_dtype) == get_origin(Tuple) - and get_origin(right_dtype) == get_origin(Tuple) + isinstance(left_dtype, dt.Tuple) + and isinstance(right_dtype, dt.Tuple) and operator in tuple_handling_operators ): - left_args = get_args(left_dtype) - right_args = get_args(right_dtype) - if len(left_args) == len(right_args): + left_args = left_dtype.args + right_args = right_dtype.args + if ( + isinstance(left_args, tuple) + and isinstance(right_args, tuple) + and len(left_args) == len(right_args) + ): results = tuple( self._eval_binary_op(left_arg, right_arg, operator, expression) for left_arg, right_arg in zip(left_args, right_args) + if not isinstance(left_arg, EllipsisType) + and not isinstance(right_arg, EllipsisType) ) assert all(result == results[0] for result in results) - return DType(results[0]) + return dt.wrap(results[0]) expression_info = get_expression_info(expression) raise TypeError( f"Pathway does not support using binary operator {operator.__name__}" @@ -175,12 +169,14 @@ def eval_const( type_ = type(expression._val) if isinstance(expression._val, datetime.datetime): if expression._val.tzinfo is None: - type_ = DateTimeNaive + dtype: dt.DType = dt.DATE_TIME_NAIVE else: - type_ = DateTimeUtc + dtype = dt.DATE_TIME_UTC elif isinstance(expression._val, datetime.timedelta): - type_ = Duration - return _wrap(expression, DType(type_)) + dtype = dt.DURATION + else: + dtype = dt.wrap(type_) + return _wrap(expression, dtype) def eval_reducer( self, @@ -194,7 +190,7 @@ def eval_reducer( def _eval_reducer( self, expression: expr.ReducerExpression, - ) -> DType: + ) -> dt.DType: [arg] = expression._args return expression._reducer.return_type(arg._dtype) @@ -214,7 +210,7 @@ def eval_count( **kwargs, ) -> expr.CountExpression: expression = super().eval_count(expression, state=state, **kwargs) - return _wrap(expression, DType(int)) + return _wrap(expression, dt.INT) def eval_apply( self, @@ -249,13 +245,16 @@ def eval_call( state: Optional[TypeInterpreterState] = None, **kwargs, ) -> expr.ColumnCallExpression: + dtype = expression._col_expr._column.dtype + assert isinstance(dtype, dt.Callable) expression = super().eval_call(expression, state=state, **kwargs) - arg_annots, ret_type = get_args(expression._col_expr._column.dtype) - if arg_annots is not Ellipsis: + arg_annots, ret_type = dtype.arg_types, dtype.return_type + if not isinstance(arg_annots, EllipsisType): arg_annots = arg_annots[1:] # ignoring self for arg_annot, arg in zip(arg_annots, expression._args, strict=True): # type: ignore arg_dtype = arg._dtype - assert dtype_issubclass(arg_dtype, arg_annot) + assert not isinstance(arg_annot, EllipsisType) + assert dt.dtype_issubclass(arg_dtype, arg_annot) return _wrap(expression, ret_type) def eval_ix( @@ -267,15 +266,15 @@ def eval_ix( expression = super().eval_ix(expression, state=state, **kwargs) key_dtype = expression._keys_expression._dtype if expression._optional: - if not dtype_issubclass(key_dtype, DType(Optional[Pointer])): + if not dt.dtype_issubclass(key_dtype, dt.Optional(dt.POINTER)): raise TypeError( f"Pathway supports indexing with Pointer type only. The type used was {key_dtype}." ) - if not is_optional(key_dtype): + if not isinstance(key_dtype, dt.Optional): return _wrap(expression, expression._column.dtype) - return _wrap(expression, DType(Optional[expression._column.dtype])) + return _wrap(expression, dt.Optional(expression._column.dtype)) else: - if not is_pointer(key_dtype): + if not isinstance(key_dtype, dt.Pointer): raise TypeError( f"Pathway supports indexing with Pointer type only. The type used was {key_dtype}." ) @@ -289,10 +288,12 @@ def eval_pointer( ) -> expr.PointerExpression: expression = super().eval_pointer(expression, state=state, **kwargs) arg_types = [arg._dtype for arg in expression._args] - if expression._optional and any(is_optional(arg) for arg in arg_types): - return _wrap(expression, DType(Optional[Pointer])) + if expression._optional and any( + isinstance(arg, dt.Optional) for arg in arg_types + ): + return _wrap(expression, dt.Optional(dt.POINTER)) else: - return _wrap(expression, DType(Pointer)) + return _wrap(expression, dt.POINTER) def eval_cast( self, @@ -322,17 +323,17 @@ def eval_coalesce( dtypes = [arg._dtype for arg in expression._args] ret_type = dtypes[0] non_optional_arg = False - for dt in dtypes: - ret_type = types_lca(dt, ret_type) - if not is_optional(dt): + for dtype in dtypes: + ret_type = dt.types_lca(dtype, ret_type) + if not isinstance(dtype, dt.Optional): # FIXME: do we want to be more radical and return now? # Maybe with a warning that some args are skipped? non_optional_arg = True - if ret_type is Any and any(dt is not Any for dt in dtypes): + if ret_type is dt.ANY and any(dtype is not dt.ANY for dtype in dtypes): raise TypeError( f"Cannot perform pathway.coalesce on columns of types {dtypes}." ) - ret_type = unoptionalize(ret_type) if non_optional_arg else ret_type + ret_type = dt.unoptionalize(ret_type) if non_optional_arg else ret_type return _wrap(expression, ret_type) def eval_require( @@ -350,7 +351,7 @@ def eval_require( ) val = self.eval_expression(expression._val, state=new_state, **kwargs) expression = expr.RequireExpression(val, *args) - ret_type = DType(Optional[val._dtype]) + ret_type = dt.Optional(val._dtype) return _wrap(expression, ret_type) def eval_not_none( @@ -360,7 +361,7 @@ def eval_not_none( **kwargs, ) -> expr.IsNotNoneExpression: ret = super().eval_not_none(expression, state=state, **kwargs) - return _wrap(ret, DType(bool)) + return _wrap(ret, dt.BOOL) def eval_none( self, @@ -369,7 +370,7 @@ def eval_none( **kwargs, ) -> expr.IsNoneExpression: ret = super().eval_none(expression, state=state, **kwargs) - return _wrap(ret, DType(bool)) + return _wrap(ret, dt.BOOL) def eval_ifelse( self, @@ -379,14 +380,14 @@ def eval_ifelse( ) -> expr.IfElseExpression: expression = super().eval_ifelse(expression, state=state, **kwargs) if_dtype = expression._if._dtype - if if_dtype != DType(bool): + if if_dtype != dt.BOOL: raise TypeError( f"First argument of pathway.if_else has to be bool, found {if_dtype}." ) then_dtype = expression._then._dtype else_dtype = expression._else._dtype - lca = types_lca(then_dtype, else_dtype) - if lca is Any: + lca = dt.types_lca(then_dtype, else_dtype) + if lca is dt.ANY: raise TypeError( f"Cannot perform pathway.if_else on columns of types {then_dtype} and {else_dtype}." ) @@ -400,7 +401,7 @@ def eval_make_tuple( ) -> expr.MakeTupleExpression: expression = super().eval_make_tuple(expression, state=state, **kwargs) dtypes = tuple(arg._dtype for arg in expression._args) - return _wrap(expression, DType(Tuple[dtypes])) + return _wrap(expression, dt.Tuple(*dtypes)) def eval_sequence_get( self, @@ -413,68 +414,68 @@ def eval_sequence_get( index_dtype = expression._index._dtype default_dtype = expression._default._dtype - if is_tuple_like(object_dtype): - object_dtype = normalize_tuple_like(object_dtype) - - if get_origin(object_dtype) != tuple and object_dtype not in [ - Any, - np.ndarray, + if not isinstance(object_dtype, (dt.Tuple, dt.List)) and object_dtype not in [ + dt.ANY, + dt.Array(), ]: raise TypeError(f"Object in {expression!r} has to be a sequence.") - if index_dtype != int: + if index_dtype != dt.INT: raise TypeError(f"Index in {expression!r} has to be an int.") - if object_dtype == np.ndarray: + if object_dtype == dt.Array(): warnings.warn( f"Object in {expression!r} is of type numpy.ndarray but its number of" + " dimensions is not known. Pathway cannot determine the return type" + " and will set Any as the return type. Please use " + "pathway.declare_type to set the correct return type." ) - return _wrap(expression, DType(Any)) + return _wrap(expression, dt.ANY) + if object_dtype == dt.ANY: + return _wrap(expression, dt.ANY) - dtypes = get_args(object_dtype) + if isinstance(object_dtype, dt.List): + if expression._check_if_exists: + return _wrap(expression, dt.Optional(object_dtype.wrapped)) + else: + return _wrap(expression, object_dtype.wrapped) + assert isinstance(object_dtype, dt.Tuple) + if object_dtype == dt.ANY_TUPLE: + return _wrap(expression, dt.ANY) - if object_dtype == Any or len(dtypes) == 0: - return _wrap(expression, DType(Any)) + assert not isinstance(object_dtype.args, EllipsisType) + dtypes = object_dtype.args if ( expression._const_index is None ): # no specified position, index is an Expression + assert isinstance(dtypes[0], dt.DType) return_dtype = dtypes[0] for dtype in dtypes[1:]: - return_dtype = types_lca(return_dtype, dtype) + if isinstance(dtype, dt.DType): + return_dtype = dt.types_lca(return_dtype, dtype) if expression._check_if_exists: - return_dtype = types_lca(return_dtype, default_dtype) + return_dtype = dt.types_lca(return_dtype, default_dtype) return _wrap(expression, return_dtype) try: - return_dtype = dtypes[expression._const_index] - if return_dtype == Ellipsis: - raise IndexError - return _wrap(expression, return_dtype) + try_ret = dtypes[expression._const_index] + return _wrap(expression, try_ret) except IndexError: - if dtypes[-1] == Ellipsis: - return_dtype = dtypes[-2] - if expression._check_if_exists: - return_dtype = types_lca(return_dtype, default_dtype) - return _wrap(expression, return_dtype) - else: - message = ( - f"Index {expression._const_index} out of range for a tuple of" - + f" type {object_dtype}." + message = ( + f"Index {expression._const_index} out of range for a tuple of" + + f" type {object_dtype}." + ) + if expression._check_if_exists: + expression_info = get_expression_info(expression) + warnings.warn( + message + + " It refers to the following expression:\n" + + expression_info + + "Consider using just the default value without .get()." ) - if expression._check_if_exists: - expression_info = get_expression_info(expression) - warnings.warn( - message - + " It refers to the following expression:\n" - + expression_info - + "Consider using just the default value without .get()." - ) - return _wrap(expression, default_dtype) - else: - raise IndexError(message) + return _wrap(expression, default_dtype) + else: + raise IndexError(message) def eval_method_call( self, @@ -485,7 +486,7 @@ def eval_method_call( expression = super().eval_method_call(expression, state=state, **kwargs) dtypes = tuple([arg._dtype for arg in expression._args]) if (dtypes_and_handler := expression.get_function(dtypes)) is not None: - return _wrap(expression, DType(dtypes_and_handler[1])) + return _wrap(expression, dt.wrap(dtypes_and_handler[1])) if len(dtypes) > 0: with_arguments = f" with arguments of type {dtypes[1:]}" @@ -503,7 +504,7 @@ def eval_unwrap( ) -> expr.UnwrapExpression: expression = super().eval_unwrap(expression, state=state, **kwargs) dtype = expression._expr._dtype - return _wrap(expression, unoptionalize(dtype)) + return _wrap(expression, dt.unoptionalize(dtype)) class JoinTypeInterpreter(TypeInterpreter): @@ -518,18 +519,18 @@ def _eval_column_val( self, expression: expr.ColumnReference, state: Optional[TypeInterpreterState] = None, - ) -> DType: + ) -> dt.DType: dtype = expression._column.dtype assert state is not None if (expression.table == self.left and self.optionalize_left) or ( expression.table == self.right and self.optionalize_right ): if not state.check_colref_to_unoptionalize_from_tables(expression): - return DType(Optional[dtype]) + return dt.Optional(dtype) return super()._eval_column_val(expression, state=state) -def eval_type(expression: expr.ColumnExpression) -> DType: +def eval_type(expression: expr.ColumnExpression) -> dt.DType: return ( TypeInterpreter() .eval_expression(expression, state=TypeInterpreterState()) @@ -540,7 +541,8 @@ def eval_type(expression: expr.ColumnExpression) -> DType: ColExprT = TypeVar("ColExprT", bound=expr.ColumnExpression) -def _wrap(expression: ColExprT, dtype: DType) -> ColExprT: +def _wrap(expression: ColExprT, dtype: dt.DType) -> ColExprT: assert not hasattr(expression, "_dtype") + assert isinstance(dtype, dt.DType) expression._dtype = dtype return expression diff --git a/python/pathway/io/_utils.py b/python/pathway/io/_utils.py index 7326d495..154a854e 100644 --- a/python/pathway/io/_utils.py +++ b/python/pathway/io/_utils.py @@ -7,7 +7,8 @@ import numpy as np import pathway.internals as pw -from pathway.internals import api, datetime_types +from pathway.internals import api +from pathway.internals import dtype as dt from pathway.internals._io_helpers import _form_value_fields from pathway.internals.api import ConnectorMode, PathwayType from pathway.internals.schema import ColumnDefinition, Schema, SchemaProperties @@ -35,10 +36,10 @@ PathwayType.FLOAT: float, PathwayType.STRING: str, PathwayType.ANY: Any, - PathwayType.POINTER: api.Pointer, - PathwayType.DATE_TIME_NAIVE: datetime_types.DateTimeNaive, - PathwayType.DATE_TIME_UTC: datetime_types.DateTimeUtc, - PathwayType.DURATION: datetime_types.Duration, + PathwayType.POINTER: api.BasePointer, + PathwayType.DATE_TIME_NAIVE: dt.DATE_TIME_NAIVE, + PathwayType.DATE_TIME_UTC: dt.DATE_TIME_UTC, + PathwayType.DURATION: dt.DURATION, PathwayType.ARRAY: np.ndarray, } @@ -143,7 +144,8 @@ def _compat_schema( if types is not None: for name, dtype in types.items(): columns[name] = dataclasses.replace( - columns[name], dtype=_PATHWAY_TYPE_MAPPING.get(dtype, Any) + columns[name], + dtype=dt.wrap(_PATHWAY_TYPE_MAPPING.get(dtype, Any)), ) if default_values is not None: for name, default_value in default_values.items(): diff --git a/python/pathway/io/http/_server.py b/python/pathway/io/http/_server.py index afea7180..32c3ed87 100644 --- a/python/pathway/io/http/_server.py +++ b/python/pathway/io/http/_server.py @@ -17,6 +17,7 @@ class RestServerSubject(io.python.ConnectorSubject): _host: str _port: int _loop: asyncio.AbstractEventLoop + _keep_queries: bool def __init__( self, @@ -26,6 +27,7 @@ def __init__( loop: asyncio.AbstractEventLoop, tasks: Dict[Any, Any], schema: Type[pw.Schema], + keep_queries: bool, format: str = "raw", ) -> None: super().__init__() @@ -35,6 +37,7 @@ def __init__( self._loop = loop self._tasks = tasks self._schema = schema + self._keep_queries = keep_queries self._format = format def run(self): @@ -71,7 +74,8 @@ async def handle(self, request: web.Request): self._add(id, data) response = await self._fetch_response(id, event) - self._remove(id, data) + if not self._keep_queries: + self._remove(id, data) return web.json_response(status=200, data=response) async def _fetch_response(self, id, event) -> Any: @@ -93,6 +97,7 @@ def rest_connector( route: str = "/", schema: Optional[Type[pw.Schema]] = None, autocommit_duration_ms=1500, + keep_queries: bool = False, ) -> Tuple[pw.Table, Callable]: """ Runs a lightweight HTTP server and inputs a collection from the HTTP endpoint, @@ -106,7 +111,11 @@ def rest_connector( host: TCP/IP host or a sequence of hosts for the created endpoint; port: port for the created endpoint; route: route which will be listened to by the web server; - schema: schema of the resulting table. + schema: schema of the resulting table; + autocommit_duration_ms: the maximum time between two commits. Every + autocommit_duration_ms milliseconds, the updates received by the connector are + committed and pushed into Pathway's computation graph; + keep_queries: whether to keep queries after processing; defaults to False. Returns: table: the table read; @@ -130,6 +139,7 @@ def rest_connector( loop=loop, tasks=tasks, schema=schema, + keep_queries=keep_queries, format=format, ), schema=schema, diff --git a/python/pathway/stdlib/ml/classifiers/_knn_via_lsh_flat.py b/python/pathway/stdlib/ml/classifiers/_knn_via_lsh_flat.py index 89bbfe32..36c3ed59 100644 --- a/python/pathway/stdlib/ml/classifiers/_knn_via_lsh_flat.py +++ b/python/pathway/stdlib/ml/classifiers/_knn_via_lsh_flat.py @@ -37,7 +37,7 @@ class DataPoint(pw.Schema): class Query(pw.Schema): - data: DataPoint + data: np.ndarray k: int diff --git a/python/pathway/stdlib/ml/smart_table_ops/_fuzzy_join.py b/python/pathway/stdlib/ml/smart_table_ops/_fuzzy_join.py index 682d5410..c6a6f771 100644 --- a/python/pathway/stdlib/ml/smart_table_ops/_fuzzy_join.py +++ b/python/pathway/stdlib/ml/smart_table_ops/_fuzzy_join.py @@ -6,7 +6,6 @@ from enum import IntEnum, auto from typing import Any, Callable, Optional -# TODO change to `import pathway as pw` when it is not imported as part of stdlib, OR move the whole file to stdlib import pathway.internals as pw from pathway.internals.fingerprints import fingerprint from pathway.internals.helpers import StableSet @@ -22,7 +21,7 @@ class Feature(pw.Schema): class Edge(pw.Schema): - node: Node + node: pw.Pointer[Node] feature: pw.Pointer[Feature] weight: float diff --git a/python/pathway/stdlib/temporal/_window.py b/python/pathway/stdlib/temporal/_window.py index df422e74..9a3ef1b4 100644 --- a/python/pathway/stdlib/temporal/_window.py +++ b/python/pathway/stdlib/temporal/_window.py @@ -8,6 +8,7 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import pathway.internals as pw +from pathway.internals import dtype as dt from pathway.internals.arg_handlers import arg_handler, windowby_handler from pathway.internals.desugaring import desugar from pathway.internals.join import validate_join_condition @@ -100,8 +101,6 @@ def _apply( key: pw.ColumnExpression, shard: Optional[pw.ColumnExpression], ) -> pw.GroupedTable: - target = self._compute_group_repr(table, key, shard) - if self.max_gap is not None: check_joint_types( { @@ -110,6 +109,7 @@ def _apply( } ) + target = self._compute_group_repr(table, key, shard) tmp = target.groupby(target._pw_window).reduce( _pw_window_start=pw.reducers.min(key), _pw_window_end=pw.reducers.max(key), @@ -316,20 +316,28 @@ def _apply( } ) - target = table.select(_pw_window=pw.apply(self._assign_windows, shard, key)) + from pathway.internals.type_interpreter import eval_type + + target = table.select( + _pw_window=pw.apply_with_type( + self._assign_windows, + dt.List( + dt.Tuple( + eval_type(shard), # type:ignore + eval_type(key), + eval_type(key), + ) + ), + shard, + key, + ) + ) target = target.flatten(target._pw_window, *table) target = target.with_columns( _pw_shard=pw.this._pw_window.get(0), _pw_window_start=pw.this._pw_window.get(1), _pw_window_end=pw.this._pw_window.get(2), ) - from pathway.internals.type_interpreter import eval_type - - target = target.update_types( - _pw_shard=eval_type(shard), # type: ignore - _pw_window_start=eval_type(key), - _pw_window_end=eval_type(key), - ) return target.groupby( target._pw_window, target._pw_shard, @@ -360,29 +368,45 @@ def _join( ) left_window = left.select( - _pw_window=pw.apply(self._assign_windows, None, left_time_expression) + _pw_window=pw.apply_with_type( + self._assign_windows, + dt.List( + dt.Tuple( + dt.NONE, + eval_type(left_time_expression), + eval_type(left_time_expression), + ) + ), + None, + left_time_expression, + ) ) left_window = left_window.flatten(left_window._pw_window, *left) left_window = left_window.with_columns( _pw_window_start=pw.this._pw_window.get(1), _pw_window_end=pw.this._pw_window.get(2), - ).update_types( - _pw_window_start=eval_type(left_time_expression), - _pw_window_end=eval_type(left_time_expression), ) right_window = right.select( - _pw_window=pw.apply(self._assign_windows, None, right_time_expression) + _pw_window=pw.apply_with_type( + self._assign_windows, + dt.List( + dt.Tuple( + dt.NONE, + eval_type(right_time_expression), + eval_type(right_time_expression), + ) + ), + None, + right_time_expression, + ) ) right_window = right_window.flatten(right_window._pw_window, *right) right_window = right_window.with_columns( _pw_window_start=pw.this._pw_window.get(1), _pw_window_end=pw.this._pw_window.get(2), - ).update_types( - _pw_window_start=eval_type(right_time_expression), - _pw_window_end=eval_type(right_time_expression), ) for cond in on: diff --git a/python/pathway/stdlib/temporal/utils.py b/python/pathway/stdlib/temporal/utils.py index 0517c38f..0dd00104 100644 --- a/python/pathway/stdlib/temporal/utils.py +++ b/python/pathway/stdlib/temporal/utils.py @@ -1,9 +1,9 @@ # Copyright © 2023 Pathway import datetime -from typing import Any, Dict, Tuple, Union, get_args +from typing import Any, Dict, Tuple, Union -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc, Duration +from pathway.internals import dtype as dt from pathway.internals.type_interpreter import eval_type TimeEventType = Union[int, float, datetime.datetime] @@ -19,12 +19,12 @@ def get_default_shift(interval: IntervalType) -> TimeEventType: return 0.0 -def _get_possible_types(type: Any) -> Any: +def _get_possible_types(type: Any) -> Tuple[dt.DType, ...]: if type is TimeEventType: - return [int, float, DateTimeNaive, DateTimeUtc] + return (dt.INT, dt.FLOAT, dt.DATE_TIME_NAIVE, dt.DATE_TIME_UTC) if type is IntervalType: - return [int, float, Duration, Duration] - return get_args(type) + return (dt.INT, dt.FLOAT, dt.DURATION, dt.DURATION) + raise ValueError("Type has to be either TimeEventType or IntervalType.") def check_joint_types(parameters: Dict[str, Tuple[Any, Any]]) -> None: @@ -48,7 +48,15 @@ def check_joint_types(parameters: Dict[str, Tuple[Any, Any]]) -> None: for name, (_variable, expected_type) in parameters.items() } ) - if not any([types == ex_types for ex_types in expected_types]): + for ex_types in expected_types: + if all( + [ + dt.dtype_issubclass(dtype, ex_dtype) + for (dtype, ex_dtype) in zip(types.values(), ex_types.values()) + ] + ): + break + else: expected_types_string = " or ".join( repr(tuple(ex_types.values())) for ex_types in expected_types ) diff --git a/python/pathway/tests/expressions/test_datetimes.py b/python/pathway/tests/expressions/test_datetimes.py index 93eca812..f4e33b50 100644 --- a/python/pathway/tests/expressions/test_datetimes.py +++ b/python/pathway/tests/expressions/test_datetimes.py @@ -1,5 +1,6 @@ # Copyright © 2023 Pathway +import datetime import operator from typing import Any, List, Tuple @@ -9,8 +10,13 @@ from dateutil import tz import pathway as pw +from pathway import dt from pathway.debug import table_from_markdown, table_from_pandas -from pathway.tests.utils import assert_table_equality, assert_table_equality_wo_index +from pathway.tests.utils import ( + assert_table_equality, + assert_table_equality_wo_index, + run_all, +) @pytest.mark.parametrize( @@ -159,6 +165,8 @@ def test_date_time(method_name: str, is_naive: bool) -> None: "%u", "%V", "%Y-%m-%d %H:%M:%S.%f", + "%Y-%m-%d %H:%M:%S.%%f", + "%%H:%%M:%%S", # test that %%sth are not changed to values ], ) def test_strftime(fmt_out: str, is_naive: bool) -> None: @@ -192,9 +200,8 @@ def test_strftime(fmt_out: str, is_naive: bool) -> None: df_converted = pd.DataFrame({"ts": df.ts.dt.tz_convert(tz.UTC)}) df_new = pd.DataFrame({"txt": df_converted.ts.dt.strftime(fmt_out)}) table = table_from_pandas(df) - fmt_out_pw = fmt_out.replace( - "%f", "%6f" - ) # TODO: do we want to be consistent with python in printing? + fmt_out_pw = fmt_out.replace("%f", "%6f") + fmt_out_pw = fmt_out_pw.replace("%%6f", "%%f") table_pw = table.select(txt=table.ts.dt.strftime(fmt_out_pw)) table_pd = table_from_pandas(df_new) assert_table_equality(table_pw, table_pd) @@ -246,9 +253,25 @@ def test_strftime_with_format_in_column() -> None: (["2023-03-25T16:43:21", "2023-03-26T16:43:21"], "%Y-%m-%dT%H:%M:%S"), (["2023-03-25 04:43:21 AM", "2023-03-26 04:43:21 PM"], "%Y-%m-%d %I:%M:%S %p"), ( - ["2023-03-25 16:43:21.123456789", "2023-03-26 16:43:21.123456789"], + [ + "1900-01-01 00:00:00.396", + "1900-01-01 00:00:00.396093123", + "2023-03-25 16:43:21.123456789", + "2023-03-26 16:43:21.123456789", + "2023-03-26 16:43:21.12", + ], "%Y-%m-%d %H:%M:%S.%f", ), + ( + [ + "1900-01-01 %f00:00:00.396", + "1900-01-01 %f00:00:00.396093123", + "2023-03-25 %f16:43:21.123456789", + "2023-03-26 %f16:43:21.123456789", + "2023-03-26 %f16:43:21.12", + ], + "%Y-%m-%d %%f%H:%M:%S.%f", + ), ], ) def test_strptime_naive(data: List[str], fmt: str) -> None: @@ -256,12 +279,51 @@ def test_strptime_naive(data: List[str], fmt: str) -> None: table_pd = table_from_pandas(df) table = table_from_pandas(pd.DataFrame({"a": data})) table_pw = table.select(ts=table.a.dt.strptime(fmt)) - if "%Y" not in fmt: - table_pw = table_pw.select(ts=table_pw.ts.dt.strftime("%H:%M:%S.%f")) - table_pd = table_pd.select(ts=table_pd.ts.dt.strftime("%H:%M:%S.%f")) assert_table_equality(table_pw, table_pd) +@pytest.mark.parametrize( + "data,fmt", + [ + (["1960-02-03", "2023-03-25", "2023-03-26", "2123-03-26"], "%Y-%m-%d"), + (["03.02.1960", "25.03.2023", "26.03.2023", "26.03.2123"], "%d.%m.%Y"), + (["02.03.1960", "03.25.2023", "03.26.2023", "03.26.2123"], "%m.%d.%Y"), + (["12:34:00", "01:22:12", "13:00:34", "23:59:59"], "%H:%M:%S"), + (["12:34:00 PM", "01:22:12 AM", "01:00:34 PM", "11:59:59 PM"], "%I:%M:%S %p"), + ( + ["12:34:00.000000", "01:22:12.12345", "13:00:34.11"], + "%H:%M:%S.%f", + ), + (["2023-03-25 16:43:21", "2023-03-26 16:43:21"], "%Y-%m-%d %H:%M:%S"), + (["2023-03-25T16:43:21", "2023-03-26T16:43:21"], "%Y-%m-%dT%H:%M:%S"), + (["2023-03-25 04:43:21 AM", "2023-03-26 04:43:21 PM"], "%Y-%m-%d %I:%M:%S %p"), + ( + [ + "1900-01-01 00:00:00.396", + "1900-01-01 00:00:00.396093", + "2023-03-25 16:43:21.123456", + "2023-03-26 16:43:21.123456", + "2023-03-26 16:43:21.12", + ], + "%Y-%m-%d %H:%M:%S.%f", + ), + ], +) +def test_strptime_naive_with_python_datetime(data: List[str], fmt: str) -> None: + table = table_from_pandas(pd.DataFrame({"a": data})).select( + ts=pw.this.a.dt.strptime(fmt) + ) + + @pw.udf + def parse_datetime(date_str: str) -> dt.DATE_TIME_NAIVE: + return datetime.datetime.strptime(date_str, fmt) # type: ignore [return-value] + + expected = table_from_pandas(pd.DataFrame({"a": data})).select( + ts=parse_datetime(pw.this.a) + ) + assert_table_equality(table, expected) + + @pytest.mark.parametrize( "data,fmt", [ @@ -279,11 +341,24 @@ def test_strptime_naive(data: List[str], fmt: str) -> None: ), ( [ - "2023-03-25 16:43:21.123456789+01:23", - "2023-03-26 16:43:21.123456789+01:23", + "1900-01-01 00:00:00.396-11:05", + "1900-01-01 00:00:00.396093123-11:05", + "2023-03-25 16:43:21.123456789-11:05", + "2023-03-26 16:43:21.123456789-11:05", + "2023-03-26 16:43:21.12-11:05", ], "%Y-%m-%d %H:%M:%S.%f%z", ), + ( + [ + "1900%f01-01 00:00:00.396-11:05", + "1900%f01-01 00:00:00.396093123-11:05", + "2023%f03-25 16:43:21.123456789-11:05", + "2023%f03-26 16:43:21.123456789-11:05", + "2023%f03-26 16:43:21.12-11:05", + ], + "%Y%%f%m-%d %H:%M:%S.%f%z", + ), ], ) def test_strptime_time_zone_aware(data: List[str], fmt: str) -> None: @@ -294,6 +369,50 @@ def test_strptime_time_zone_aware(data: List[str], fmt: str) -> None: assert_table_equality(table_pw, table_pd) +@pytest.mark.parametrize( + "data,fmt", + [ + ( + ["2023-03-25 16:43:21+01:23", "2023-03-26 16:43:21+01:23"], + "%Y-%m-%d %H:%M:%S%z", + ), + ( + ["2023-03-25T16:43:21+01:23", "2023-03-26T16:43:21+01:23"], + "%Y-%m-%dT%H:%M:%S%z", + ), + ( + ["2023-03-25 04:43:21 AM +01:23", "2023-03-26 04:43:21 PM +01:23"], + "%Y-%m-%d %I:%M:%S %p %z", + ), + ( + [ + "1900-01-01 00:00:00.396-11:05", + "1900-01-01 00:00:00.396093-11:05", + "2023-03-25 16:43:21.1234-11:05", + "2023-03-26 16:43:21.12345-11:05", + "2023-03-26 16:43:21.12-11:05", + ], + "%Y-%m-%d %H:%M:%S.%f%z", + ), + ], +) +def test_strptime_time_zone_aware_with_python_datetime( + data: List[str], fmt: str +) -> None: + table = table_from_pandas(pd.DataFrame({"a": data})).select( + ts=pw.this.a.dt.strptime(fmt) + ) + + @pw.udf + def parse_datetime(date_str: str) -> dt.DATE_TIME_UTC: + return datetime.datetime.strptime(date_str, fmt) # type: ignore [return-value] + + expected = table_from_pandas(pd.DataFrame({"a": data})).select( + ts=parse_datetime(pw.this.a) + ) + assert_table_equality(table, expected) + + def test_strptime_with_format_in_column() -> None: pairs = [ ("1960-02-03 12:45:12", "%Y-%m-%d %H:%M:%S"), @@ -323,6 +442,53 @@ def test_strptime_with_format_in_column() -> None: assert_table_equality_wo_index(res, expected) +def test_strptime_naive_errors_on_wrong_specifier() -> None: + table_from_pandas(pd.DataFrame({"a": ["2023-03-26 16:43:21-12"]})).select( + t=pw.this.a.dt.strptime("%Y-%m-%d %H:%M:%S-%f") + ) + with pytest.raises( + ValueError, + match="parse error: Cannot use format: %Y-%m-%d %H:%M:%S-%f." + + " Using %f without the leading dot is not supported.", + ): + run_all() + + +def test_strptime_naive_errors_on_wrong_format() -> None: + table_from_pandas(pd.DataFrame({"a": ["2023-03-26T16:43:21.12"]})).select( + t=pw.this.a.dt.strptime("%Y-%m-%d %H:%M:%S.%f") + ) + with pytest.raises( + ValueError, + match="parse error: Cannot format date: 2023-03-26T16:43:21.12 using format: %Y-%m-%d %H:%M:%S%.f.", + ): + run_all() + + +def test_strptime_utc_errors_on_wrong_specifier() -> None: + table_from_pandas(pd.DataFrame({"a": ["2023-03-26 16:43:21-12+0100"]})).select( + t=pw.this.a.dt.strptime("%Y-%m-%d %H:%M:%S-%f%z") + ) + with pytest.raises( + ValueError, + match="parse error: Cannot use format: %Y-%m-%d %H:%M:%S-%f%z." + + " Using %f without the leading dot is not supported.", + ): + run_all() + + +def test_strptime_utc_errors_on_wrong_format() -> None: + table_from_pandas(pd.DataFrame({"a": ["2023-03-26T16:43:21.12-0100"]})).select( + t=pw.this.a.dt.strptime("%Y-%m-%d %H:%M:%S.%f%z") + ) + with pytest.raises( + ValueError, + match="parse error: Cannot format date: 2023-03-26T16:43:21.12-0100 using" + + " format: %Y-%m-%d %H:%M:%S%.f%z.", + ): + run_all() + + def test_date_time_naive_to_utc() -> None: table = table_from_markdown( """ diff --git a/python/pathway/tests/expressions/test_string.py b/python/pathway/tests/expressions/test_string.py index 6ee46ce3..8c52b2aa 100644 --- a/python/pathway/tests/expressions/test_string.py +++ b/python/pathway/tests/expressions/test_string.py @@ -450,3 +450,52 @@ def test_to_string_for_optional_type(): res = table.select(a=pw.this.a.to_string()) assert_table_equality(res, expected) + + +def test_to_string_for_datetime_naive(): + t = T( + """ + | t + 1 | 2019-12-31T23:49:59.999999999 + 2 | 2019-12-31T23:49:59.0001 + 3 | 2020-03-04T11:13:00.345612 + 4 | 2023-03-26T12:00:00.000000001 + """ + ) + expected = T( + """ + | t + 1 | 2019-12-31T23:49:59.999999999 + 2 | 2019-12-31T23:49:59.000100000 + 3 | 2020-03-04T11:13:00.345612000 + 4 | 2023-03-26T12:00:00.000000001 + """ + ) + assert_table_equality( + t.select(t=pw.this.t.dt.strptime("%Y-%m-%dT%H:%M:%S.%f").to_string()), expected + ) + + +def test_to_string_for_datetime_utc(): + t = T( + """ + | t + 1 | 2019-12-31T23:49:59.999999999+0100 + 2 | 2019-12-31T23:49:59.0001+0100 + 3 | 2020-03-04T11:13:00.345612+0100 + 4 | 2023-03-26T12:00:00.000000001+0100 + """ + ) + expected = T( + """ + | t + 1 | 2019-12-31T22:49:59.999999999+0000 + 2 | 2019-12-31T22:49:59.000100000+0000 + 3 | 2020-03-04T10:13:00.345612000+0000 + 4 | 2023-03-26T11:00:00.000000001+0000 + """ + ) + assert_table_equality( + t.select(t=pw.this.t.dt.strptime("%Y-%m-%dT%H:%M:%S.%f%z").to_string()), + expected, + ) diff --git a/python/pathway/tests/temporal/test_asof_joins.py b/python/pathway/tests/temporal/test_asof_joins.py index 244b37ff..392b5c94 100644 --- a/python/pathway/tests/temporal/test_asof_joins.py +++ b/python/pathway/tests/temporal/test_asof_joins.py @@ -5,7 +5,7 @@ import pytest import pathway as pw -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc +from pathway.internals.dtype import DATE_TIME_NAIVE, DATE_TIME_UTC from pathway.internals.join_mode import JoinMode from pathway.tests.utils import ( T, @@ -367,10 +367,10 @@ def test_with_timestamps(): @pytest.mark.parametrize( "left_type,right_type", [ - (int, float), - (DateTimeNaive, int), - (float, DateTimeNaive), - (DateTimeNaive, DateTimeUtc), + (int, DATE_TIME_UTC), + (DATE_TIME_NAIVE, int), + (float, DATE_TIME_NAIVE), + (DATE_TIME_NAIVE, DATE_TIME_UTC), ], ) def test_incorrect_args(left_type, right_type): diff --git a/python/pathway/tests/temporal/test_interval_joins.py b/python/pathway/tests/temporal/test_interval_joins.py index 84a09c05..feb43aec 100644 --- a/python/pathway/tests/temporal/test_interval_joins.py +++ b/python/pathway/tests/temporal/test_interval_joins.py @@ -8,8 +8,7 @@ import pytest import pathway as pw -from pathway.internals.datetime_types import DateTimeNaive, DateTimeUtc -from pathway.internals.dtype import NoneType # type: ignore +from pathway.internals.dtype import DATE_TIME_NAIVE, DATE_TIME_UTC, NONE from pathway.tests.utils import T, assert_table_equality, assert_table_equality_wo_index @@ -112,8 +111,8 @@ def test_interval_join_time_only( 11 | 6 | 5 """ ) - left = pw.Table.empty(a=int, b=NoneType) - right = pw.Table.empty(a=NoneType, b=int) + left = pw.Table.empty(a=int, b=NONE) + right = pw.Table.empty(a=NONE, b=int) if join_type in [pw.JoinMode.LEFT, pw.JoinMode.OUTER]: expected = expected.concat_reindex(left) @@ -534,6 +533,55 @@ def test_interval_join_float(max_time_difference: float) -> None: assert_table_equality_wo_index(res, expected) +@pytest.mark.parametrize("max_time_difference", [1, 1.5]) +def test_interval_join_int_float(max_time_difference: float) -> None: + t1 = T( + """ + | a | t + 0 | 1 | 3 + 1 | 2 | 7 + """ + ) + t2 = T( + """ + | b | t + 0 | 1 | 1.6 + 1 | 2 | 2.1 + 2 | 3 | 4.4 + 3 | 4 | 6.1 + 4 | 5 | 7.8 + 5 | 6 | 0.0 + 6 | 7 | 9.2 + 7 | 8 | 8.3 + """ + ) + if max_time_difference == 1: + expected = T( + """ + | a | b + 0 | 1 | 2 + 1 | 2 | 4 + 2 | 2 | 5 + """ + ) + else: + expected = T( + """ + | a | b + 0 | 1 | 1 + 1 | 1 | 2 + 2 | 1 | 3 + 3 | 2 | 4 + 4 | 2 | 5 + 5 | 2 | 8 + """ + ) + res = t1.interval_join_inner( + t2, t1.t, t2.t, pw.temporal.interval(-max_time_difference, max_time_difference) + ).select(t1.a, t2.b) + assert_table_equality_wo_index(res, expected) + + @pytest.mark.parametrize( "join_type", [pw.JoinMode.INNER, pw.JoinMode.LEFT, pw.JoinMode.RIGHT, pw.JoinMode.OUTER], @@ -1040,17 +1088,17 @@ def test_with_timestamps() -> None: @pytest.mark.parametrize( "left_type,right_type,lower_bound,upper_bound", [ - (int, int, 1, 1.2), - (int, int, -0.1, 1.1), - (int, int, -0.1, 1), - (int, float, 0, 1), - (float, int, 0, 1), - (float, float, 1, 0.2), - (DateTimeNaive, DateTimeNaive, 1, 2), - (DateTimeNaive, DateTimeNaive, datetime.timedelta(days=1), 2), + (int, int, 1, datetime.timedelta(days=1)), + (int, int, datetime.timedelta(days=-1), datetime.timedelta(days=1)), + (int, int, datetime.timedelta(days=-1), 1), + (int, DATE_TIME_NAIVE, 0, 1), + (DATE_TIME_NAIVE, int, 0, 1), + (float, DATE_TIME_NAIVE, 1, 0.2), + (DATE_TIME_NAIVE, DATE_TIME_NAIVE, 1, 2), + (DATE_TIME_NAIVE, DATE_TIME_NAIVE, datetime.timedelta(days=1), 2), ( - DateTimeUtc, - DateTimeNaive, + DATE_TIME_UTC, + DATE_TIME_NAIVE, datetime.timedelta(days=1), datetime.timedelta(days=2), ), @@ -1107,18 +1155,13 @@ def test_incorrect_args_specific(): """ ) - t1 = t1.with_columns(t=pw.declare_type(DateTimeNaive, pw.this.t)) + t1 = t1.with_columns(t=pw.declare_type(DATE_TIME_NAIVE, pw.this.t)) t2 = t2.with_columns(t=pw.declare_type(int, pw.this.t)) expected_error_message = """Arguments (self_time_expression, other_time_expression, - lower_bound, upper_bound) have to be of types (, , - , ) or (, , , - ) or (, - , , - ) or (, - , , - ) but are of types - (, , , ).""".replace( + lower_bound, upper_bound) have to be of types (INT, INT, INT, INT) or (FLOAT, FLOAT, + FLOAT, FLOAT) or (DATE_TIME_NAIVE, DATE_TIME_NAIVE, DURATION, DURATION) or (DATE_TIME_UTC, + DATE_TIME_UTC, DURATION, DURATION) but are of types (DATE_TIME_NAIVE, INT, INT, INT).""".replace( "\n", "" ) with pytest.raises(TypeError) as error: diff --git a/python/pathway/tests/temporal/test_window_joins.py b/python/pathway/tests/temporal/test_window_joins.py index 0c197d2a..23155c27 100644 --- a/python/pathway/tests/temporal/test_window_joins.py +++ b/python/pathway/tests/temporal/test_window_joins.py @@ -1,9 +1,11 @@ # Copyright © 2023 Pathway +import datetime + import pytest import pathway as pw -from pathway.internals.datetime_types import DateTimeNaive +from pathway.internals.dtype import DATE_TIME_NAIVE, DATE_TIME_UTC from pathway.tests.utils import T, assert_table_equality, assert_table_equality_wo_index @@ -893,26 +895,33 @@ def test_window_join_float(w: pw.temporal.Window) -> None: @pytest.mark.parametrize( "left_type,right_type,window,error_str", [ - (int, int, pw.temporal.tumbling(duration=1.2), ", window.hop"), ( int, int, + pw.temporal.tumbling(duration=datetime.timedelta(days=1)), + ", window.hop", + ), + ( + int, + DATE_TIME_NAIVE, pw.temporal.tumbling(duration=2, offset=1.1), ", window.hop, window.offset", ), ( int, int, - pw.temporal.sliding(hop=2, duration=3.5), + pw.temporal.sliding( + hop=datetime.timedelta(days=1), duration=datetime.timedelta(days=2) + ), ", window.hop, window.duration", ), - (int, float, pw.temporal.tumbling(duration=1.2), ", window.hop"), - (int, float, pw.temporal.tumbling(duration=1.2), ", window.hop"), - (float, int, pw.temporal.session(max_gap=2), ", window.max_gap"), - (float, int, pw.temporal.session(predicate=lambda a, b: False), ""), + (DATE_TIME_NAIVE, float, pw.temporal.tumbling(duration=1.2), ", window.hop"), + (int, DATE_TIME_UTC, pw.temporal.tumbling(duration=1.2), ", window.hop"), + (float, DATE_TIME_NAIVE, pw.temporal.session(max_gap=2), ", window.max_gap"), + (DATE_TIME_UTC, int, pw.temporal.session(predicate=lambda a, b: False), ""), ( - DateTimeNaive, - DateTimeNaive, + DATE_TIME_NAIVE, + DATE_TIME_NAIVE, pw.temporal.sliding(hop=2, duration=3.5), ", window.hop, window.duration", ), diff --git a/python/pathway/tests/temporal/test_windows.py b/python/pathway/tests/temporal/test_windows.py index 5264709d..ab77284b 100644 --- a/python/pathway/tests/temporal/test_windows.py +++ b/python/pathway/tests/temporal/test_windows.py @@ -9,7 +9,8 @@ import pytest import pathway as pw -from pathway.internals.datetime_types import DateTimeNaive +from pathway import dt +from pathway.internals.dtype import DATE_TIME_NAIVE, DATE_TIME_UTC from pathway.tests.utils import T, assert_table_equality_wo_index @@ -123,6 +124,35 @@ def test_session_max_gap(): assert_table_equality_wo_index(result, res) +def test_session_max_gap_mixed(): + t = T( + """ + | t + 1 | 10 + 2 | 11 + 3 | 12 + 4 | 30 + 5 | 34 + 6 | 35 + """ + ) + + gb = t.windowby(t.t, window=pw.temporal.session(max_gap=1.5)) + result = gb.reduce( + min_t=pw.reducers.min(pw.this.t), + count=pw.reducers.count(), + ) + res = T( + """ + min_t | count + 10 | 3 + 30 | 1 + 34 | 2 + """ + ) + assert_table_equality_wo_index(result, res) + + def test_session_window_creation(): with pytest.raises(ValueError): pw.temporal.session() @@ -239,6 +269,39 @@ def test_sliding_larger_hop(): assert_table_equality_wo_index(result, res) +def test_sliding_larger_hop_mixed(): + t = T( + """ + | t + 0 | 11.3 + 1 | 12.1 + 2 | 13.3 + 3 | 14.7 + 4 | 15.3 + 5 | 16.1 + 6 | 17.8 + """ + ) + + gb = t.windowby(t.t, window=pw.temporal.sliding(duration=4, hop=6)) + result = gb.reduce( + pw.this._pw_shard, + pw.this._pw_window_start, + pw.this._pw_window_end, + min_t=pw.reducers.min(pw.this.t), + max_t=pw.reducers.max(pw.this.t), + count=pw.reducers.count(), + ) + + res = T( + """ + _pw_shard | _pw_window_start | _pw_window_end | min_t | max_t | count + | 12 | 16 | 12.1 | 15.3 | 4 + """ + ).update_types(_pw_window_start=dt.FLOAT, _pw_window_end=dt.FLOAT) + assert_table_equality_wo_index(result, res) + + def test_tumbling(): t = T( """ @@ -513,22 +576,33 @@ def test_windows_with_datetimes(w): @pytest.mark.parametrize( "dtype,window,error_str", [ - (int, pw.temporal.tumbling(duration=1.2), ", window.hop"), ( int, - pw.temporal.tumbling(duration=2, offset=1.1), - ", window.hop, window.offset", + pw.temporal.tumbling(duration=datetime.timedelta(days=1)), + ", window.hop", ), ( int, + pw.temporal.tumbling( + duration=datetime.timedelta(days=1), + offset=datetime.datetime(year=1970, month=1, day=1), + ), + ", window.hop, window.offset", + ), + ( + DATE_TIME_UTC, pw.temporal.sliding(hop=2, duration=3.5), ", window.hop, window.duration", ), - (int, pw.temporal.tumbling(duration=1.2), ", window.hop"), - (int, pw.temporal.tumbling(duration=1.2), ", window.hop"), - (float, pw.temporal.session(max_gap=2), ", window.max_gap"), + (DATE_TIME_NAIVE, pw.temporal.tumbling(duration=1.2), ", window.hop"), + ( + int, + pw.temporal.tumbling(duration=datetime.timedelta(days=1)), + ", window.hop", + ), + (DATE_TIME_NAIVE, pw.temporal.session(max_gap=2), ", window.max_gap"), ( - DateTimeNaive, + DATE_TIME_NAIVE, pw.temporal.sliding(hop=2, duration=3.5), ", window.hop, window.duration", ), diff --git a/python/pathway/tests/test_api.py b/python/pathway/tests/test_api.py index 23d9e1da..e67aae3f 100644 --- a/python/pathway/tests/test_api.py +++ b/python/pathway/tests/test_api.py @@ -1060,3 +1060,54 @@ def build(s): with pytest.raises(Exception): api.run_with_new_graph(build, event_loop) + + +@pytest.mark.parametrize( + "value", + [ + None, + 42, + 42.0, + "", + "foo", + True, + False, + (), + (42,), + (1, 2, 3), + np.array([], dtype=int), + np.array([42]), + np.array([1, 2, 3]), + np.array([], dtype=float), + np.array([42.0]), + np.array([1.0, 2.0, 3.0]), + ], + ids=repr, +) +def test_value_type_via_python(event_loop, value): + def build(s): + key = api.ref_scalar() + universe = s.static_universe([key]) + column = s.static_column(universe, [(key, value)], dtype=type(value)) + table = s.table(universe, [column]) + + def fun(values): + [inner] = values + assert type(value) is type(inner) + if isinstance(value, np.ndarray): + assert value.dtype == inner.dtype + assert np.array_equal(value, inner) + else: + assert value == inner + return inner + + new_column = s.map_column(table, fun, api.EvalProperties(dtype=type(value))) + new_table = s.table(universe, [new_column]) + + return table, new_table + + table, new_table = api.run_with_new_graph(build, event_loop) + + assert_equal_tables_wo_index( + table, new_table + ) # wo_index knows how to compare arrays diff --git a/python/pathway/tests/test_common.py b/python/pathway/tests/test_common.py index 16da8fee..d9c7736c 100644 --- a/python/pathway/tests/test_common.py +++ b/python/pathway/tests/test_common.py @@ -17,6 +17,7 @@ import pathway.internals.shadows.operator as operator from pathway.debug import table_from_pandas, table_to_pandas from pathway.internals import api +from pathway.internals import dtype as dt from pathway.internals.expression import NumbaApplyExpression from pathway.tests.utils import ( T, @@ -1014,7 +1015,7 @@ def test_flatten_incorrect_type(): ) with pytest.raises( TypeError, - match=re.escape("Cannot flatten column .a of type ."), + match=re.escape("Cannot flatten column .a of type INT."), ): t = t.flatten(t.a) @@ -4199,6 +4200,27 @@ def move_to_pathway_with_the_right_type(list: List, dtype: Any): assert_table_equality(table, expected) +def test_cast_optional(): + tab = T( + """ + | a + 1 | 1 + 2 | + 3 | 1 + """ + ) + ret = tab.select(a=pw.cast(Optional[float], pw.this.a)) + expected = T( + """ + | a + 1 | 1.0 + 2 | + 3 | 1.0 + """ + ).update_types(a=Optional[float]) + assert_table_equality(ret, expected) + + def test_lazy_coalesce(): tab = T( """ @@ -4746,7 +4768,7 @@ def test_sequence_get_unchecked_fixed_length_dynamic_index_1(): t2 = t1.select(tup=pw.make_tuple(pw.this.i, pw.this.s), a=pw.this.a) t3 = t2.select(r=pw.this.tup[pw.this.a]) - assert t3.schema.as_dict()["r"] == Any + assert t3.schema.as_dict()["r"] == dt.ANY def test_sequence_get_unchecked_fixed_length_dynamic_index_2(): @@ -4794,7 +4816,7 @@ def test_sequence_get_checked_fixed_length_dynamic_index(): t2 = t1.select(tup=pw.make_tuple(pw.this.a, pw.this.b), c=pw.this.c) t3 = t2.select(r=pw.this.tup.get(pw.this.c)) - assert t3.schema.as_dict()["r"] == Optional[int] + assert t3.schema.as_dict()["r"] == dt.Optional(dt.INT) assert_table_equality_wo_types(t3, expected) @@ -4904,9 +4926,7 @@ def test_sequence_get_unchecked_fixed_length_errors(): with pytest.raises( IndexError, match=( - re.escape( - "Index 2 out of range for a tuple of type typing.Tuple[int, int]." - ) + re.escape("Index 2 out of range for a tuple of type Tuple(INT, INT).") + r"[\s\S]*" ), ): @@ -4936,9 +4956,7 @@ def test_sequence_get_checked_fixed_length_errors(): with pytest.warns( UserWarning, match=( - re.escape( - "Index 2 out of range for a tuple of type typing.Tuple[int, int]." - ) + re.escape("Index 2 out of range for a tuple of type Tuple(INT, INT).") + " It refers to the following expression:\n\t" + re.escape("(.tup).get(2, .c)\n") + rf"called in .*{file_name}.*\n" @@ -5402,6 +5420,7 @@ def test_tuple_reducer_consistency(): t2 = left_res.select( pet=pw.this.pet.get(3), owner=pw.this.owner.get(3), age=pw.this.age.get(3) ) + print(t2.schema) joined = left.join( t2, diff --git a/python/pathway/tests/test_demo.py b/python/pathway/tests/test_demo.py new file mode 100644 index 00000000..dd4eb931 --- /dev/null +++ b/python/pathway/tests/test_demo.py @@ -0,0 +1,167 @@ +# Copyright © 2023 Pathway + +import pathlib + +import pytest + +import pathway as pw +from pathway.tests.utils import T, assert_table_equality_wo_index, write_csv + + +def test_generate_custom_stream(): + value_functions = { + "number": lambda x: x + 1, + "name": lambda x: f"Person_{x}", + "age": lambda x: 20 + x, + } + + class InputSchema(pw.Schema): + number: int + name: str + age: int + + table = pw.demo.generate_custom_stream( + value_functions, schema=InputSchema, nb_rows=5, input_rate=1000 + ) + + expected_md = """ + number | name | age + 1 | Person_0 | 20 + 2 | Person_1 | 21 + 3 | Person_2 | 22 + 4 | Person_3 | 23 + 5 | Person_4 | 24 + """ + expected = T(expected_md) + assert_table_equality_wo_index(table, expected) + + +def test_generate_range_stream(): + table = pw.demo.range_stream(nb_rows=5, input_rate=1000) + + expected_md = """ + value + 0.0 + 1.0 + 2.0 + 3.0 + 4.0 + """ + expected = T(expected_md) + expected = expected.select(value=pw.cast(float, pw.this.value)) + assert_table_equality_wo_index(table, expected) + + +def test_generate_range_stream_offset(): + table = pw.demo.range_stream(nb_rows=5, offset=10, input_rate=1000) + + expected_md = """ + value + 10.0 + 11.0 + 12.0 + 13.0 + 14.0 + """ + expected = T(expected_md) + expected = expected.select(value=pw.cast(float, pw.this.value)) + assert_table_equality_wo_index(table, expected) + + +def test_generate_range_stream_negative_offset(): + table = pw.demo.range_stream(nb_rows=5, offset=-10, input_rate=1000) + + expected_md = """ + value + -10.0 + -9.0 + -8.0 + -7.0 + -6.0 + """ + expected = T(expected_md) + expected = expected.select(value=pw.cast(float, pw.this.value)) + assert_table_equality_wo_index(table, expected) + + +def test_generate_noisy_linear_stream(): + table = pw.demo.noisy_linear_stream(nb_rows=5, input_rate=1000) + + expected_md = """ + x + 0.0 + 1.0 + 2.0 + 3.0 + 4.0 + """ + expected = T(expected_md) + # We only compare the x values because the y are random and are "high precision" floats + expected = expected.select(x=pw.cast(float, pw.this.x)) + table = table.select(pw.this.x) + assert_table_equality_wo_index(table, expected) + + +def test_demo_replay(tmp_path: pathlib.Path): + data = """ + k | v + 1 | foo + 2 | bar + 3 | baz + """ + input_path = tmp_path / "input.csv" + write_csv(input_path, data) + + class InputSchema(pw.Schema): + k: int + v: str + + table = pw.demo.replay_csv( + str(input_path), + schema=InputSchema, + input_rate=1000, + ) + expected = T(data) + assert_table_equality_wo_index(table, expected) + + +def test_demo_replay_with_time(tmp_path: pathlib.Path): + data = """ + k | v | time + 1 | foo | 0 + 2 | bar | 2 + 3 | baz | 4 + """ + input_path = tmp_path / "input.csv" + write_csv(input_path, data) + + class InputSchema(pw.Schema): + k: int + v: str + time: int + + table = pw.demo.replay_csv_with_time( + str(input_path), + schema=InputSchema, + time_column="time", + unit="ns", + autocommit_ms=10, + ) + expected = T(data) + assert_table_equality_wo_index(table, expected) + + +def test_demo_replay_with_time_wrong_schema(): + class InputSchema(pw.Schema): + time: str + k: int + v: str + + with pytest.raises(ValueError, match="Invalid schema. Time columns must be int."): + pw.demo.replay_csv_with_time( + "/tmp", + schema=InputSchema, + time_column="time", + unit="ns", + autocommit_ms=10, + ) diff --git a/python/pathway/tests/test_expression_repr.py b/python/pathway/tests/test_expression_repr.py index 90e56ea4..86293cd1 100644 --- a/python/pathway/tests/test_expression_repr.py +++ b/python/pathway/tests/test_expression_repr.py @@ -124,8 +124,8 @@ def test_cast(): 1 | Alice | 10 """ ) - assert repr(pw.cast(int, t.pet)) == "pathway.cast(int, .pet)" - assert repr(pw.cast(float, t.pet)) == "pathway.cast(float, .pet)" + assert repr(pw.cast(int, t.pet)) == "pathway.cast(INT, .pet)" + assert repr(pw.cast(float, t.pet)) == "pathway.cast(FLOAT, .pet)" def test_declare_type(): @@ -136,11 +136,11 @@ def test_declare_type(): """ ) assert ( - repr(pw.declare_type(int, t.pet)) == "pathway.declare_type(int, .pet)" + repr(pw.declare_type(int, t.pet)) == "pathway.declare_type(INT, .pet)" ) assert ( repr(pw.declare_type(float, t.pet)) - == "pathway.declare_type(float, .pet)" + == "pathway.declare_type(FLOAT, .pet)" ) diff --git a/python/pathway/tests/test_flatten.py b/python/pathway/tests/test_flatten.py index b894ab7b..e5c929bf 100644 --- a/python/pathway/tests/test_flatten.py +++ b/python/pathway/tests/test_flatten.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import List +from typing import Any, List import pandas as pd import pathway as pw -from pathway import Pointer, Table, this +from pathway import Table, this from pathway.tests.utils import T, assert_table_equality_wo_index @@ -93,5 +93,5 @@ def test_flatten_empty_lists(): assert_table_equality_wo_index( tab.flatten(this.col, origin_id=this.id), - Table.empty(col=str, origin_id=Pointer), + Table.empty(col=Any, origin_id=pw.Pointer), ) diff --git a/python/pathway/tests/test_sorting.py b/python/pathway/tests/test_sorting.py index ced83e87..3731824d 100644 --- a/python/pathway/tests/test_sorting.py +++ b/python/pathway/tests/test_sorting.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - from pathway import ( ClassArg, Table, @@ -14,6 +12,7 @@ this, transformer, ) +from pathway.dt import Optional from pathway.stdlib.indexing.sorting import ( binsearch_oracle, build_sorted_index, @@ -517,11 +516,11 @@ def test_compute_max_key(): 0 | 3 """ ).with_columns(val=items.pointer_from(this.val)) - expected = expected.update_types(val=Optional[expected.schema["val"]]) # noqa + expected = expected.update_types(val=Optional(expected.schema["val"])) # noqa assert_table_equality( - expected, result, + expected, ) diff --git a/python/pathway/tests/test_types.py b/python/pathway/tests/test_types.py index 08801f12..b517df30 100644 --- a/python/pathway/tests/test_types.py +++ b/python/pathway/tests/test_types.py @@ -19,13 +19,9 @@ def test_date_time_naive_schema(): diff=pw.this.t1 - pw.this.t2 ) schema = table_with_datetimes.schema.as_dict() - assert ( - repr(schema["t1"]) == "" - ) - assert ( - repr(schema["t2"]) == "" - ) - assert repr(schema["diff"]) == "" + assert repr(schema["t1"]) == "DATE_TIME_NAIVE" + assert repr(schema["t2"]) == "DATE_TIME_NAIVE" + assert repr(schema["diff"]) == "DURATION" def test_date_time_utc_schema(): @@ -43,10 +39,6 @@ def test_date_time_utc_schema(): diff=pw.this.t1 - pw.this.t2 ) schema = table_with_datetimes.schema.as_dict() - assert ( - repr(schema["t1"]) == "" - ) - assert ( - repr(schema["t2"]) == "" - ) - assert repr(schema["diff"]) == "" + assert repr(schema["t1"]) == "DATE_TIME_UTC" + assert repr(schema["t2"]) == "DATE_TIME_UTC" + assert repr(schema["diff"]) == "DURATION" diff --git a/python/pathway/tests/utils.py b/python/pathway/tests/utils.py index 42a2c5a6..3799ebd9 100644 --- a/python/pathway/tests/utils.py +++ b/python/pathway/tests/utils.py @@ -68,6 +68,25 @@ def __call__(self): return len(result) == self.n_lines +def expect_csv_checker(expected, output_path, usecols=("k", "v"), index_col=("k")): + expected = ( + pw.debug._markdown_to_pandas(expected) + .set_index(index_col, drop=False) + .sort_index() + ) + + def checker(): + result = ( + pd.read_csv(output_path, usecols=[*usecols, *index_col]) + .convert_dtypes() + .set_index(index_col, drop=False) + .sort_index() + ) + return expected.equals(result) + + return checker + + @dataclass(frozen=True) class TestDataSource(datasource.DataSource): __test__ = False diff --git a/src/engine/dataflow/operators/time_column.rs b/src/engine/dataflow/operators/time_column.rs index f25bd363..6c2bf7c5 100644 --- a/src/engine/dataflow/operators/time_column.rs +++ b/src/engine/dataflow/operators/time_column.rs @@ -17,6 +17,7 @@ use std::cmp::{max, Ord}; use std::rc::Rc; use timely::communication::Push; use timely::dataflow::channels::pact::Pipeline; +use timely::dataflow::operators::generic::builder_rc::OperatorBuilder; use timely::dataflow::operators::generic::OutputHandle; use timely::dataflow::operators::CapabilityRef; use timely::dataflow::operators::Operator; @@ -28,6 +29,7 @@ use timely::progress::{PathSummary, Timestamp}; type KeyValArr = Arranged::Timestamp, R>>>>>; + #[derive(Debug, Default, Hash, Clone, Eq, Ord, PartialEq, PartialOrd, Serialize, Deserialize)] pub struct TimeKey { pub time: CT, @@ -214,7 +216,6 @@ where CT: ExchangeData + Abelian, R: ExchangeData + Abelian, { - // For each (data, Alt(time), diff) we add a (data, Neu(time), -diff). fn prepend_time_to_key( &self, time_column_extractor: &'static impl Fn(&V) -> CT, @@ -242,8 +243,8 @@ pub trait TimeColumnBuffer< fn postpone( &self, scope: G, - threshold_time_column_extractor: &'static impl Fn(&V) -> CT, - current_time_column_extractor: &'static impl Fn(&V) -> CT, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, ) -> Collection where G::Timestamp: @@ -263,8 +264,8 @@ where fn postpone( &self, mut scope: G, - threshold_time_column_extractor: &'static impl Fn(&V) -> CT, - current_time_column_extractor: &'static impl Fn(&V) -> CT, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, ) -> Collection where G::Timestamp: @@ -278,9 +279,9 @@ where let ret_in_buffer_timespace_col = postpone_core( self.enter(inner) .concat(&retractions.negate()) - .prepend_time_to_key(threshold_time_column_extractor) + .prepend_time_to_key(threshold_time_extractor) .arrange(), - current_time_column_extractor, + current_time_extractor, ); retractions.set(&ret_in_buffer_timespace_col); ret_in_buffer_timespace_col.leave() @@ -288,15 +289,14 @@ where } } -fn push_key_values_to_output>, P>( +fn push_key_values_to_output( wrapper: &mut CursorStorageWrapper, output: &mut OutputHandle<'_, C::Time, ((K, C::Val), C::Time, C::R), P>, capability: &CapabilityRef<'_, C::Time>, k: &K, time: &Option, - total_weight: bool, + weight_fun: &'static impl Fn(&mut CursorStorageWrapper) -> Option, ) where - T: Timestamp + Lattice + Clone, K: Ord + Clone + Data + 'static, C::Val: Ord + Clone + Data + 'static, C::Time: Lattice + timely::progress::Timestamp, @@ -311,11 +311,8 @@ fn push_key_values_to_output>, P>( >, { while wrapper.cursor.val_valid(wrapper.storage) { - let weight = if total_weight { - key_val_total_weight(wrapper) - } else { - key_val_total_original_weight(wrapper) - }; + let weight = weight_fun(wrapper); + //rust can't do if let Some(weight) = weight && !weight.is_zero() //while we can do if let ... { if !weight.is_zero()} that introduces if-s //that can be collapsed, which is bad as well @@ -347,8 +344,7 @@ pub fn postpone_core< R: ExchangeData + Abelian + Diff, >( mut input_arrangement: KeyValArr, V, R>, - current_time_column_extractor: &'static impl Fn(&V) -> CT, - // ) -> KeyValArr + current_time_extractor: &'static impl Fn(&V) -> CT, ) -> Collection where G: Scope, @@ -363,7 +359,6 @@ where .stream .unary(Pipeline, "buffer", move |_capability, _operator_info| { let mut input_buffer = Vec::new(); - let mut source_trace = input_arrangement.trace.clone(); let mut max_column_time: Option = None; let mut last_arrangement_key: Option> = None; move |input, output| { @@ -410,7 +405,7 @@ where }; }, ); - let current_column_time = current_time_column_extractor( + let current_column_time = current_time_extractor( batch_wrapper.cursor.val(batch_wrapper.storage), ); @@ -438,7 +433,7 @@ where &capability, k, &max_curr_time, - false, + &key_val_total_original_weight, ); local_last_arrangement_key = local_last_arrangement_key.map(|timekey: TimeKey| { @@ -482,7 +477,7 @@ where &capability, k, &max_curr_time, - true, + &key_val_total_weight, ); input_wrapper.cursor.step_key(input_wrapper.storage); } @@ -496,13 +491,174 @@ where last_arrangement_key = local_last_arrangement_key; - source_trace.advance_upper(&mut upper_limit); + input_arrangement.trace.advance_upper(&mut upper_limit); - source_trace.set_logical_compaction(upper_limit.borrow()); - source_trace.set_physical_compaction(upper_limit.borrow()); + input_arrangement + .trace + .set_logical_compaction(upper_limit.borrow()); + input_arrangement + .trace + .set_physical_compaction(upper_limit.borrow()); }); } }) }; Collection { inner: stream } } + +pub trait TimeColumnForget< + G: Scope, + K: ExchangeData + Abelian + Shard, + V: ExchangeData + Abelian, + R: ExchangeData + Abelian, + CT: Abelian + Ord + ExchangeData, +> +{ + fn forget( + &self, + scope: G, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, + ) -> Collection + where + G::Timestamp: + Timestamp + Default + PathSummary + Epsilon; +} + +impl TimeColumnForget for Collection +where + G: Scope + MaybeTotalScope, + G::Timestamp: Lattice, + K: ExchangeData + Abelian + Shard, + V: ExchangeData + Abelian, + CT: ExchangeData + Abelian, + R: ExchangeData + Abelian, + TimeKey: Hashable, +{ + fn forget( + &self, + scope: G, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, + ) -> Collection + where + G::Timestamp: + Timestamp + Default + PathSummary + Epsilon, + { + self.concat( + &self + .postpone(scope, threshold_time_extractor, current_time_extractor) + .negate(), + ) + } +} +pub trait TimeColumnCutoff< + G: Scope, + K: ExchangeData + Abelian + Shard, + V: ExchangeData + Abelian, + R: ExchangeData + Abelian, + CT: Abelian + Ord + ExchangeData, +> +{ + fn cut_off( + &self, + scope: G, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_column_extractor: &'static impl Fn(&V) -> CT, + ) -> (Collection, Collection) + where + G::Timestamp: + Timestamp + Default + PathSummary + Epsilon; +} + +impl TimeColumnCutoff for Collection +where + G: Scope + MaybeTotalScope, + G::Timestamp: Lattice, + K: ExchangeData + Abelian + Shard, + V: ExchangeData + Abelian, + CT: ExchangeData + Abelian, + R: ExchangeData + Abelian, + TimeKey: Hashable, +{ + fn cut_off( + &self, + _scope: G, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, + ) -> (Collection, Collection) + where + G::Timestamp: + Timestamp + Default + PathSummary + Epsilon, + { + ignore_late(self, threshold_time_extractor, current_time_extractor) + } +} + +#[allow(clippy::too_many_lines)] +pub fn ignore_late< + G: ScopeParent, + CT: ExchangeData, + K: ExchangeData + Shard, + V, + R: ExchangeData + Abelian + Diff, +>( + input_collection: &Collection, + threshold_time_extractor: &'static impl Fn(&V) -> CT, + current_time_extractor: &'static impl Fn(&V) -> CT, +) -> (Collection, Collection) +where + G: Scope, + G::Timestamp: Lattice + Ord, + CT: ExchangeData + Abelian, + K: ExchangeData + Abelian, + V: ExchangeData + Abelian, +{ + let mut builder = + OperatorBuilder::new("ignore_late".to_owned(), input_collection.inner.scope()); + + let mut input = builder.new_input(&input_collection.inner, Pipeline); + let (mut output, stream) = builder.new_output(); + let (mut late_output, late_stream) = builder.new_output(); + + builder.build(move |_| { + let mut input_buffer = Vec::new(); + let mut max_column_time: Option = None; + move |_frontiers| { + let mut output_handle = output.activate(); + let mut late_output_handle = late_output.activate(); + input.for_each(|capability, batch| { + batch.swap(&mut input_buffer); + let mut max_curr_time = None; + + for ((_key, val), time, _weight) in &input_buffer { + let candidate_column_time = current_time_extractor(val); + if max_column_time.is_none() + || &candidate_column_time > max_column_time.as_ref().unwrap() + { + max_column_time = Some(candidate_column_time.clone()); + } + //for complex handling this will need some changes, to accommodate anitchains + if max_curr_time.is_none() || max_curr_time.as_ref().unwrap() < time { + max_curr_time = Some(time.clone()); + }; + } + + for entry in input_buffer.drain(..) { + let ((_key, val), _time, _weight) = &entry; + let threshold = threshold_time_extractor(val); + if max_column_time.as_ref().unwrap() < &threshold { + output_handle.session(&capability).give(entry); + } else { + late_output_handle.session(&capability).give(entry); + } + } + }); + } + }); + + ( + Collection { inner: stream }, + Collection { inner: late_stream }, + ) +} diff --git a/src/engine/time.rs b/src/engine/time.rs index 5c3894ee..f7733930 100644 --- a/src/engine/time.rs +++ b/src/engine/time.rs @@ -74,6 +74,17 @@ pub trait DateTime { .unwrap() .timestamp_nanos() } + + fn sanitize_format_string(format: &str) -> Result { + let format = format.replace(".%f", "%.f"); + if format.matches("%f").count() == format.matches("%%f").count() { + Ok(format) + } else { + Err(Error::ParseError(format!( + "Cannot use format: {format}. Using %f without the leading dot is not supported." + ))) + } + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] @@ -87,17 +98,20 @@ impl DateTimeNaive { } pub fn strptime(date_string: &str, format: &str) -> Result { - if let Ok(datetime) = chrono::NaiveDateTime::parse_from_str(date_string, format) { + let format = Self::sanitize_format_string(format)?; + if let Ok(datetime) = chrono::NaiveDateTime::parse_from_str(date_string, &format) { Ok(Self { timestamp: datetime.timestamp_nanos(), }) - } else if let Ok(date) = chrono::NaiveDate::parse_from_str(date_string, format) { + } else if let Ok(date) = chrono::NaiveDate::parse_from_str(date_string, &format) { let datetime = date.and_hms_opt(0, 0, 0).unwrap(); Ok(Self { timestamp: datetime.timestamp_nanos(), }) - } else if let Ok(time) = chrono::NaiveTime::parse_from_str(date_string, format) { - let datetime = chrono::Utc::now().date_naive().and_time(time); + } else if let Ok(time) = chrono::NaiveTime::parse_from_str(date_string, &format) { + let datetime = chrono::NaiveDate::from_ymd_opt(1900, 1, 1) + .unwrap() + .and_time(time); Ok(Self { timestamp: datetime.timestamp_nanos(), }) @@ -202,7 +216,7 @@ impl Sub for DateTimeNaive { impl Display for DateTimeNaive { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}", self.strftime("%Y-%m-%dT%H:%M:%S%.f")) + write!(fmt, "{}", self.strftime("%Y-%m-%dT%H:%M:%S%.9f")) } } @@ -217,7 +231,8 @@ impl DateTimeUtc { } pub fn strptime(date_string: &str, format: &str) -> Result { - match chrono::DateTime::parse_from_str(date_string, format) { + let format = Self::sanitize_format_string(format)?; + match chrono::DateTime::parse_from_str(date_string, &format) { Ok(datetime) => Ok(Self { timestamp: datetime.timestamp_nanos(), }), @@ -296,7 +311,7 @@ impl Sub for DateTimeUtc { impl Display for DateTimeUtc { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}", self.strftime("%Y-%m-%dT%H:%M:%S%.f%z")) + write!(fmt, "{}", self.strftime("%Y-%m-%dT%H:%M:%S%.9f%z")) } } diff --git a/src/python_api.rs b/src/python_api.rs index f8a83462..0a623f43 100644 --- a/src/python_api.rs +++ b/src/python_api.rs @@ -252,16 +252,8 @@ impl<'source> FromPyObject<'source> for Value { } else if let Ok(b) = ob.extract::<&PyBool>() { // Fallback checks from now on Ok(Value::Bool(b.is_true())) - } else if let Ok(i) = ob.extract::() { - Ok(Value::Int(i)) - } else if let Ok(f) = ob.extract::() { - // XXX: bigints go here - Ok(Value::Float(f.into())) - } else if let Ok(k) = ob.extract::() { - Ok(Value::Pointer(k)) - } else if let Ok(s) = ob.downcast::() { - Ok(s.to_str()?.into()) } else if let Ok(array) = ob.extract::>() { + // single-element arrays convert to scalars, so we need to check for arrays first Ok(Value::from(array.as_array().to_owned())) } else if let Ok(array) = ob.extract::>() { Ok(Value::from(array.as_array().mapv(i64::from))) @@ -271,6 +263,15 @@ impl<'source> FromPyObject<'source> for Value { Ok(Value::from(array.as_array().to_owned())) } else if let Ok(array) = ob.extract::>() { Ok(Value::from(array.as_array().mapv(f64::from))) + } else if let Ok(i) = ob.extract::() { + Ok(Value::Int(i)) + } else if let Ok(f) = ob.extract::() { + // XXX: bigints go here + Ok(Value::Float(f.into())) + } else if let Ok(k) = ob.extract::() { + Ok(Value::Pointer(k)) + } else if let Ok(s) = ob.downcast::() { + Ok(s.to_str()?.into()) } else if let Ok(t) = ob.extract::>() { Ok(Value::from(t.as_slice())) } else { diff --git a/tests/test_json_output.rs b/tests/test_json_output.rs index 99e47a88..99837d85 100644 --- a/tests/test_json_output.rs +++ b/tests/test_json_output.rs @@ -133,7 +133,7 @@ fn test_json_date_time_naive_serialization() -> eyre::Result<()> { assert_eq!(result.payloads.len(), 1); assert_eq!( from_utf8(&result.payloads[0])?, - r#"{"a":"2023-05-15T10:51:00","diff":1,"time":0}"# + r#"{"a":"2023-05-15T10:51:00.000000000","diff":1,"time":0}"# ); assert_eq!(result.values.len(), 0); @@ -153,7 +153,7 @@ fn test_json_date_time_utc_serialization() -> eyre::Result<()> { assert_eq!(result.payloads.len(), 1); assert_eq!( from_utf8(&result.payloads[0])?, - r#"{"a":"2023-05-15T10:51:00+0000","diff":1,"time":0}"# + r#"{"a":"2023-05-15T10:51:00.000000000+0000","diff":1,"time":0}"# ); assert_eq!(result.values.len(), 0); diff --git a/tests/test_psql_output.rs b/tests/test_psql_output.rs index b65c367c..a2e6bfbc 100644 --- a/tests/test_psql_output.rs +++ b/tests/test_psql_output.rs @@ -209,7 +209,7 @@ fn test_psql_format_date_time_naive() -> eyre::Result<()> { { assert_eq!( Value::DateTimeNaive(DateTimeNaive::new(1684147860000000000)).to_postgres_output(), - "2023-05-15T10:51:00".to_string() + "2023-05-15T10:51:00.000000000".to_string() ); let result = formatter.format( @@ -249,7 +249,7 @@ fn test_psql_format_date_time_naive() -> eyre::Result<()> { { assert_eq!( Value::DateTimeNaive(DateTimeNaive::new(0)).to_postgres_output(), - "1970-01-01T00:00:00".to_string() + "1970-01-01T00:00:00.000000000".to_string() ); let result = formatter.format( @@ -276,7 +276,7 @@ fn test_psql_format_date_time_utc() -> eyre::Result<()> { { assert_eq!( Value::DateTimeUtc(DateTimeUtc::new(1684147860000000000)).to_postgres_output(), - "2023-05-15T10:51:00+0000".to_string() + "2023-05-15T10:51:00.000000000+0000".to_string() ); let result = formatter.format( @@ -312,7 +312,7 @@ fn test_psql_format_date_time_utc() -> eyre::Result<()> { { assert_eq!( Value::DateTimeUtc(DateTimeUtc::new(0)).to_postgres_output(), - "1970-01-01T00:00:00+0000".to_string() + "1970-01-01T00:00:00.000000000+0000".to_string() ); let result = formatter.format( diff --git a/tests/test_postpone.rs b/tests/test_time_column.rs similarity index 77% rename from tests/test_postpone.rs rename to tests/test_time_column.rs index 8e3fe5ec..12f4a198 100644 --- a/tests/test_postpone.rs +++ b/tests/test_time_column.rs @@ -5,7 +5,8 @@ use differential_dataflow::operators::arrange::ArrangeByKey; use operator_test_utils::run_test; use pathway_engine::engine::dataflow::operators::time_column::{ - postpone_core, SelfCompactionTime, TimeColumnBuffer, TimeKey, + postpone_core, SelfCompactionTime, TimeColumnBuffer, TimeColumnCutoff, TimeColumnForget, + TimeKey, }; #[test] @@ -408,3 +409,122 @@ fn test_wrapper_two_times_forward_late() { .arrange_by_key() }); } + +#[test] +fn test_forget() { + let input = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (100, 22)), 0, 1)], + vec![((3, (200, 33)), 1, 1), ((4, (200, 44)), 1, 1)], + vec![((5, (300, 55)), 2, 1), ((6, (300, 66)), 2, 1)], + ]; + let expected = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (100, 22)), 0, 1)], + vec![ + ((1, (100, 11)), 1, -1), + ((2, (100, 22)), 1, -1), + ((3, (200, 33)), 1, 1), + ((4, (200, 44)), 1, 1), + ], + vec![ + ((3, (200, 33)), 2, -1), + ((4, (200, 44)), 2, -1), + ((5, (300, 55)), 2, 1), + ((6, (300, 66)), 2, 1), + ], + ]; + + run_test(input, expected, |coll| { + coll.forget(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .arrange_by_key() + }); +} + +#[test] +fn test_cutoff() { + let input = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (100, 22)), 0, 1)], + vec![((3, (100, 33)), 1, 1), ((4, (200, 44)), 1, 1)], + vec![((5, (300, 55)), 2, 1), ((6, (300, 66)), 2, 1)], + ]; + let expected = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (100, 22)), 0, 1)], + vec![((4, (200, 44)), 1, 1)], + vec![((5, (300, 55)), 2, 1), ((6, (300, 66)), 2, 1)], + ]; + + let expected_late = vec![vec![((3, (100, 33)), 1, 1)]]; + + run_test(input.clone(), expected, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .0 + .arrange_by_key() + }); + + run_test(input, expected_late, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .1 + .arrange_by_key() + }); +} + +#[test] +fn test_cutoff_late() { + let input = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (200, 22)), 0, 1)], + vec![((3, (200, 33)), 1, 1), ((4, (100, 44)), 1, 1)], + vec![((5, (300, 55)), 2, 1), ((6, (300, 66)), 2, 1)], + ]; + let expected = vec![ + vec![((2, (200, 22)), 0, 1)], + vec![((3, (200, 33)), 1, 1)], + vec![((5, (300, 55)), 2, 1), ((6, (300, 66)), 2, 1)], + ]; + let expected_late = vec![vec![((1, (100, 11)), 0, 1)], vec![((4, (100, 44)), 1, 1)]]; + + run_test(input.clone(), expected, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .0 + .arrange_by_key() + }); + + run_test(input, expected_late, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .1 + .arrange_by_key() + }); +} + +#[test] +fn test_cutoff_late_multiple() { + let input = vec![ + vec![((1, (100, 11)), 0, 1), ((2, (200, 22)), 0, 1)], + vec![((3, (200, 33)), 1, 1), ((4, (100, 44)), 1, 1)], + vec![ + ((5, (300, 55)), 2, 1), + ((6, (100, 66)), 2, 1), + ((7, (-100, 77)), 2, 1), + ], + ]; + let expected = vec![ + vec![((2, (200, 22)), 0, 1)], + vec![((3, (200, 33)), 1, 1)], + vec![((5, (300, 55)), 2, 1)], + ]; + + let expected_late = vec![ + vec![((1, (100, 11)), 0, 1)], + vec![((4, (100, 44)), 1, 1)], + vec![((6, (100, 66)), 2, 1), ((7, (-100, 77)), 2, 1)], + ]; + run_test(input.clone(), expected, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .0 + .arrange_by_key() + }); + + run_test(input, expected_late, |coll| { + coll.cut_off(coll.scope(), &|(t, _d)| *t + 1, &|(t, _d)| *t) + .1 + .arrange_by_key() + }); +}