Skip to content

Commit

Permalink
feat(drivellava/scripts/generate_sparse_llava_dataset_parallel.py): p…
Browse files Browse the repository at this point in the history
…arallel dataset generation
  • Loading branch information
AdityaNG committed Feb 25, 2024
1 parent 8ec0bd1 commit c28fda0
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 204 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ $ drivellava

Read the [CONTRIBUTING.md](CONTRIBUTING.md) file.

## Dataset

```
cd ~/
git clone https://github.com/AdityaNG/DriveLLaVA
cd ~/Datasets/
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/commaai/commavq ~/Datasets/commavq
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/commaai/commavq-gpt2m ~/Datasets/commavq-gpt2m
cd ~/Datasets/commavq-gpt2m
git lfs pull --include "*.onnx"
cd ~/Datasets/commavq
git lfs pull
cd ~/DriveLLaVA
python3 -m drivellava.scripts.generate_sparse_llava_dataset
```

## Running the scripts

```bash
Expand Down
1 change: 1 addition & 0 deletions drivellava/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __getitem__(self, index):

ENCODED_JSON_ALL = ENCODED_JSON + VAL_ENCODED_JSON


def get_image_path(encoded_video_path: str, index: int) -> str:
return os.path.join(
encoded_video_path.replace("val", "img_val").replace(".npy", ""),
Expand Down
198 changes: 8 additions & 190 deletions drivellava/scripts/generate_sparse_llava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,215 +2,33 @@
Generates image frames for the commavq dataset
"""

import json
import os

import cv2
import numpy as np
from tqdm import tqdm

from drivellava.constants import (
DECODER_ONNX_PATH,
ENCODED_POSE_ALL,
get_image_path,
get_json,
)
from drivellava.datasets.commavq import CommaVQPoseQuantizedDataset
from drivellava.onnx import load_model_from_onnx_comma
from drivellava.trajectory_encoder import (
NUM_TRAJECTORY_TEMPLATES,
TRAJECTORY_SIZE,
TRAJECTORY_TEMPLATES_KMEANS_PKL,
TRAJECTORY_TEMPLATES_NPY,
TrajectoryEncoder,
)
from drivellava.utils import (
decode_image,
plot_bev_trajectory,
plot_steering_traj,
)


def visualize_pose(
frame_path: str,
trajectory_encoder: TrajectoryEncoder,
trajectory_encoded: str,
trajectory: np.ndarray,
):
img = cv2.imread(frame_path)

trajectory_quantized = trajectory_encoder.decode(trajectory_encoded)

print(
"trajectory[0]",
(np.min(trajectory[:, 0]), np.max(trajectory[:, 0])),
)
print(
"trajectory[1]",
(np.min(trajectory[:, 1]), np.max(trajectory[:, 1])),
)
print(
"trajectory[2]",
(np.min(trajectory[:, 2]), np.max(trajectory[:, 2])),
)
dx = trajectory[1:, 2] - trajectory[:-1, 2]
speed = dx / (1.0 / 20.0)
# m/s to km/h
speed_kmph = speed * 3.6
# speed mph
speed_mph = speed_kmph * 0.621371

img = plot_steering_traj(
img,
trajectory,
color=(255, 0, 0),
)

img = plot_steering_traj(
img,
trajectory_quantized,
color=(0, 255, 0),
)

img_bev = plot_bev_trajectory(trajectory, img, color=(255, 0, 0))
img_bev = plot_bev_trajectory(trajectory_quantized, img, color=(0, 255, 0))

# Write speed on img
font = cv2.FONT_HERSHEY_SIMPLEX
bottomLeftCornerOfText = (10, 50)
fontScale = 0.5
fontColor = (255, 255, 255)
lineType = 2

img = cv2.resize(img, (0, 0), fx=2, fy=2)

cv2.putText(
img,
f"Speed: {speed_mph.mean():.2f} mph",
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
lineType,
)

cv2.imshow("frame_path", img)

cv2.imshow("frame_path_bev", cv2.resize(img_bev, (0, 0), fx=2, fy=2))

key = cv2.waitKey(1)

if key == ord("q"):
exit()
from drivellava.constants import ENCODED_POSE_ALL
from drivellava.sparse_llava_dataset import generate_sparse_dataset
from drivellava.trajectory_encoder import TRAJECTORY_SIZE


def main():

NUM_FRAMES = TRAJECTORY_SIZE
WINDOW_LENGTH = 21 * 2 - 1
SKIP_FRAMES = 20 * 20
PLOT_TRAJ = False

batch_size = 1
decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda")

trajectory_encoder = TrajectoryEncoder(
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES,
trajectory_size=TRAJECTORY_SIZE,
trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY,
trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL,
)

for pose_index, pose_path in tqdm(
enumerate(ENCODED_POSE_ALL),
desc="Generating sparse LLaVA dataset",
total=len(ENCODED_POSE_ALL),
):

pose_dataset = CommaVQPoseQuantizedDataset(
generate_sparse_dataset(
pose_path,
num_frames=NUM_FRAMES,
window_length=WINDOW_LENGTH,
polyorder=1,
trajectory_encoder=trajectory_encoder,
)

encoded_video_path = pose_path.replace("pose_data", "data").replace(
"pose_val", "val"
pose_index,
NUM_FRAMES,
WINDOW_LENGTH,
SKIP_FRAMES,
)

json_path = get_json(encoded_video_path)

if os.path.isfile(json_path):
continue

data = []

assert os.path.exists(encoded_video_path)

# embeddings: (1200, 8, 16) -> (B, x, y)
embeddings = np.load(encoded_video_path)

# Iterate over the embeddings in batches and decode the images
for i in tqdm(
range(
WINDOW_LENGTH, len(pose_dataset) - WINDOW_LENGTH, SKIP_FRAMES
),
desc="Video",
disable=True,
):
frame_path = get_image_path(encoded_video_path, i)

trajectory, trajectory_encoded = pose_dataset[i]

if not os.path.isfile(frame_path):
embeddings_batch = embeddings[i : i + batch_size]
frames = decode_image(
decoder_onnx,
embeddings_batch,
batch_size,
)
frame = frames[0]
os.makedirs(os.path.dirname(frame_path), exist_ok=True)
cv2.imwrite(frame_path, frame)

unique_id = pose_index * 100000 + i
data += [
{
"id": str(unique_id),
"image": frame_path,
"conversations": [
{
"from": "human",
"value": (
"<image>\nYou are DriveLLaVA, a "
+ "self-driving car. You will select the "
+ "appropriate trrajectory token given the "
+ "above image as context.\n"
+ "You may select one from the "
+ "following templates: "
+ ",".join(
trajectory_encoder.token2trajectory.keys()
)
),
},
{"from": "gpt", "value": trajectory_encoded},
],
}
]

if PLOT_TRAJ:
visualize_pose(
frame_path,
trajectory_encoder,
trajectory_encoded,
trajectory,
)

# Write to json
with open(json_path, "w") as f:
json.dump(data, f, indent=4)


if __name__ == "__main__":
main()
61 changes: 61 additions & 0 deletions drivellava/scripts/generate_sparse_llava_dataset_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Generates image frames for the commavq dataset using parallel processing
"""

import concurrent.futures

from tqdm import tqdm

from drivellava.constants import ENCODED_POSE_ALL
from drivellava.sparse_llava_dataset import generate_sparse_dataset
from drivellava.trajectory_encoder import TRAJECTORY_SIZE


def generate_frame(pose_path_num_frames_window_length_skip_frames):
"""
Wrapper function to call generate_sparse_dataset with all necessary
arguments.
This is needed because ProcessPoolExecutor.map only supports functions
with a single argument.
"""
pose_path, num_frames, window_length, skip_frames = (
pose_path_num_frames_window_length_skip_frames
)
generate_sparse_dataset(
pose_path,
num_frames,
window_length,
skip_frames,
)


def main():

NUM_FRAMES = TRAJECTORY_SIZE
WINDOW_LENGTH = 21 * 2 - 1
SKIP_FRAMES = 20 * 20

# Prepare a list of arguments for each task
tasks = [
(pose_path, NUM_FRAMES, WINDOW_LENGTH, SKIP_FRAMES)
for pose_path in ENCODED_POSE_ALL
]

# Initialize progress bar
pbar = tqdm(
total=len(ENCODED_POSE_ALL), desc="Generating sparse LLaVA dataset"
)

# Use ProcessPoolExecutor to parallelize dataset generation
with concurrent.futures.ProcessPoolExecutor() as executor:
# Map the generate_frame function across all tasks
# The result iterator allows us to update the progress bar
# as tasks complete
for _ in executor.map(generate_frame, tasks):
pbar.update(1)

pbar.close()


if __name__ == "__main__":
main()
21 changes: 9 additions & 12 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
Trains LLAVA model on the cumulative dataset.
"""

import json
import os
import sys
import subprocess
import sys
from typing import List

import json
from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON

from drivellava.constants import (
ENCODED_JSON,
VAL_ENCODED_JSON,
)

def load_json_dataset(
json_list: List[str],
Expand All @@ -21,9 +18,10 @@ def load_json_dataset(
for json_path in json_list:
with open(json_path, "r") as f:
data.extend(json.load(f))

return data


def main():
train = load_json_dataset(ENCODED_JSON)
val = load_json_dataset(VAL_ENCODED_JSON)
Expand All @@ -50,9 +48,10 @@ def main():
sys.path.append(WORKING_DIR)

# Command to run the script
finetune_script = f'''
finetune_script = f"""
{DEEPSPEED_SCRIPT} \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--lora_enable True --lora_r 128 --lora_alpha 256 \
--mm_projector_lr 2e-5 \
--deepspeed {DEEPSPEED_JSON} \
--model_name_or_path {MODEL_NAME} \
--version v1 \
Expand Down Expand Up @@ -88,15 +87,13 @@ def main():
--report_to wandb \
--freeze_backbone \
--freeze_mm_mlp_adapter
'''
"""

print(finetune_script)

# Run the command in WORKING_DIR
subprocess.run(finetune_script, cwd=WORKING_DIR, shell=True)




if __name__ == "__main__":
main()
Loading

0 comments on commit c28fda0

Please sign in to comment.