diff --git a/drivellava/scripts/eval.py b/drivellava/scripts/eval.py index c5185c8..320d489 100644 --- a/drivellava/scripts/eval.py +++ b/drivellava/scripts/eval.py @@ -32,7 +32,7 @@ def main(): # fine_tuned_model_path = "liuhaotian/llava-v1.5-7b" fine_tuned_model_path = os.path.expanduser( - "~/Datasets/checkpoints/checkpoint-600/" + "~/Datasets/checkpoints/checkpoint-100/" ) args = type( diff --git a/drivellava/scripts/train.py b/drivellava/scripts/train.py index d82e9b2..788630c 100644 --- a/drivellava/scripts/train.py +++ b/drivellava/scripts/train.py @@ -7,7 +7,7 @@ import random import subprocess import sys -from typing import List +from typing import Dict, List from drivellava.constants import COMMAVQ_DIR @@ -42,6 +42,67 @@ def load_json_dataset( return data +def load_json_dataset_balanced( + json_list: List[str], + trajectory_encoder: TrajectoryEncoder, +): + from drivellava.sparse_llava_dataset import get_drivellava_prompt + + data = [] + for json_path in json_list: + with open(json_path, "r", encoding="utf-8") as f: + loaded = json.load(f) + for index in range(len(loaded)): + assert len(loaded[index]["conversations"][1]["value"]) == 1 + # print('val', loaded[index]["conversations"][1]["value"]) + # exit() + + loaded[index]["conversations"][1]["value"] = ( + "Selected Trajectory: " + + loaded[index]["conversations"][1]["value"] + ) + loaded[index]["conversations"][0]["value"] = ( + get_drivellava_prompt(trajectory_encoder) + ) + data.extend(loaded) + + # Balance by the class given by data[index]["conversations"][1]["value"] + class_dist: Dict[str, int] = {} + for index in range(len(data)): + class_name = data[index]["conversations"][1]["value"] + if class_name in class_dist: + class_dist[class_name] += 1 + else: + class_dist[class_name] = 1 + + min_class = min(class_dist.values()) + max_class = max(class_dist.values()) + mean_class = sum(class_dist.values()) / len(class_dist) + std_class = sum( + [(x - mean_class) ** 2 for x in class_dist.values()] + ) / len(class_dist) + std_class = std_class**0.5 + print( + f"Min class: {min_class}, Max class: {max_class}, " + f"Mean class: {mean_class}, Std class: {std_class}" + ) + + threshold = min_class * 5 + + final_data = [] + final_dist: Dict[str, int] = {} + for index in range(len(data)): + class_name = data[index]["conversations"][1]["value"] + if class_name in final_dist: + final_dist[class_name] += 1 + else: + final_dist[class_name] = 1 + if final_dist[class_name] < threshold: + final_data.append(data[index]) + + return final_data + + def main(): trajectory_encoder = TrajectoryEncoder() @@ -70,7 +131,7 @@ def main(): train_json_path = os.path.join(COMMAVQ_DIR, "train.json") val_json_path = os.path.join(COMMAVQ_DIR, "val.json") - train = load_json_dataset( + train = load_json_dataset_balanced( [ train_json_path, ], @@ -83,6 +144,9 @@ def main(): trajectory_encoder, ) + print(f"Train: {len(train)}") + print(f"Val: {len(val)}") + # Shuffle train and val random.shuffle(train) random.shuffle(val) @@ -99,9 +163,6 @@ def main(): json_data = json.dumps(val, ensure_ascii=False, indent=4) f.write(json_data) - print(f"Train: {len(train)}") - print(f"Val: {len(val)}") - # Assign paths to variables WORKING_DIR = os.path.abspath("./LLaVA/") DEEPSPEED_SCRIPT = "deepspeed llava/train/train_mem.py" @@ -136,7 +197,7 @@ def main(): --group_by_modality_length True \ --bf16 True \ --output_dir {OUTPUT_DIR} \ - --num_train_epochs 1 \ + --num_train_epochs 4 \ --per_device_train_batch_size 16 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ @@ -154,8 +215,7 @@ def main(): --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ - --report_to wandb \ - --freeze_backbone + --report_to wandb """ print(finetune_script)