Skip to content

Commit

Permalink
import problem
Browse files Browse the repository at this point in the history
  • Loading branch information
FilyaGeikyan committed Sep 10, 2024
1 parent 6568304 commit 741cff5
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ sentencepiece
tiktoken
blobfile
tabulate
transformers
transformers
orjson
7 changes: 4 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,20 +372,21 @@ def __init__(self):

# validation configs
self.parser.add_argument(
"--validation.batch_size", type=int, default=0
"--validation.batch_size", type=int, default=None
)
self.parser.add_argument(
"--validation.dataset", type=str, default="c4_mini", help="Dataset to use"
"--validation.dataset", type=str, help="Dataset to use", default=None
)
self.parser.add_argument(
"--validation.dataset_path",
type=str,
help="""
Path to the dataset for validation in the file system. If provided, data will be
loaded from this path instead of downloaded.""",
default=None
)
self.parser.add_argument(
"--validation.eval_freq", type=int, default=1, help="How often to evaluate the model and log metrics to aim."
"--validation.eval_freq", type=int, default=None, help="How often to evaluate the model and log metrics to aim."
)
self.parser.add_argument(
"--validation.enable_val", type=bool, default=False, help="Whether to do validation."
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
from torchtitan.datasets.hf_datasets import build_hf_data_loader

__all__ = [
"build_hf_data_loader"
"build_hf_data_loader",
]
8 changes: 5 additions & 3 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,11 @@ def build_hf_data_loader(
world_size,
rank,
infinite: bool = True,
hf_ds: Optional[HuggingFaceDataset] = None
):
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite
)
if not hf_ds:
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, data_processing_style, tokenizer, seq_len, world_size, rank, infinite
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
32 changes: 27 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtitan.checkpoint import CheckpointManager, TrainState
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader
from torchtitan.datasets.hf_datasets import HuggingFaceDataset
from torchtitan.tokenizers.tokenizer import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
Expand Down Expand Up @@ -103,8 +104,27 @@ def main(job_config: JobConfig):
dp_rank
)

# validation batch size
val_bs = job_config.validation.batch_size if job_config.validation.batch_size != 0 else job_config.metrics.batch_size
# check if everything is specified for validation loop
if not job_config.validation.dataset_path and not job_config.validation.dataset:
raise ValueError("You didn't specify the validation dataset.")

if not job_config.validation.eval_freq:
logger.info("You didn't specify the frequency of evaluation. The default value is 1024.")
if job_config.validation.batch_size == None:
logger.info("You didn't specify the batch size for validation. The batch size for the training will be used instead.")
val_bs = job_config.validation.batch_size if job_config.validation.batch_size else job_config.training.batch_size
eval_freq = job_config.validation.eval_freq if job_config.validation.eval_freq else 1024

val_dataset = HuggingFaceDataset(
job_config.validation.dataset,
job_config.validation.dataset_path,
job_config.training.data_processing_style,
tokenizer,
job_config.training.seq_len,
dp_degree,
dp_rank,
False,
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
Expand Down Expand Up @@ -375,19 +395,21 @@ def loss_fn(pred, labels):
if (
job_config.validation.enable_val
and train_state.step == 1
or train_state.step % job_config.validation.eval_freq == 0
or train_state.step % eval_freq == 0
):
with torch.no_grad():
model.eval()
val_data_loader = build_hf_data_loader(
job_config.validation.dataset,
job_config.validation.dataset_path,
job_config.training.data_processing_style,
tokenizer,
val_bs,
job_config.training.seq_len,
dp_degree,
dp_rank,
False
False,
val_dataset
)
num_flop_per_token_val = utils.get_num_flop_per_token_forward(
utils.get_num_params(model, exclude_embedding=True),
Expand All @@ -404,7 +426,7 @@ def loss_fn(pred, labels):
gpu_memory_monitor,
data_loading_times,
time_last_val_log,
job_config.validation.eval_freq,
eval_freq,
color,
train_state.step,
num_flop_per_token_val,
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ dataset = "chemlactica_train_mini" # supported datasets: c4_test (2K), c4 (177M
data_process_style="chemlactica_style"

[validation]
enable_val = true
batch_size = 1
eval_freq = 3
enable_val = true
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
Expand Down

0 comments on commit 741cff5

Please sign in to comment.