Skip to content

Commit

Permalink
move stopes utils
Browse files Browse the repository at this point in the history
  • Loading branch information
zyaoj committed Dec 1, 2024
1 parent 7af70d6 commit ffd751f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/fairseq2/data/parquet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
concat_table,
get_parquet_dataset_metadata,
get_row_group_level_metadata,
pyarrow_table_to_torch_dict,
)

__all__ = [
Expand All @@ -40,4 +41,5 @@
"concat_table",
"get_parquet_dataset_metadata",
"get_row_group_level_metadata",
"pyarrow_table_to_torch_dict",
]
39 changes: 39 additions & 0 deletions src/fairseq2/data/parquet/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc


def is_list_like(arr):
return pa.types.is_list(arr.type) or pa.types.is_large_list(arr.type)


def _fix_list_offset(arr: pa.Array) -> pa.Array:
"""
Recursively fixes list offset to 0, so that arr.offsets are always starts from 0
and can be used easily downstream.
"""
if not is_list_like(arr):
return arr
if arr.offset == 0:
return arr

new_values = _fix_list_offset(pc.list_flatten(arr))
new_offsets = pc.subtract(arr.offsets, arr.offsets[0])

return (
pa.LargeListArray.from_arrays(new_offsets, new_values)
if pa.types.is_large_list(arr.type)
else pa.ListArray.from_arrays(new_offsets, new_values)
)


def pyarrow_column_to_array(arg: Union[pa.ChunkedArray, pa.Array]) -> pa.Array:
# see https://github.com/apache/arrow/issues/37318
if isinstance(arg, pa.Array):
return _fix_list_offset(arg)

return _fix_list_offset(
arg.chunk(0) if arg.num_chunks == 1 else arg.combine_chunks()
)
65 changes: 48 additions & 17 deletions src/fairseq2/data/parquet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
from numpy.typing import NDArray
from pyarrow.dataset import get_partition_keys # requires pyarrow >= 13

from fairseq2.data.parquet.arrow import pyarrow_column_to_array
from fairseq2.logging import get_log_writer
from fairseq2.data.data_pipeline import DataPipeline, DataPipelineBuilder, read_sequence

logger = get_log_writer(__name__)


@contextmanager
def pyarrow_cpu(nb_cpu: int) -> Generator[None, None, None]:
Expand All @@ -47,8 +51,9 @@ def torch_random_seed(seed: Optional[int] = None) -> Generator[None, None, None]
BatchOutputType = Union[pa.Table, pd.DataFrame, NestedDict]



def from_pyarrow_to_torch_tensor(
arr: Union[pa.Array, pa.ChunkedArray], strict: bool = True
arr: Union[pa.Array, pa.ChunkedArray], strict: bool = False
) -> NestedDictValue:
"""
struct_array = pa.Array.from_pandas([{"x": 4, "y": "RR"}] * 10)
Expand All @@ -60,12 +65,14 @@ def from_pyarrow_to_torch_tensor(
if arr.null_count != 0:
raise ValueError("to torch conversion does not support null values")

if isinstance(arr, pa.ChunkedArray):
arr = arr.chunks[0] if arr.num_chunks == 1 else arr.combine_chunks()
arr = pyarrow_column_to_array(arr)

arr_type = arr.type
if pa.types.is_primitive(arr_type):
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
try:
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
except Exception:
pass

try:
return torch.from_numpy(arr.to_numpy(zero_copy_only=True))
Expand All @@ -76,19 +83,29 @@ def from_pyarrow_to_torch_tensor(
return from_pyarrow_to_torch_tensor(arr.dictionary_decode())

if pa.types.is_string(arr_type):
return list(map(str, arr.to_pandas()))

if (
pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type)
) and pa.types.is_primitive(arr_type.value_type):
return torch.nested.as_nested_tensor(
list(map(torch.from_numpy, arr.to_pandas()))
)
return arr.to_pandas().tolist()

if pa.types.is_list(arr_type) or pa.types.is_large_list(arr_type):
if pa.types.is_primitive(arr_type.value_type):
return arr.to_pandas().map(torch.from_numpy).tolist()

if pa.types.is_fixed_size_list(arr_type.value_type) and pa.types.is_primitive(
arr_type.value_type.value_type
):
# FIXME: get the column global dtype for empty seq case
return (
arr.to_pandas()
.map(
lambda x: torch.from_numpy(
np.vstack(x) if len(x) > 0 else np.array([], dtype=np.float32)
)
)
.tolist()
)

if pa.types.is_fixed_size_list(arr_type) and pa.types.is_primitive(
arr_type.value_type
):
return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size)))
if pa.types.is_fixed_size_list(arr_type):
if pa.types.is_primitive(arr_type.value_type):
return torch.from_numpy(np.reshape(arr.values, (-1, arr_type.list_size)))

if pa.types.is_struct(arr_type):
return {
Expand All @@ -103,7 +120,7 @@ def from_pyarrow_to_torch_tensor(
if strict:
raise NotImplementedError(f"{arr_type} cannot be converted to torch.Tensor")
else:
return arr
return arr # keeping as in the orignal pyarrow form


def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = True) -> NestedDict:
Expand Down Expand Up @@ -505,3 +522,17 @@ def get_fragment_minimal_stats(frag):

df_stats = pd.DataFrame(stats)
return df_stats


def pyarrow_table_to_torch_dict(tt: pa.Table, strict: bool = False) -> NestedDict:
out = {}
for col in tt.column_names:
try:
out[col] = from_pyarrow_to_torch_tensor(tt[col], strict)
except ValueError as e:
logger.info(
f"Column {col} of type {tt[col].type} was not converted to torch as expected",
str(e),
)
out[col] = tt[col]
return out
11 changes: 11 additions & 0 deletions tests/unit/data/parquet/test_parquet_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from fairseq2.data.parquet.utils import (
get_parquet_dataset_metadata,
get_row_group_level_metadata,
pyarrow_table_to_torch_dict,
NestedDict,
pa,
pq,
)

Expand Down Expand Up @@ -79,3 +82,11 @@ def test_get_parquet_dataset_metadata(multi_partition_file_dataset):
"cat",
]
)


def test_nested_text_conversion():
nested_input = pa.array([["abc", "efg"], ["xyz"]])
tt = pa.Table.from_pydict({"nested_text": nested_input})
converted = pyarrow_table_to_torch_dict(tt)
# we want to keep this type unchanged
assert isinstance(converted["nested_text"], pa.Array)

0 comments on commit ffd751f

Please sign in to comment.