Skip to content

Commit

Permalink
Always set non-null writer batch size (#7258)
Browse files Browse the repository at this point in the history
always set non-null writer batch size
  • Loading branch information
lhoestq authored Oct 28, 2024
1 parent 444ce83 commit ff0149f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 51 deletions.
55 changes: 44 additions & 11 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from fsspec.core import url_to_fs

from . import config
from .features import Features, Image, Value
from .features import Audio, Features, Image, Value, Video
from .features.features import (
FeatureType,
_ArrayXDExtensionType,
_visit,
cast_to_python_objects,
generate_from_arrow_type,
get_nested_type,
Expand All @@ -48,6 +49,45 @@
type_ = type # keep python's type function


def get_writer_batch_size(features: Optional[Features]) -> Optional[int]:
"""
Get the writer_batch_size that defines the maximum row group size in the parquet files.
The default in `datasets` is 1,000 but we lower it to 100 for image/audio datasets and 10 for videos.
This allows to optimize random access to parquet file, since accessing 1 row requires
to read its entire row group.
This can be improved to get optimized size for querying/iterating
but at least it matches the dataset viewer expectations on HF.
Args:
features (`datasets.Features` or `None`):
Dataset Features from `datasets`.
Returns:
writer_batch_size (`Optional[int]`):
Writer batch size to pass to a dataset builder.
If `None`, then it will use the `datasets` default.
"""
if not features:
return None

batch_size = np.inf

def set_batch_size(feature: FeatureType) -> None:
nonlocal batch_size
if isinstance(feature, Image):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
elif isinstance(feature, Audio):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
elif isinstance(feature, Video):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
elif isinstance(feature, Value) and feature.dtype == "binary":
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)

_visit(features, set_batch_size)

return None if batch_size is np.inf else batch_size


class SchemaInferenceError(ValueError):
pass

Expand Down Expand Up @@ -340,7 +380,9 @@ def __init__(

self.fingerprint = fingerprint
self.disable_nullable = disable_nullable
self.writer_batch_size = writer_batch_size
self.writer_batch_size = (
writer_batch_size or get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE
)
self.update_features = update_features
self.with_metadata = with_metadata
self.unit = unit
Expand All @@ -353,11 +395,6 @@ def __init__(
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
self.hkey_record = []

if self.writer_batch_size is None and self._features is not None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

def __len__(self):
"""Return the number of writed and staged examples"""
return self._num_examples + len(self.current_examples) + len(self.current_rows)
Expand Down Expand Up @@ -402,10 +439,6 @@ def _build_writer(self, inferred_schema: pa.Schema):
schema = schema.with_metadata({})
self._schema = schema
self.pa_writer = self._WRITER_CLASS(self.stream, schema)
if self.writer_batch_size is None:
from .io.parquet import get_writer_batch_size

self.writer_batch_size = get_writer_batch_size(self._features) or config.DEFAULT_MAX_BATCH_SIZE

@property
def schema(self):
Expand Down
42 changes: 2 additions & 40 deletions src/datasets/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import BinaryIO, Optional, Union

import fsspec
import numpy as np
import pyarrow.parquet as pq

from .. import Audio, Dataset, Features, Image, NamedSplit, Value, Video, config
from ..features.features import FeatureType, _visit
from .. import Dataset, Features, NamedSplit, config
from ..arrow_writer import get_writer_batch_size
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
from ..packaged_modules.parquet.parquet import Parquet
Expand All @@ -15,43 +14,6 @@
from .abc import AbstractDatasetReader


def get_writer_batch_size(features: Features) -> Optional[int]:
"""
Get the writer_batch_size that defines the maximum row group size in the parquet files.
The default in `datasets` is 1,000 but we lower it to 100 for image datasets.
This allows to optimize random access to parquet file, since accessing 1 row requires
to read its entire row group.
This can be improved to get optimized size for querying/iterating
but at least it matches the dataset viewer expectations on HF.
Args:
ds_config_info (`datasets.info.DatasetInfo`):
Dataset info from `datasets`.
Returns:
writer_batch_size (`Optional[int]`):
Writer batch size to pass to a dataset builder.
If `None`, then it will use the `datasets` default.
"""

batch_size = np.inf

def set_batch_size(feature: FeatureType) -> None:
nonlocal batch_size
if isinstance(feature, Image):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS)
elif isinstance(feature, Audio):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS)
elif isinstance(feature, Video):
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_VIDEO_DATASETS)
elif isinstance(feature, Value) and feature.dtype == "binary":
batch_size = min(batch_size, config.PARQUET_ROW_GROUP_SIZE_FOR_BINARY_DATASETS)

_visit(features, set_batch_size)

return None if batch_size is np.inf else batch_size


class ParquetDatasetReader(AbstractDatasetReader):
def __init__(
self,
Expand Down

0 comments on commit ff0149f

Please sign in to comment.