Skip to content

Commit

Permalink
feat(train): dataset augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Feb 28, 2024
1 parent 09df6a8 commit 9118563
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
37 changes: 30 additions & 7 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,30 @@

import json
import os
import random
import subprocess
import sys
from typing import List

from drivellava.constants import COMMAVQ_DIR
from drivellava.trajectory_encoder import TrajectoryEncoder


def load_json_dataset(
json_list: List[str],
trajectory_encoder: TrajectoryEncoder,
):
# from drivellava.sparse_llava_dataset import generate_sparse_dataset
from drivellava.sparse_llava_dataset import get_drivellava_prompt

data = []
for json_path in json_list:
with open(json_path, "r", encoding="utf-8") as f:
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1
# loaded[index][
# "conversations"
# ][0]["value"] = generate_sparse_dataset()
loaded[index]["conversations"][0]["value"] = (
get_drivellava_prompt(trajectory_encoder)
)
data.extend(loaded)

return data
Expand All @@ -49,17 +52,37 @@ def main():
train_json_path = os.path.join(COMMAVQ_DIR, "train.json")
val_json_path = os.path.join(COMMAVQ_DIR, "val.json")

trajectory_encoder = TrajectoryEncoder()

train = load_json_dataset(
[
train_json_path,
]
],
trajectory_encoder,
)
val = load_json_dataset(
[
val_json_path,
]
],
trajectory_encoder,
)

# Shuffle train and val
random.shuffle(train)
random.shuffle(val)

new_train_json_path = os.path.abspath("checkpoints/train.json")
new_val_json_path = os.path.abspath("checkpoints/val.json")

# Save train to a temp file
with open(new_train_json_path, "w", encoding="utf-8") as f:
json_data = json.dumps(train, ensure_ascii=False, indent=4)
f.write(json_data)

with open(new_val_json_path, "w", encoding="utf-8") as f:
json_data = json.dumps(val, ensure_ascii=False, indent=4)
f.write(json_data)

print(f"Train: {len(train)}")
print(f"Val: {len(val)}")

Expand All @@ -68,7 +91,7 @@ def main():
DEEPSPEED_SCRIPT = "deepspeed llava/train/train_mem.py"
DEEPSPEED_JSON = os.path.abspath("./config/zero3.json")
MODEL_NAME = "liuhaotian/llava-v1.5-7b"
DATA_PATH = train_json_path # Replace with your JSON data path
DATA_PATH = new_train_json_path # Replace with your JSON data path
IMAGE_FOLDER = os.path.expanduser(
"~/Datasets/commavq"
) # Replace with your image folder path
Expand Down
19 changes: 16 additions & 3 deletions drivellava/sparse_llava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import os
import random
import sys

import cv2
Expand Down Expand Up @@ -113,15 +114,27 @@ def visualize_pose(


def get_drivellava_prompt(trajectory_encoder: TrajectoryEncoder):
return (
traj_list = list(trajectory_encoder.token2trajectory.keys())
random.shuffle(traj_list)
traj_str = ",".join(list(map(str, traj_list)))
P1 = (
f"{DEFAULT_IMAGE_TOKEN}\nYou are DriveLLaVA, a "
+ "self-driving car. You will select the "
+ "appropriate trrajectory token given the "
+ "above image as context.\n"
+ "You may select one from the "
+ "following templates: "
+ ",".join(trajectory_encoder.token2trajectory.keys())
+ f"following templates: {traj_str}"
)
P2 = f"""{DEFAULT_IMAGE_TOKEN} As DriveLLaVA, the autonomous vehicle, your task is to analyze the given image and determine the optimal driving path. Choose the most suitable trajectory option from the list provided based on the visual information. {traj_str}"""
P3 = f"""{DEFAULT_IMAGE_TOKEN} You are the AI system DriveLLaVA, responsible for navigating self-driving cars. With the image provided as your guide, select the correct trajectory from the options below to ensure a safe and efficient route. {traj_str}"""
P4 = f"""{DEFAULT_IMAGE_TOKEN} Imagine yourself as DriveLLaVA, an advanced self-driving vehicle intelligence. Examine the scenario depicted in the image and decide on the best course of action by selecting an appropriate trajectory from the given templates. {traj_str}"""
P5 = f"""{DEFAULT_IMAGE_TOKEN} You embody DriveLLaVA, the brain behind autonomous driving technology. Given the context shown in the image, it's your job to pick the right trajectory from the available choices to navigate safely. {traj_str}"""
P6 = f"""{DEFAULT_IMAGE_TOKEN} As DriveLLaVA, a pioneering self-driving car AI, you're tasked with interpreting the visual cues in the provided image to choose the most suitable trajectory from the list of options to ensure a smooth journey. {traj_str}"""
P7 = f"""{DEFAULT_IMAGE_TOKEN} You, as DriveLLaVA, are at the forefront of autonomous navigation. Assess the situation depicted in the image and select the trajectory that best aligns with safe and efficient driving principles from the options provided. {traj_str}"""
P8 = f"""{DEFAULT_IMAGE_TOKEN} Functioning as DriveLLaVA, the self-driving car's decision-making system, you must look at the image and determine the best path forward by choosing from the predefined trajectory templates. {traj_str}"""
P9 = f"""{DEFAULT_IMAGE_TOKEN} You are DriveLLaVA, an AI designed for autonomous vehicles. Your objective is to analyze the context presented in the image and select a trajectory that guarantees the safety and comfort of your passengers from the given templates. {traj_str}"""

return random.choice([P1, P2, P3, P4, P5, P6, P7, P8, P9])


def generate_sparse_dataset(
Expand Down

0 comments on commit 9118563

Please sign in to comment.