From 819cd8f09abbdf451e4e3d78232016f96eb1d864 Mon Sep 17 00:00:00 2001 From: Aditya NG Date: Sun, 3 Mar 2024 10:27:59 +0530 Subject: [PATCH] feat(train): eval dataset pointed --- LLaVA/llava/train/train.py | 6 +++++- drivellava/scripts/compile_jsons.py | 12 +++++++++--- drivellava/scripts/train.py | 11 ++++++++--- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/LLaVA/llava/train/train.py b/LLaVA/llava/train/train.py index 477c668..f8242ea 100644 --- a/LLaVA/llava/train/train.py +++ b/LLaVA/llava/train/train.py @@ -780,8 +780,12 @@ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_path=data_args.data_path, data_args=data_args) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + eval_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.validation_data_path, + data_args=data_args) return dict(train_dataset=train_dataset, - eval_dataset=None, + eval_dataset=eval_dataset, data_collator=data_collator) diff --git a/drivellava/scripts/compile_jsons.py b/drivellava/scripts/compile_jsons.py index 0f47240..1ca4f36 100644 --- a/drivellava/scripts/compile_jsons.py +++ b/drivellava/scripts/compile_jsons.py @@ -10,9 +10,11 @@ from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON from drivellava.sparse_llava_dataset import get_drivellava_prompt from drivellava.trajectory_encoder import ( - TrajectoryEncoder, NUM_TRAJECTORY_TEMPLATES + NUM_TRAJECTORY_TEMPLATES, + TrajectoryEncoder, ) + def load_json_dataset( json_list: List[str], trajectory_encoder: TrajectoryEncoder, @@ -57,8 +59,12 @@ def main(): random.shuffle(train) random.shuffle(val) - new_train_json_path = os.path.abspath(f"checkpoints/train_{str(NUM_TRAJECTORY_TEMPLATES)}.json") - new_val_json_path = os.path.abspath(f"checkpoints/val_{NUM_TRAJECTORY_TEMPLATES}.json") + new_train_json_path = os.path.abspath( + f"checkpoints/train_{str(NUM_TRAJECTORY_TEMPLATES)}.json" + ) + new_val_json_path = os.path.abspath( + f"checkpoints/val_{NUM_TRAJECTORY_TEMPLATES}.json" + ) # Save train to a temp file with open(new_train_json_path, "w", encoding="utf-8") as f: diff --git a/drivellava/scripts/train.py b/drivellava/scripts/train.py index befdac7..48e8f32 100644 --- a/drivellava/scripts/train.py +++ b/drivellava/scripts/train.py @@ -74,8 +74,12 @@ def load_json_dataset_balanced( def main(): - train_json_path = os.path.join(COMMAVQ_DIR, f"train_{str(NUM_TRAJECTORY_TEMPLATES)}.json") - val_json_path = os.path.join(COMMAVQ_DIR, f"val_{str(NUM_TRAJECTORY_TEMPLATES)}.json") + train_json_path = os.path.join( + COMMAVQ_DIR, f"train_{str(NUM_TRAJECTORY_TEMPLATES)}.json" + ) + val_json_path = os.path.join( + COMMAVQ_DIR, f"val_{str(NUM_TRAJECTORY_TEMPLATES)}.json" + ) train = load_json_dataset_balanced( [ @@ -113,7 +117,7 @@ def main(): DEEPSPEED_JSON = os.path.abspath("./config/zero3.json") MODEL_NAME = "liuhaotian/llava-v1.5-7b" DATA_PATH = new_train_json_path # Replace with your JSON data path - # VAL_DATA_PATH = new_val_json_path + VAL_DATA_PATH = new_val_json_path IMAGE_FOLDER = os.path.expanduser( "~/Datasets/commavq" ) # Replace with your image folder path @@ -131,6 +135,7 @@ def main(): --model_name_or_path {MODEL_NAME} \ --version llava_llama_2 \ --data_path {DATA_PATH} \ + --validation_data_path {VAL_DATA_PATH} \ --image_folder {IMAGE_FOLDER} \ --vision_tower {VISION_TOWER} \ --mm_projector_type mlp2x_gelu \