Skip to content

Commit

Permalink
feat(trajectory_encoder): left to right encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Feb 29, 2024
1 parent b9b610d commit 5efa0e3
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 43 deletions.
27 changes: 22 additions & 5 deletions drivellava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from PIL import Image

from drivellava.constants import LLAVA_PATH
from drivellava.trajectory_encoder import TrajectoryEncoder


def load_image(image_file):
Expand All @@ -28,7 +29,9 @@ def load_images(image_files):


class DriveLLaVA:
def __init__(self, args):
def __init__(self, args, trajectory_encoder: TrajectoryEncoder):

self.trajectory_encoder = trajectory_encoder

if LLAVA_PATH not in sys.path:
sys.path.append(LLAVA_PATH)
Expand Down Expand Up @@ -113,15 +116,15 @@ def run(self, query: str, image_files: List[str]):
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

print("qs", qs)
# print("qs", qs)

# Prepare conversation
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

print("prompt", prompt)
# print("prompt", prompt)

# Process images
images = load_images(image_files)
Expand All @@ -139,7 +142,7 @@ def run(self, query: str, image_files: List[str]):
.to(self.model.device)
)

print("input_ids", input_ids)
# print("input_ids", input_ids)

# Inference
with torch.inference_mode():
Expand All @@ -162,4 +165,18 @@ def run(self, query: str, image_files: List[str]):

outputs = outputs[0]

return outputs
# Output is of the format: "Selected Trajectory: %T"
# where %T is the selected trajectory_token

# Extract the selected trajectory_token
trajectory_token = outputs.split(":")[1].strip()

print("trajectory_token", trajectory_token)

model_trajectory_quantized = self.trajectory_encoder.decode(
trajectory_token
)

print("model_trajectory_quantized", model_trajectory_quantized)

return model_trajectory_quantized
33 changes: 11 additions & 22 deletions drivellava/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def main():

# from transformers.models.llava.configuration_llava import LlavaConfig

fine_tuned_model_path = "liuhaotian/llava-v1.5-7b"
# fine_tuned_model_path = os.path.expanduser(
# "~/Datasets/checkpoints/checkpoint-1000/"
# )
# fine_tuned_model_path = "liuhaotian/llava-v1.5-7b"
fine_tuned_model_path = os.path.expanduser(
"~/Datasets/checkpoints/checkpoint-600/"
)

args = type(
"Args",
Expand All @@ -53,7 +53,13 @@ def main():
},
)()

model = DriveLLaVA(args)
trajectory_encoder = TrajectoryEncoder(
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES,
trajectory_size=TRAJECTORY_SIZE,
trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY,
trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL,
)
model = DriveLLaVA(args, trajectory_encoder)

print(dir(model.tokenizer))
# print(model.tokenizer.get_vocab())
Expand All @@ -67,7 +73,6 @@ def main():
# assert os.path.isfile(encoded_video_path), encoded_video_path

pose_path = encoded_video_path.replace("data_", "pose_data_").replace(
# pose_path = encoded_video_path.replace("img_data_", "pose_data_").replace(
"val",
"pose_val",
)
Expand All @@ -82,13 +87,6 @@ def main():
if os.path.isfile(frame_path):
decoded_imgs_list.append(frame_path)

trajectory_encoder = TrajectoryEncoder(
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES,
trajectory_size=TRAJECTORY_SIZE,
trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY,
trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL,
)

pose_dataset = CommaVQPoseQuantizedDataset(
pose_path,
num_frames=NUM_FRAMES,
Expand Down Expand Up @@ -130,15 +128,6 @@ def main():
decoded_imgs_list[i],
],
)
print(
"model_trajectory_quantized",
len(model_trajectory_quantized),
model_trajectory_quantized,
)
model_trajectory_quantized = model_trajectory_quantized[0]
model_trajectory_quantized = trajectory_encoder.decode(
model_trajectory_quantized
)

dx = trajectory[1:, 2] - trajectory[:-1, 2]
speed = dx / (1.0 / 20.0)
Expand Down
26 changes: 18 additions & 8 deletions drivellava/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import List

from drivellava.constants import COMMAVQ_DIR

# from drivellava.constants import ENCODED_JSON, VAL_ENCODED_JSON
from drivellava.trajectory_encoder import TrajectoryEncoder


