Skip to content

Commit

Permalink
feat(train): class balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Feb 29, 2024
1 parent 5efa0e3 commit c50aaa9
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
2 changes: 1 addition & 1 deletion drivellava/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main():

# fine_tuned_model_path = "liuhaotian/llava-v1.5-7b"
fine_tuned_model_path = os.path.expanduser(
"~/Datasets/checkpoints/checkpoint-600/"
"~/Datasets/checkpoints/checkpoint-100/"
)

args = type(
Expand Down
76 changes: 68 additions & 8 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import random
import subprocess
import sys
from typing import List
from typing import Dict, List

from drivellava.constants import COMMAVQ_DIR

Expand Down Expand Up @@ -42,6 +42,67 @@ def load_json_dataset(
return data


def load_json_dataset_balanced(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
):
from drivellava.sparse_llava_dataset import get_drivellava_prompt

data = []
for json_path in json_list:
with open(json_path, "r", encoding="utf-8") as f:
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1
# print('val', loaded[index]["conversations"][1]["value"])
# exit()

loaded[index]["conversations"][1]["value"] = (
"Selected Trajectory: "
+ loaded[index]["conversations"][1]["value"]
)
loaded[index]["conversations"][0]["value"] = (
get_drivellava_prompt(trajectory_encoder)
)
data.extend(loaded)

# Balance by the class given by data[index]["conversations"][1]["value"]
class_dist: Dict[str, int] = {}
for index in range(len(data)):
class_name = data[index]["conversations"][1]["value"]
if class_name in class_dist:
class_dist[class_name] += 1
else:
class_dist[class_name] = 1

min_class = min(class_dist.values())
max_class = max(class_dist.values())
mean_class = sum(class_dist.values()) / len(class_dist)
std_class = sum(
[(x - mean_class) ** 2 for x in class_dist.values()]
) / len(class_dist)
std_class = std_class**0.5
print(
f"Min class: {min_class}, Max class: {max_class}, "
f"Mean class: {mean_class}, Std class: {std_class}"
)

threshold = min_class * 5

final_data = []
final_dist: Dict[str, int] = {}
for index in range(len(data)):
class_name = data[index]["conversations"][1]["value"]
if class_name in final_dist:
final_dist[class_name] += 1
else:
final_dist[class_name] = 1
if final_dist[class_name] < threshold:
final_data.append(data[index])

return final_data


def main():

trajectory_encoder = TrajectoryEncoder()
Expand Down Expand Up @@ -70,7 +131,7 @@ def main():
train_json_path = os.path.join(COMMAVQ_DIR, "train.json")
val_json_path = os.path.join(COMMAVQ_DIR, "val.json")

train = load_json_dataset(
train = load_json_dataset_balanced(
[
train_json_path,
],
Expand All @@ -83,6 +144,9 @@ def main():
trajectory_encoder,
)

print(f"Train: {len(train)}")
print(f"Val: {len(val)}")

# Shuffle train and val
random.shuffle(train)
random.shuffle(val)
Expand All @@ -99,9 +163,6 @@ def main():
json_data = json.dumps(val, ensure_ascii=False, indent=4)
f.write(json_data)

print(f"Train: {len(train)}")
print(f"Val: {len(val)}")

# Assign paths to variables
WORKING_DIR = os.path.abspath("./LLaVA/")
DEEPSPEED_SCRIPT = "deepspeed llava/train/train_mem.py"
Expand Down Expand Up @@ -136,7 +197,7 @@ def main():
--group_by_modality_length True \
--bf16 True \
--output_dir {OUTPUT_DIR} \
--num_train_epochs 1 \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
Expand All @@ -154,8 +215,7 @@ def main():
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
--freeze_backbone
--report_to wandb
"""

print(finetune_script)
Expand Down

0 comments on commit c50aaa9

Please sign in to comment.