diff --git a/docs/user-guide/documentdataset.rst b/docs/user-guide/documentdataset.rst index 07ef41a2..445a0933 100644 --- a/docs/user-guide/documentdataset.rst +++ b/docs/user-guide/documentdataset.rst @@ -68,14 +68,16 @@ Let's walk through this code line by line. "books_dataset/books_02.jsonl"] * ``books = DocumentDataset.read_json(files, add_filename=True)`` This will read the files listed into memory. - The ``add_filename=True`` option preserves the name of the shard (``books_00.jsonl``, ``books_01.jsonl``, etc.) as an additional ``filename`` field. - When the dataset is written back to disk, this option (in conjunction with the ``write_to_filename`` option) ensure that documents stay in their original shard. + The ``add_filename=True`` option preserves the name of the shard (``books_00.jsonl``, ``books_01.jsonl``, etc.) as an additional ``file_name`` field. + When the dataset is written back to disk, this option (in conjunction with the ``write_to_filename`` option and ``filename_col`` ) ensure that documents stay in their original shard. This can be useful for manually inspecting the results of filtering shard by shard. + The ``add_filename`` option can also be used as a string, in which case it will be used as the name of the column (instead of the default ``file_name``). * ``filter_step = ...`` This constructs and applies a heuristic filter for the length of the document. More information is provided in the filtering page of the documentation. * ``long_books.to_json("long_books/", write_to_filename=True)`` This writes the filtered dataset to a new directory. As mentioned above, the ``write_to_filename=True`` preserves the sharding of the dataset. If the dataset was not read in with ``add_filename=True``, setting ``write_to_filename=True`` will throw an error. + If the dataset was read with ``add_filename="path"`` then along with ``write_to_filename=True`` the ``filename_col="path"`` will need to be set as well. ``DocumentDataset`` is just a wrapper around a `Dask dataframe `_. The underlying dataframe can be accessed with the ``DocumentDataset.df`` member variable. diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index 5b4caf51..1b70f57b 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -52,7 +52,7 @@ def read_json( backend: Literal["pandas", "cudf"] = "pandas", files_per_partition: Optional[int] = None, blocksize: Optional[str] = "1gb", - add_filename: bool = False, + add_filename: Union[bool, str] = False, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, **kwargs, @@ -64,7 +64,9 @@ def read_json( input_files: The path of the input file(s). backend: The backend to use for reading the data. files_per_partition: The number of files to read per partition. - add_filename: Whether to add a "file_name" column to the DataFrame. + add_filename: Whether to add a filename column to the DataFrame. + If True, a new column is added to the DataFrame called `file_name`. + If str, sets new column name. Default is False. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. columns: If not None, only these columns will be read from the file. @@ -91,7 +93,7 @@ def read_parquet( backend: Literal["pandas", "cudf"] = "pandas", files_per_partition: Optional[int] = None, blocksize: Optional[str] = "1gb", - add_filename=False, + add_filename: Union[bool, str] = False, columns: Optional[List[str]] = None, **kwargs, ) -> "DocumentDataset": @@ -102,7 +104,9 @@ def read_parquet( input_files: The path of the input file(s). backend: The backend to use for reading the data. files_per_partition: The number of files to read per partition. - add_filename: Whether to add a "file_name" column to the DataFrame. + add_filename: Whether to add a filename column to the DataFrame. + If True, a new column is added to the DataFrame called `file_name`. + If str, sets new column name. Default is False. columns: If not None, only these columns will be read from the file. There is a significant performance gain when specifying columns for Parquet files. @@ -135,7 +139,9 @@ def read_pickle( input_files: The path of the input file(s). backend: The backend to use for reading the data. files_per_partition: The number of files to read per partition. - add_filename: Whether to add a "file_name" column to the DataFrame. + add_filename: Whether to add a filename column to the DataFrame. + If True, a new column is added to the DataFrame called `file_name`. + If str, sets new column name. Default is False. columns: If not None, only these columns will be read from the file. """ @@ -154,6 +160,7 @@ def to_json( output_path: str, write_to_filename: bool = False, keep_filename_column: bool = False, + filename_col: str = "file_name", ): """ See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters. @@ -165,6 +172,7 @@ def to_json( write_to_filename=write_to_filename, keep_filename_column=keep_filename_column, output_type="jsonl", + filename_col=filename_col, ) def to_parquet( @@ -234,7 +242,7 @@ def _read_json_or_parquet( input_files: Union[str, List[str]], file_type: str, backend: Literal["cudf", "pandas"], - add_filename: bool, + add_filename: Union[bool, str] = False, files_per_partition: Optional[int] = None, blocksize: Optional[str] = None, input_meta: Union[str, dict] = None, diff --git a/nemo_curator/datasets/parallel_dataset.py b/nemo_curator/datasets/parallel_dataset.py index b9a5eee1..afce3877 100644 --- a/nemo_curator/datasets/parallel_dataset.py +++ b/nemo_curator/datasets/parallel_dataset.py @@ -1,11 +1,11 @@ import csv -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import dask.dataframe as dd import pandas as pd from nemo_curator.datasets.doc_dataset import DocumentDataset -from nemo_curator.utils.distributed_utils import write_to_disk +from nemo_curator.utils.distributed_utils import _resolve_filename_col, write_to_disk from nemo_curator.utils.file_utils import remove_path_extension from nemo_curator.utils.import_utils import gpu_only_import @@ -31,7 +31,7 @@ def read_simple_bitext( src_lang: str, tgt_lang: str, backend: str = "pandas", - add_filename: bool = False, + add_filename: Union[bool, str] = False, npartitions: int = 16, ): """See `read_single_simple_bitext_file_pair` docstring for what "simple_bitext" means and usage of other parameters. @@ -99,7 +99,7 @@ def read_single_simple_bitext_file_pair( tgt_lang: str, doc_id: str = None, backend: str = "cudf", - add_filename: bool = False, + add_filename: Union[bool, str] = False, ) -> Union[dd.DataFrame, "dask_cudf.DataFrame"]: """This function reads a pair of "simple bitext" files into a pandas DataFrame. A simple bitext is a commonly data format in machine translation. @@ -129,7 +129,10 @@ def read_single_simple_bitext_file_pair( tgt_lang (str): Target language, in ISO-639-1 (two character) format (e.g. 'en') doc_id (str, optional): A string document id to assign to every segment in the file. Defaults to None. backend (str, optional): Backend of the data frame. Defaults to "cudf". - add_filename (bool, optional): Add "file_name" as an extra field to every segment in the file. Defaults to False. + add_filename (Union[bool, str]): Whether to add a filename column to the DataFrame. + If True, a new column is added to the DataFrame called `file_name`. + If str, sets new column name. Default is False. + Returns: Union[dd.DataFrame, dask_cudf.DataFrame] @@ -162,6 +165,8 @@ def read_single_simple_bitext_file_pair( df_combined["tgt_lang"] = tgt_lang if add_filename: - df_combined["file_name"] = remove_path_extension(src_input_file) + df_combined[_resolve_filename_col(add_filename)] = remove_path_extension( + src_input_file + ) return df_combined diff --git a/nemo_curator/download/arxiv.py b/nemo_curator/download/arxiv.py index 449503e8..538d567d 100644 --- a/nemo_curator/download/arxiv.py +++ b/nemo_curator/download/arxiv.py @@ -415,6 +415,7 @@ def download_arxiv( output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download, + filename_col="file_name", ) return dataset diff --git a/nemo_curator/download/commoncrawl.py b/nemo_curator/download/commoncrawl.py index 68ad0de4..de7e333b 100644 --- a/nemo_curator/download/commoncrawl.py +++ b/nemo_curator/download/commoncrawl.py @@ -442,6 +442,7 @@ def download_common_crawl( output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download, + filename_col="file_name", ) return dataset diff --git a/nemo_curator/download/doc_builder.py b/nemo_curator/download/doc_builder.py index dbeac133..5ad3094e 100644 --- a/nemo_curator/download/doc_builder.py +++ b/nemo_curator/download/doc_builder.py @@ -112,12 +112,16 @@ def _download_and_extract_single_partition( keep_raw_download: bool, force_download: bool, input_meta: Union[str, dict] = None, + filename_col: str = "file_name", ) -> pd.DataFrame: url, output_path = paths if os.path.exists(output_path) and not force_download: partition = read_single_partition( - [output_path], backend="pandas", filetype=output_type, add_filename=True + [output_path], + backend="pandas", + filetype=output_type, + add_filename=filename_col, ) return partition @@ -141,8 +145,10 @@ def _download_and_extract_single_partition( partition = pd.DataFrame(records) filename = os.path.basename(output_path) output_dir = os.path.dirname(output_path) - partition["file_name"] = filename - single_partition_write_with_filename(partition, output_dir, output_type=output_type) + partition[filename_col] = filename + single_partition_write_with_filename( + partition, output_dir, output_type=output_type, filename_col=filename_col + ) if not keep_raw_download: os.remove(downloaded_file) @@ -160,6 +166,7 @@ def download_and_extract( keep_raw_download=False, force_download=False, input_meta: Union[str, dict] = None, + filename_col: str = "file_name", ) -> DocumentDataset: """ Downloads and extracts a dataset into a format accepted by the NeMo Curator @@ -178,7 +185,7 @@ def download_and_extract( directly read from them instead. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. - + filename_col : The name of the column that contains the filename. Default is "filename_col" Returns: A DocumentDataset of the downloaded data """ @@ -202,6 +209,7 @@ def download_and_extract( force_download=force_download, enforce_metadata=False, input_meta=input_meta, + filename_col=filename_col, meta=output_format, ) diff --git a/nemo_curator/download/wikipedia.py b/nemo_curator/download/wikipedia.py index 32088a27..54b85d45 100644 --- a/nemo_curator/download/wikipedia.py +++ b/nemo_curator/download/wikipedia.py @@ -811,6 +811,7 @@ def download_wikipedia( output_type=output_type, keep_raw_download=keep_raw_download, force_download=force_download, + filename_col="file_name", ) return dataset diff --git a/nemo_curator/modules/dataset_ops.py b/nemo_curator/modules/dataset_ops.py index e996946b..f11184a5 100644 --- a/nemo_curator/modules/dataset_ops.py +++ b/nemo_curator/modules/dataset_ops.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional import dask.dataframe as dd import numpy as np @@ -17,9 +17,10 @@ def __init__( seed: Optional[int] = None, npartitions: Optional[int] = None, partition_to_filename: Callable[[int], str] = default_filename, + filename_col: str = "file_name", ) -> None: """ - Randomly permutes the dataset. This will make the original "file_name" column invalid, so if the column is present it will be overwritten. + Randomly permutes the dataset. This will make the original filename_col column invalid, so if the column is present it will be overwritten. Args: seed: The random seed that will be used to determine which partition (file) each datapoint goes to. Setting the seed will guarantee determinism, but may be slightly slower (20-30% slower) @@ -35,6 +36,7 @@ def __init__( self.npartitions = npartitions self.partition_to_filename = partition_to_filename self.rand_col = "_shuffle_rand" + self.filename_col = filename_col def __call__(self, dataset: DocumentDataset) -> DocumentDataset: if self.seed is None: @@ -52,8 +54,10 @@ def shuffle_deterministic(self, dataset: DocumentDataset) -> DocumentDataset: shuffled_df = dataset.df.set_index(self.rand_col, npartitions=new_npartitions) shuffled_df = shuffled_df.reset_index(drop=True) - if "file_name" in shuffled_df: - shuffled_df["file_name"] = shuffled_df.map_partitions(self._add_filename) + if self.filename_col in shuffled_df: + shuffled_df[self.filename_col] = shuffled_df.map_partitions( + self._add_filename + ) return DocumentDataset(shuffled_df) @@ -98,15 +102,15 @@ def _partition_shuffle(self, partition, partition_info=None): drop=True ) - if "file_name" in partition: + if self.filename_col in partition: filename = self.partition_to_filename(partition_num) - partition["file_name"] = filename + partition[self.filename_col] = filename return partition def _add_filename(self, partition, partition_info=None): if partition_info is None: - return ["file_name"] * len(partition) + return [self.filename_col] * len(partition) filename = self.partition_to_filename(partition_info["number"]) diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 7ff463a7..a9d792b4 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -26,9 +26,8 @@ import warnings from contextlib import nullcontext from datetime import datetime -from itertools import zip_longest from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union import dask.dataframe as dd import numpy as np @@ -273,16 +272,30 @@ def _set_torch_to_use_rmm(): torch.cuda.memory.change_current_allocator(rmm_torch_allocator) +def _resolve_filename_col(filename: Union[bool, str]) -> Union[str, bool]: + if filename is False: + return False + elif filename is True: + return "file_name" + elif isinstance(filename, str): + return filename + else: + msg = f"Unknown filename value: {filename}" + raise ValueError(msg) + + def select_columns( df: Union[dd.DataFrame, pd.DataFrame, "cudf.DataFrame"], columns: List[str], filetype: Literal["jsonl", "json", "parquet"], - add_filename: bool, + add_filename: Union[bool, str], ) -> Union[dd.DataFrame, pd.DataFrame, "cudf.DataFrame"]: # We exclude parquet because the parquet readers already support column selection if filetype in ["jsonl", "json"] and columns is not None: - if add_filename and "file_name" not in columns: - columns.append("file_name") + if add_filename: + filename_str = _resolve_filename_col(add_filename) + if filename_str not in columns: + columns.append(filename_str) df = df[columns] return df @@ -292,19 +305,21 @@ def read_single_partition( files: List[str], backend: Literal["cudf", "pandas"] = "cudf", filetype: str = "jsonl", - add_filename: bool = False, + add_filename: Union[bool, str] = False, input_meta: Union[str, dict] = None, io_columns: Optional[List[str]] = None, **kwargs, ) -> Union["cudf.DataFrame", pd.DataFrame]: """ This function reads a file with cuDF, sorts the columns of the DataFrame - and adds a "file_name" column. + and adds a filename column. Args: files: The path to the jsonl files to read. backend: The backend to use for reading the data. Either "cudf" or "pandas". - add_filename: Whether to add a "file_name" column to the DataFrame. + add_filename: Whether to add a filename column to the DataFrame. + If True, a new column is added to the DataFrame called `file_name`. + If str, sets new column name. Default is False. input_meta: A dictionary or a string formatted as a dictionary, which outlines the field names and their respective data types within the JSONL input file. columns: If not None, only these columns will be read from the file. @@ -368,7 +383,7 @@ def read_single_partition( for file in files: df = read_f(file, **read_kwargs, **kwargs) if add_filename: - df["file_name"] = os.path.basename(file) + df[_resolve_filename_col(add_filename)] = os.path.basename(file) df = select_columns(df, io_columns, filetype, add_filename) df_ls.append(df) @@ -384,7 +399,7 @@ def read_data_blocksize( backend: Literal["cudf", "pandas"], file_type: Literal["parquet", "jsonl"], blocksize: str, - add_filename: bool = False, + add_filename: Union[bool, str] = False, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, **kwargs, @@ -392,7 +407,6 @@ def read_data_blocksize( read_kwargs = dict() - postprocessing_func: Optional[Callable[[dd.DataFrame], dd.DataFrame]] = None if file_type == "jsonl": warnings.warn( "If underlying JSONL data does not have a consistent schema, reading with blocksize will fail. " @@ -427,9 +441,8 @@ def read_data_blocksize( def extract_filename(path: str) -> str: return os.path.basename(path) - read_kwargs["include_path_column"] = add_filename + read_kwargs["include_path_column"] = _resolve_filename_col(add_filename) read_kwargs["path_converter"] = extract_filename - postprocessing_func = lambda df: df.rename(columns={"path": "file_name"}) elif file_type == "parquet": if backend == "cudf" and not DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA: @@ -457,8 +470,6 @@ def extract_filename(path: str) -> str: with dask.config.set({"dataframe.backend": backend}): df = read_func(input_files, blocksize=blocksize, **read_kwargs, **kwargs) - if postprocessing_func is not None: - df = postprocessing_func(df) output = select_columns(df, columns, file_type, add_filename) return output[sorted(output.columns)] @@ -468,7 +479,7 @@ def read_data_files_per_partition( input_files: List[str], file_type: Literal["parquet", "json", "jsonl"], backend: Literal["cudf", "pandas"] = "cudf", - add_filename: bool = False, + add_filename: Union[bool, str] = False, files_per_partition: Optional[int] = None, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, @@ -500,7 +511,7 @@ def read_data_files_per_partition( def read_pandas_pickle( file: str, - add_filename: bool = False, + add_filename: Union[bool, str] = False, columns: Optional[List[str]] = None, **kwargs, ) -> pd.DataFrame: @@ -530,7 +541,7 @@ def read_data( backend: Literal["cudf", "pandas"] = "cudf", blocksize: Optional[str] = None, files_per_partition: Optional[int] = 1, - add_filename: bool = False, + add_filename: Union[bool, str] = False, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, **kwargs, @@ -679,6 +690,7 @@ def single_partition_write_with_filename( output_file_dir: str, keep_filename_column: bool = False, output_type: str = "jsonl", + filename_col: str = "file_name", ): """ This function processes a DataFrame and writes it to disk @@ -686,14 +698,15 @@ def single_partition_write_with_filename( Args: df: A DataFrame. output_file_dir: The output file path. - keep_filename_column: Boolean representing whether to keep or drop the "file_name" column, if it exists. + keep_filename_column: Boolean representing whether to keep or drop the `filename_col`, if it exists. output_type: The type of output file to write. Can be "jsonl" or "parquet". + filename_col: The name of the column that contains the filename. Default is "file_name" Returns: If the DataFrame is non-empty, return a Series containing a single element, True. If the DataFrame is empty, return a Series containing a single element, False. """ - assert "file_name" in df.columns + assert filename_col in df.columns if len(df) > 0: empty_partition = False @@ -709,14 +722,14 @@ def single_partition_write_with_filename( success_ser = pd.Series([empty_partition]) if not empty_partition: - filenames = df.file_name.unique() + filenames = df[filename_col].unique() filenames = list(filenames.values_host) if is_cudf_type(df) else list(filenames) num_files = len(filenames) for filename in filenames: - out_df = df[df.file_name == filename] if num_files > 1 else df + out_df = df[df[filename_col] == filename] if num_files > 1 else df if not keep_filename_column: - out_df = out_df.drop("file_name", axis=1) + out_df = out_df.drop(filename_col, axis=1) filename = ( Path(filename).stem if output_type != "bitext" else Path(filename).name @@ -824,24 +837,26 @@ def _merge_tmp_simple_bitext_partitions(tmp_output_dir: str, output_dir: str): def write_to_disk( df, output_path: str, - write_to_filename: bool = False, + write_to_filename: Union[bool, str] = False, keep_filename_column: bool = False, output_type: str = "jsonl", ): """ This function writes a Dask DataFrame to the specified file path. If write_to_filename is True, then it expects the - DataFrame to have a "file_name" column that specifies where to write the document. + DataFrame to have a `filename_col` that specifies where to write the document. Args: df: A Dask DataFrame. output_path: The output file path. - write_to_filename: Boolean representing whether to write the filename using the "file_name" column. - keep_filename_column: Boolean representing whether to keep or drop the "file_name" column, if it exists. + write_to_filename: Whether to write the filename using the filename column. + If True the `file_name` column is used to write out. + If str, uses that as the filename column to write to. + keep_filename_column: Boolean representing whether to keep or drop the filename column, if it exists. output_type: The type of output file to write. Can be "jsonl" or "parquet". - """ + filename_col = _resolve_filename_col(write_to_filename) # output_path is a file name if isinstance(output_path, str) and output_path.endswith(".jsonl"): if df.npartitions == 1: @@ -856,9 +871,9 @@ def write_to_disk( ) # output_path is a directory - elif write_to_filename and "file_name" not in df.columns: + elif write_to_filename and filename_col not in df.columns: raise ValueError( - "write_using_filename is True but no file_name column found in DataFrame" + f"write_using_filename is True but no {filename_col} column found in DataFrame" ) if is_cudf_type(df): @@ -870,12 +885,14 @@ def write_to_disk( # output_path is a directory if write_to_filename and output_type != "bitext": + os.makedirs(output_path, exist_ok=True) output = df.map_partitions( single_partition_write_with_filename, output_path, keep_filename_column=keep_filename_column, output_type=output_type, + filename_col=filename_col, meta=output_meta, enforce_metadata=False, ) @@ -890,7 +907,7 @@ def write_to_disk( os.makedirs(output_path, exist_ok=True) tmp_output_file_dir = os.path.join(output_path, ".tmp") os.makedirs(tmp_output_file_dir, exist_ok=True) - file_name = os.path.basename(list(df.file_name.unique())[0]) + file_name = os.path.basename(list(df[filename_col].unique())[0]) else: tmp_output_file_dir = os.path.join(output_path, ".tmp") os.makedirs(tmp_output_file_dir, exist_ok=True) diff --git a/nemo_curator/utils/file_utils.py b/nemo_curator/utils/file_utils.py index 4632346a..5df61a2f 100644 --- a/nemo_curator/utils/file_utils.py +++ b/nemo_curator/utils/file_utils.py @@ -220,6 +220,7 @@ def write_dataframe_by_meta( output_type: str = "jsonl", include_values: List[str] = None, exclude_values: List[str] = None, + filename_col: str = "file_name", ): counts = df[metadata_field].value_counts().to_dict() @@ -236,7 +237,10 @@ def write_dataframe_by_meta( if remove_metadata: meta_slice = meta_slice.drop(columns=[metadata_field]) single_partition_write_with_filename( - meta_slice, meta_output_dir, output_type=output_type + meta_slice, + meta_output_dir, + output_type=output_type, + filename_col=filename_col, ) return counts @@ -294,13 +298,14 @@ def separate_by_metadata( input_type: str = "jsonl", include_values: List[str] = None, exclude_values: List[str] = None, + filename_col: str = "file_name", ) -> dict: """ Saves the dataframe to subfolders named after a metadata Args: input_data: Either a DataFrame or a string representing the path to the input directory. - If a DataFrame is provided, it must have a "file_name" column for the shard. + If a DataFrame is provided, it must have a filename_col for the shard. output_dir: The base directory for which all metadata based subdirs will be created under metadata_field: The metadata field to split on remove_metadata: Whether to remove the metadata from the dataframe when saving it @@ -310,7 +315,7 @@ def separate_by_metadata( If provided, only the items matching these values should be kept. exclude_values: A list of strings representing specific values to be excluded or ignored. If provided, any items matching these values should be skipped. - + filename_col: The column name in the DataFrame that contains the filename. Default is "file_name". Returns: A delayed dictionary mapping each metadata to the count of entries with that metadata value. @@ -357,7 +362,7 @@ def separate_by_metadata( get_all_files_paths_under(input_data), file_type=input_type, backend="pandas", - add_filename=True, + add_filename=filename_col, ) delayed_counts = [ delayed(write_dataframe_by_meta)( @@ -368,6 +373,7 @@ def separate_by_metadata( output_type, include_values, exclude_values, + filename_col, ) for partition in input_data.to_delayed() ] diff --git a/tests/test_io.py b/tests/test_io.py index 546ccf27..a4280ab1 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -147,25 +147,24 @@ def test_meta_str(self, jsonl_dataset): class TestWriteWithFilename: @pytest.mark.parametrize("keep_filename_column", [True, False]) @pytest.mark.parametrize("file_ext", ["jsonl", "parquet"]) + @pytest.mark.parametrize("filename_col", ["file_name", "filename"]) def test_multifile_single_partition( - self, - tmp_path, - keep_filename_column, - file_ext, + self, tmp_path, keep_filename_column, file_ext, filename_col ): - df = pd.DataFrame({"a": [1, 2, 3], "file_name": ["file0", "file1", "file1"]}) + df = pd.DataFrame({"a": [1, 2, 3], filename_col: ["file0", "file1", "file1"]}) single_partition_write_with_filename( df=df, output_file_dir=tmp_path, keep_filename_column=keep_filename_column, output_type=file_ext, + filename_col=filename_col, ) assert os.path.exists(tmp_path / f"file0.{file_ext}") assert os.path.exists(tmp_path / f"file1.{file_ext}") if not keep_filename_column: - df = df.drop("file_name", axis=1) + df = df.drop(filename_col, axis=1) df1 = read_single_partition( files=[tmp_path / f"file0.{file_ext}"], backend="pandas", filetype=file_ext @@ -219,18 +218,19 @@ def test_multifile_single_partition_error(self, tmp_path): ("parquet", DocumentDataset.read_parquet), ], ) - def test_multifile_multi_partition(self, tmp_path, file_ext, read_f): - df1 = pd.DataFrame({"a": [1, 2, 3], "file_name": ["file1", "file2", "file2"]}) + @pytest.mark.parametrize("filename_col", ["file_name", "filename"]) + def test_multifile_multi_partition(self, tmp_path, file_ext, read_f, filename_col): + df1 = pd.DataFrame({"a": [1, 2, 3], filename_col: ["file1", "file2", "file2"]}) df2 = df1.copy() - df2["file_name"] = "file3" + df2[filename_col] = "file3" df3 = df1.copy() - df3["file_name"] = ["file4", "file5", "file6"] + df3[filename_col] = ["file4", "file5", "file6"] ddf = dd.concat([df1, df2, df3]) - ddf["file_name"] = ddf["file_name"] + f".{file_ext}" + ddf[filename_col] = ddf[filename_col] + f".{file_ext}" write_to_disk( df=ddf, output_path=tmp_path / file_ext, - write_to_filename=True, + write_to_filename=filename_col, output_type=file_ext, ) @@ -239,7 +239,7 @@ def test_multifile_multi_partition(self, tmp_path, file_ext, read_f): blocksize=None, files_per_partition=2, backend="pandas", - add_filename=True, + add_filename=filename_col, ).df assert_eq(got_df, ddf, check_index=False) diff --git a/tests/test_read_data.py b/tests/test_read_data.py index 29013479..0c9a2aa5 100644 --- a/tests/test_read_data.py +++ b/tests/test_read_data.py @@ -289,19 +289,23 @@ def test_read_data_fpp_partitioning( pytest.param("cudf", marks=pytest.mark.gpu), ], ) -def test_read_data_blocksize_add_filename_jsonl(mock_multiple_jsonl_files, backend): +@pytest.mark.parametrize("filename_arg", [True, "some_filename"]) +def test_read_data_blocksize_add_filename_jsonl( + mock_multiple_jsonl_files, backend, filename_arg +): df = read_data_blocksize( input_files=mock_multiple_jsonl_files, backend=backend, file_type="jsonl", blocksize="128Mib", - add_filename=True, + add_filename=filename_arg, input_meta=None, columns=None, ) - assert "file_name" in df.columns - file_names = df["file_name"].unique().compute() + filename_str = "file_name" if filename_arg is True else filename_arg + assert filename_str in df.columns + file_names = df[filename_str].unique().compute() if backend == "cudf": file_names = file_names.to_pandas() @@ -318,7 +322,10 @@ def test_read_data_blocksize_add_filename_jsonl(mock_multiple_jsonl_files, backe pytest.param("cudf", marks=pytest.mark.gpu), ], ) -def test_read_data_blocksize_add_filename_parquet(mock_multiple_parquet_files, backend): +@pytest.mark.parametrize("filename_arg", [True, "some_filename"]) +def test_read_data_blocksize_add_filename_parquet( + mock_multiple_parquet_files, backend, filename_arg +): with pytest.raises( ValueError, match="add_filename and blocksize cannot be set at the same time for Parquet files", @@ -328,7 +335,7 @@ def test_read_data_blocksize_add_filename_parquet(mock_multiple_parquet_files, b backend=backend, file_type="parquet", blocksize="128Mib", - add_filename=True, + add_filename=filename_arg, input_meta=None, columns=None, ) @@ -343,8 +350,13 @@ def test_read_data_blocksize_add_filename_parquet(mock_multiple_parquet_files, b ("pandas", "parquet"), ], ) +@pytest.mark.parametrize("filename_arg", [True, "some_filename"]) def test_read_data_fpp_add_filename( - mock_multiple_jsonl_files, mock_multiple_parquet_files, backend, file_type + mock_multiple_jsonl_files, + mock_multiple_parquet_files, + backend, + file_type, + filename_arg, ): input_files = ( mock_multiple_jsonl_files @@ -357,14 +369,16 @@ def test_read_data_fpp_add_filename( backend=backend, file_type=file_type, files_per_partition=NUM_FILES, - add_filename=True, + add_filename=filename_arg, input_meta=None, columns=None, ) + filename_str = "file_name" if filename_arg is True else filename_arg + assert filename_str in df.columns assert list(df.columns) == list(df.head().columns) - assert set(df.columns) == {"file_name", "id", "text"} - file_names = df["file_name"].unique().compute() + assert set(df.columns) == {filename_str, "id", "text"} + file_names = df[filename_str].unique().compute() if backend == "cudf": file_names = file_names.to_pandas()