Expand All @@ -25,6 +27,8 @@ def load_json_dataset(
loaded = json.load(f)
for index in range(len(loaded)):
assert len(loaded[index]["conversations"][1]["value"]) == 1
# print('val', loaded[index]["conversations"][1]["value"])
# exit()

loaded[index]["conversations"][1]["value"] = (
"Selected Trajectory: "
Expand All @@ -39,8 +43,17 @@ def load_json_dataset(


def main():
# train = load_json_dataset(ENCODED_JSON)
# val = load_json_dataset(VAL_ENCODED_JSON)

trajectory_encoder = TrajectoryEncoder()

# train = load_json_dataset(
# ENCODED_JSON,
# trajectory_encoder,
# )
# val = load_json_dataset(
# VAL_ENCODED_JSON,
# trajectory_encoder,
# )

# train_json_path = os.path.abspath("checkpoints/train.json")
# val_json_path = os.path.abspath("checkpoints/val.json")
Expand All @@ -57,8 +70,6 @@ def main():
train_json_path = os.path.join(COMMAVQ_DIR, "train.json")
val_json_path = os.path.join(COMMAVQ_DIR, "val.json")

trajectory_encoder = TrajectoryEncoder()

train = load_json_dataset(
[
train_json_path,
Expand Down Expand Up @@ -125,13 +136,13 @@ def main():
--group_by_modality_length True \
--bf16 True \
--output_dir {OUTPUT_DIR} \
--num_train_epochs 5 \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "epoch" \
--save_strategy "steps" \
--save_steps 1000 \
--save_steps 50 \
--save_total_limit 1 \
--learning_rate 2e-3 \
--weight_decay 0. \
Expand All @@ -144,8 +155,7 @@ def main():
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
--freeze_backbone \
--freeze_mm_mlp_adapter
--freeze_backbone
"""

print(finetune_script)
Expand Down
15 changes: 11 additions & 4 deletions drivellava/sparse_llava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,17 @@ def get_drivellava_prompt(
trajectory_encoder: TrajectoryEncoder,
default_image_token: str = DEFAULT_IMAGE_TOKEN,
):
traj_list = list(trajectory_encoder.token2trajectory.keys())
random.shuffle(traj_list)
traj_str = ",".join(list(map(str, traj_list)))
# traj_list = list(trajectory_encoder.token2trajectory.keys())
left_tokens, center_tokens, right_tokens = (
trajectory_encoder.left_to_right()
)
traj_str = (
"The trajectory tokens are sorted from left to center to right\n"
)
traj_str += "Left: " + ",".join(list(map(str, left_tokens))) + "\n"
traj_str += "Center: " + ",".join(list(map(str, center_tokens))) + "\n"
traj_str += "Right: " + ",".join(list(map(str, right_tokens))) + "\n"

P1 = (
f"{default_image_token}\nYou are DriveLLaVA, a "
+ "self-driving car. You will select the "
Expand Down Expand Up @@ -178,7 +186,6 @@ def generate_sparse_dataset(

data = []


# Iterate over the embeddings in batches and decode the images
for i in tqdm(
range(WINDOW_LENGTH, len(pose_dataset) - WINDOW_LENGTH, SKIP_FRAMES),
Expand Down
53 changes: 49 additions & 4 deletions drivellava/trajectory_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def decode(self, token: str):
trajectory_2d = self.token2trajectory[token]

height_axis = np.zeros_like(trajectory_2d[:, 0])
trajectory_2d = np.stack(
trajectory_3d = np.stack(
(
trajectory_2d[:, 0],
height_axis,
Expand All @@ -124,14 +124,59 @@ def decode(self, token: str):
)

# trajectory_templates is of shape (B, F, 100, 2)
trajectory_2d = trajectory_2d.reshape((self.trajectory_size, 3))
trajectory_2d = trajectory_2d.astype(np.float32)
return trajectory_2d
trajectory_3d = trajectory_3d.reshape((self.trajectory_size, 3))
trajectory_3d = trajectory_3d.astype(np.float32)
return trajectory_3d

def left_to_right(
self,
):
"""Arrange the tokens from left to right
-ve x is left
+ve x is right
"""
x = self.trajectory_templates[:, :, 0]
# x_mean, x_std = x.mean(), x.std()
# y_mean, y_std = (
# self.trajectory_templates[:, :, 1].mean(),
# self.trajectory_templates[:, :, 1].std(),
# )

# x_mean is ~0
# Sort the trajectory_templates by the mean of the x axis
sorted_indices = np.argsort(x.mean(axis=1))

# Split the sorted_indices into left, center and right
left = sorted_indices[: self.num_trajectory_templates // 3]
center = sorted_indices[
self.num_trajectory_templates
// 3 : 2
* self.num_trajectory_templates
// 3
]
right = sorted_indices[2 * self.num_trajectory_templates // 3 :]

left_tokens = [self.trajectory_index_2_token[i] for i in left]
center_tokens = [self.trajectory_index_2_token[i] for i in center]
right_tokens = [self.trajectory_index_2_token[i] for i in right]

# print("self.trajectory_templates", self.trajectory_templates.shape)
# print("left", left.shape)
# print("center", center.shape)
# print("right", right.shape)

# print("left_tokens", left_tokens)
# print("center_tokens", center_tokens)
# print("right_tokens", right_tokens)

return left_tokens, center_tokens, right_tokens


if __name__ == "__main__":
trajectory_encoder = TrajectoryEncoder()

trajectory_encoder.left_to_right()

print(
[
i.encode(ENCODING).decode(ENCODING)
Expand Down

0 comments on commit 5efa0e3

Please sign in to comment.