From 17589bbc5427233dc8c6219b613ae033cfa30478 Mon Sep 17 00:00:00 2001 From: Nathanael Maytan Date: Tue, 6 Aug 2024 07:46:53 -0400 Subject: [PATCH] Accept Python dict from client "write_dataframe" and TableAdapter (#771) * Accept dict for client 'write_dataframe' * Add test for writing dataframe from dict * Better generic dict name in TableStructure * Add support for dict to TableAdapter * Simplify generated_minimal example * Use newer TableAdapter name, rather than alias * Update changelog * Rename from_dict methods * Remove commented ignore for Pandas warning from new test --- CHANGELOG.md | 7 +++++++ tiled/_tests/test_writing.py | 25 ++++++++++++++++++++++ tiled/adapters/table.py | 32 +++++++++++++++++++++++++++++ tiled/client/container.py | 2 ++ tiled/examples/generated_minimal.py | 17 +++++++-------- tiled/serialization/table.py | 5 ++++- tiled/structures/table.py | 9 ++++++++ 7 files changed, 86 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d31c4d342..7ba645001 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ Write the date in place of the "Unreleased" in the case a new version is release ## Unreleased +### Added +- Add method to `TableAdapter` which accepts a Python dictionary. + +### Changed +- Make `tiled.client` accept a Python dictionary when fed to `write_dataframe()`. +- The `generated_minimal` example no longer requires pandas and instead uses a Python dict. + ### Fixed - A bug in `Context.__getstate__` caused picking to fail if applied twice. diff --git a/tiled/_tests/test_writing.py b/tiled/_tests/test_writing.py index 66cd8516b..61a0bf47f 100644 --- a/tiled/_tests/test_writing.py +++ b/tiled/_tests/test_writing.py @@ -174,6 +174,31 @@ def test_write_dataframe_partitioned(tree): assert result.specs == specs +def test_write_dataframe_dict(tree): + with Context.from_app( + build_app(tree, validation_registry=validation_registry) + ) as context: + client = from_context(context) + + data = {f"Column{i}": (1 + i) * numpy.ones(5) for i in range(5)} + df = pandas.DataFrame(data) + metadata = {"scan_id": 1, "method": "A"} + specs = [Spec("SomeSpec")] + + with record_history() as history: + client.write_dataframe(data, metadata=metadata, specs=specs) + # one request for metadata, one for data + assert len(history.requests) == 1 + 1 + + results = client.search(Key("scan_id") == 1) + result = results.values().first() + result_dataframe = result.read() + + pandas.testing.assert_frame_equal(result_dataframe, df) + assert result.metadata == metadata + assert result.specs == specs + + @pytest.mark.parametrize( "coo", [ diff --git a/tiled/adapters/table.py b/tiled/adapters/table.py index b169ca05d..4b7d3ff0c 100644 --- a/tiled/adapters/table.py +++ b/tiled/adapters/table.py @@ -57,6 +57,38 @@ def from_pandas( ddf, metadata=metadata, specs=specs, access_policy=access_policy ) + @classmethod + def from_dict( + cls, + *args: Any, + metadata: Optional[JSON] = None, + specs: Optional[List[Spec]] = None, + access_policy: Optional[AccessPolicy] = None, + npartitions: int = 1, + **kwargs: Any, + ) -> "TableAdapter": + """ + + Parameters + ---------- + args : + metadata : + specs : + access_policy : + npartitions : + kwargs : + + Returns + ------- + + """ + ddf = dask.dataframe.from_dict(*args, npartitions=npartitions, **kwargs) + if specs is None: + specs = [Spec("dataframe")] + return cls.from_dask_dataframe( + ddf, metadata=metadata, specs=specs, access_policy=access_policy + ) + @classmethod def from_dask_dataframe( cls, diff --git a/tiled/client/container.py b/tiled/client/container.py index 53719dfb8..7a87ce507 100644 --- a/tiled/client/container.py +++ b/tiled/client/container.py @@ -942,6 +942,8 @@ def write_dataframe( if isinstance(dataframe, dask.dataframe.DataFrame): structure = TableStructure.from_dask_dataframe(dataframe) + elif isinstance(dataframe, dict): + structure = TableStructure.from_dict(dataframe) else: structure = TableStructure.from_pandas(dataframe) client = self.new( diff --git a/tiled/examples/generated_minimal.py b/tiled/examples/generated_minimal.py index bee425dce..38774f1a5 100644 --- a/tiled/examples/generated_minimal.py +++ b/tiled/examples/generated_minimal.py @@ -1,9 +1,8 @@ import numpy -import pandas import xarray from tiled.adapters.array import ArrayAdapter -from tiled.adapters.dataframe import DataFrameAdapter +from tiled.adapters.dataframe import TableAdapter from tiled.adapters.mapping import MapAdapter from tiled.adapters.xarray import DatasetAdapter @@ -11,14 +10,12 @@ { "A": ArrayAdapter.from_array(numpy.ones((100, 100))), "B": ArrayAdapter.from_array(numpy.ones((100, 100, 100))), - "C": DataFrameAdapter.from_pandas( - pandas.DataFrame( - { - "x": 1 * numpy.ones(100), - "y": 2 * numpy.ones(100), - "z": 3 * numpy.ones(100), - } - ), + "C": TableAdapter.from_dict( + { + "x": 1 * numpy.ones(100), + "y": 2 * numpy.ones(100), + "z": 3 * numpy.ones(100), + }, npartitions=3, ), "D": DatasetAdapter.from_dataset( diff --git a/tiled/serialization/table.py b/tiled/serialization/table.py index 999ce1aa7..b339e6b40 100644 --- a/tiled/serialization/table.py +++ b/tiled/serialization/table.py @@ -10,7 +10,10 @@ def serialize_arrow(df, metadata, preserve_index=True): import pyarrow - table = pyarrow.Table.from_pandas(df, preserve_index=preserve_index) + if isinstance(df, dict): + table = pyarrow.Table.from_pydict(df) + else: + table = pyarrow.Table.from_pandas(df, preserve_index=preserve_index) sink = pyarrow.BufferOutputStream() with pyarrow.ipc.new_file(sink, table.schema) as writer: writer.write_table(table) diff --git a/tiled/structures/table.py b/tiled/structures/table.py index 81a35d5c4..8cf6de0f1 100644 --- a/tiled/structures/table.py +++ b/tiled/structures/table.py @@ -47,6 +47,15 @@ def from_pandas(cls, df): data_uri = B64_ENCODED_PREFIX + schema_b64 return cls(arrow_schema=data_uri, npartitions=1, columns=list(df.columns)) + @classmethod + def from_dict(cls, d): + import pyarrow + + schema_bytes = pyarrow.Table.from_pydict(d).schema.serialize() + schema_b64 = base64.b64encode(schema_bytes).decode("utf-8") + data_uri = B64_ENCODED_PREFIX + schema_b64 + return cls(arrow_schema=data_uri, npartitions=1, columns=list(d.keys())) + @property def arrow_schema_decoded(self): import pyarrow