diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 69e9d89da..10ecc0df8 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -6,4 +6,5 @@ sentencepiece tiktoken blobfile tabulate -transformers \ No newline at end of file +transformers +orjson \ No newline at end of file diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index a4b401849..394d4127c 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -372,10 +372,10 @@ def __init__(self): # validation configs self.parser.add_argument( - "--validation.batch_size", type=int, default=0 + "--validation.batch_size", type=int, default=None ) self.parser.add_argument( - "--validation.dataset", type=str, default="c4_mini", help="Dataset to use" + "--validation.dataset", type=str, help="Dataset to use", default=None ) self.parser.add_argument( "--validation.dataset_path", @@ -383,9 +383,10 @@ def __init__(self): help=""" Path to the dataset for validation in the file system. If provided, data will be loaded from this path instead of downloaded.""", + default=None ) self.parser.add_argument( - "--validation.eval_freq", type=int, default=1, help="How often to evaluate the model and log metrics to aim." + "--validation.eval_freq", type=int, default=None, help="How often to evaluate the model and log metrics to aim." ) self.parser.add_argument( "--validation.enable_val", type=bool, default=False, help="Whether to do validation." diff --git a/torchtitan/datasets/__init__.py b/torchtitan/datasets/__init__.py index 962ff1b57..bc05c8bfd 100644 --- a/torchtitan/datasets/__init__.py +++ b/torchtitan/datasets/__init__.py @@ -7,5 +7,5 @@ from torchtitan.datasets.hf_datasets import build_hf_data_loader __all__ = [ - "build_hf_data_loader" + "build_hf_data_loader", ] diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 06c9d3ea4..d1df4972f 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -220,9 +220,11 @@ def build_hf_data_loader( world_size, rank, infinite: bool = True, + hf_ds: Optional[HuggingFaceDataset] = None ): - hf_ds = HuggingFaceDataset( - dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite - ) + if not hf_ds: + hf_ds = HuggingFaceDataset( + dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite + ) return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index ceae1b36c..30abe831b 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader +from torchtitan.datasets.hf_datasets import HuggingFaceDataset from torchtitan.tokenizers.tokenizer import build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger @@ -103,8 +104,27 @@ def main(job_config: JobConfig): dp_rank ) - # validation batch size - val_bs = job_config.validation.batch_size if job_config.validation.batch_size != 0 else job_config.metrics.batch_size + # check if everything is specified for validation loop + if not job_config.validation.dataset_path and not job_config.validation.dataset: + raise ValueError("You didn't specify the validation dataset.") + + if not job_config.validation.eval_freq: + logger.info("You didn't specify the frequency of evaluation. The default value is 1024.") + if job_config.validation.batch_size == None: + logger.info("You didn't specify the batch size for validation. The batch size for the training will be used instead.") + val_bs = job_config.validation.batch_size if job_config.validation.batch_size else job_config.training.batch_size + eval_freq = job_config.validation.eval_freq if job_config.validation.eval_freq else 1024 + + val_dataset = HuggingFaceDataset( + job_config.validation.dataset, + job_config.validation.dataset_path, + job_config.training.data_processing_style, + tokenizer, + job_config.training.seq_len, + dp_degree, + dp_rank, + False, + ) # build model (using meta init) model_cls = model_name_to_cls[model_name] @@ -375,19 +395,21 @@ def loss_fn(pred, labels): if ( job_config.validation.enable_val and train_state.step == 1 - or train_state.step % job_config.validation.eval_freq == 0 + or train_state.step % eval_freq == 0 ): with torch.no_grad(): model.eval() val_data_loader = build_hf_data_loader( job_config.validation.dataset, job_config.validation.dataset_path, + job_config.training.data_processing_style, tokenizer, val_bs, job_config.training.seq_len, dp_degree, dp_rank, - False + False, + val_dataset ) num_flop_per_token_val = utils.get_num_flop_per_token_forward( utils.get_num_params(model, exclude_embedding=True), @@ -404,7 +426,7 @@ def loss_fn(pred, labels): gpu_memory_monitor, data_loading_times, time_last_val_log, - job_config.validation.eval_freq, + eval_freq, color, train_state.step, num_flop_per_token_val, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index f84338c70..da18d99a6 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -46,9 +46,9 @@ dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M data_process_style="chemlactica_style" [validation] +enable_val = true batch_size = 1 eval_freq = 3 -enable_val = true dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [experimental]