Skip to content

Commit

Permalink
add default cfg logic; fix data_mixture demo
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Jan 24, 2025
1 parent a99c9b5 commit 3c9caf5
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 7 deletions.
6 changes: 3 additions & 3 deletions data_juicer/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .config import (export_config, get_init_configs, init_configs,
merge_config, prepare_side_configs)
from .config import (export_config, get_default_cfg, get_init_configs,
init_configs, merge_config, prepare_side_configs)

__all__ = [
'init_configs', 'get_init_configs', 'export_config', 'merge_config',
'prepare_side_configs'
'prepare_side_configs', 'get_default_cfg'
]
33 changes: 33 additions & 0 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,36 @@ def get_init_configs(cfg: Union[Namespace, Dict]):
json.dump(cfg, f)
inited_dj_cfg = init_configs(['--config', temp_file])
return inited_dj_cfg


def get_default_cfg():
"""Get default config values from config_all.yaml"""
cfg = Namespace()

# Get path to config_all.yaml
config_dir = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(config_dir,
'../../configs/config_all.yaml')

# Load default values from yaml
with open(default_config_path, 'r', encoding='utf-8') as f:
defaults = yaml.safe_load(f)

# Convert to flat dictionary for namespace
flat_defaults = {
'executor_type': 'default',
'ray_address': 'auto',
'suffixes': None,
'text_keys': 'text',
'add_suffix': False,
'export_path': './outputs',
# Add other top-level keys from config_all.yaml
**defaults
}

# Update cfg with defaults
for key, value in flat_defaults.items():
if not hasattr(cfg, key):
setattr(cfg, key, value)

return cfg
2 changes: 1 addition & 1 deletion data_juicer/core/executor/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, cfg: Optional[Namespace] = None):
:param cfg: optional jsonargparse Namespace.
"""
super().__init__(cfg)
self.executor_type = 'local'
self.executor_type = 'default'
self.work_dir = self.cfg.work_dir

self.tracer = None
Expand Down
7 changes: 5 additions & 2 deletions demos/data_mixture/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import pandas as pd
import streamlit as st

from data_juicer.core.data.dataset_builder import DatasetBuilder
from data_juicer.config import get_default_cfg

if st.__version__ >= '1.23.0':
data_editor = st.data_editor
Expand Down Expand Up @@ -96,7 +96,10 @@ def mix_dataset():
' '.join([str(weight), ds_file])
for ds_file, weight in zip(ds_files, weights)
])
df = pd.DataFrame(DatasetBuilder(data_path).load_dataset())
cfg = get_default_cfg()
cfg.dataset_path = data_path
dataset_builder = DatasetBuilder(cfg)
df = pd.DataFrame(dataset_builder.load_dataset())

st.session_state.dataset = df
else:
Expand Down
28 changes: 27 additions & 1 deletion tests/config/test_config_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jsonargparse import Namespace

from data_juicer.config import init_configs
from data_juicer.config import init_configs, get_default_cfg
from data_juicer.ops import load_ops
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

Expand Down Expand Up @@ -276,5 +276,31 @@ def test_op_params_parsing(self):
self.assertIn(base_param_key, params)


def test_get_default_cfg(self):
"""Test getting default configuration from config_all.yaml"""
# Get default config
cfg = get_default_cfg()

# Verify basic default values
self.assertIsInstance(cfg, Namespace)

# Test essential defaults
self.assertEqual(cfg.executor_type, 'default')
self.assertEqual(cfg.ray_address, 'auto')
self.assertEqual(cfg.text_keys, 'text')
self.assertEqual(cfg.add_suffix, False)
self.assertEqual(cfg.export_path, '/path/to/result/dataset.jsonl')
self.assertEqual(cfg.suffixes, [])

# Test other important defaults from config_all.yaml
self.assertTrue(hasattr(cfg, 'np')) # Number of processes
self.assertTrue(hasattr(cfg, 'use_cache')) # Cache usage flag
self.assertTrue(hasattr(cfg, 'temp_dir')) # Temporary directory

# Test default values are of correct type
self.assertIsInstance(cfg.executor_type, str)
self.assertIsInstance(cfg.add_suffix, bool)
self.assertIsInstance(cfg.export_path, str)

if __name__ == '__main__':
unittest.main()

0 comments on commit 3c9caf5

Please sign in to comment.