From 3c286afa6a4060596db9a123d170d295b4b61704 Mon Sep 17 00:00:00 2001 From: Aditya Date: Sun, 25 Feb 2024 13:33:22 +0000 Subject: [PATCH] feat(eval): drivellava inference script --- .dockerignore | 1 + .gitignore | 1 + drivellava/model.py | 150 ++++++++++++++++ drivellava/scripts/eval.py | 165 ++++++++++++++++++ drivellava/scripts/generate_commavq_images.py | 10 +- .../scripts/generate_sparse_llava_dataset.py | 10 +- drivellava/sparse_llava_dataset.py | 34 ++-- 7 files changed, 348 insertions(+), 23 deletions(-) create mode 100644 drivellava/model.py create mode 100644 drivellava/scripts/eval.py diff --git a/.dockerignore b/.dockerignore index 9c2001d..168dc64 100644 --- a/.dockerignore +++ b/.dockerignore @@ -125,3 +125,4 @@ checkpoints_pretrained/ docker/ core +test_media/* diff --git a/.gitignore b/.gitignore index 285c852..c4e3162 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,4 @@ dmypy.json .github/templates/* checkpoints/* +test_media/* diff --git a/drivellava/model.py b/drivellava/model.py new file mode 100644 index 0000000..9fce147 --- /dev/null +++ b/drivellava/model.py @@ -0,0 +1,150 @@ +import os +import re +import sys +from io import BytesIO + +import requests +import torch +from PIL import Image + + +def image_parser(args): + out = args.image_file.split(args.sep) + return out + + +def load_image(image_file): + if image_file.startswith("http") or image_file.startswith("https"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + image = Image.open(image_file).convert("RGB") + return image + + +def load_images(image_files): + out = [] + for image_file in image_files: + image = load_image(image_file) + out.append(image) + return out + + +class DriveLLaVA: + def __init__(self, args): + + LLAVA_PATH = os.path.abspath("./LLaVA") + + if LLAVA_PATH not in sys.path: + sys.path.append(LLAVA_PATH) + + from llava.mm_utils import get_model_name_from_path + from llava.model.builder import load_pretrained_model + from llava.utils import disable_torch_init + + # Model Initialization + # Assuming this function disables initialization in PyTorch + disable_torch_init() + + self.model_name = get_model_name_from_path(args.model_path) + self.tokenizer, self.model, self.image_processor, self.context_len = ( + load_pretrained_model( + args.model_path, args.model_base, self.model_name + ) + ) + + # Infer conversation mode based on model name + if "llama-2" in self.model_name.lower(): + self.conv_mode = "llava_llama_2" + elif "mistral" in self.model_name.lower(): + self.conv_mode = "mistral_instruct" + elif "v1.6-34b" in self.model_name.lower(): + self.conv_mode = "chatml_direct" + elif "v1" in self.model_name.lower(): + self.conv_mode = "llava_v1" + elif "mpt" in self.model_name.lower(): + self.conv_mode = "mpt" + else: + self.conv_mode = "llava_v0" + + if args.conv_mode is not None and self.conv_mode != args.conv_mode: + print( + f"[WARNING] the auto inferred conversation mode is " + f"{self.conv_mode}, while `--conv-mode` is {args.conv_mode}, " + f"using {args.conv_mode}" + ) + self.conv_mode = args.conv_mode + + self.args = args + + def run(self, query, image_files): + + from llava.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + IMAGE_PLACEHOLDER, + IMAGE_TOKEN_INDEX, + ) + from llava.conversation import conv_templates + from llava.mm_utils import process_images, tokenizer_image_token + + # Process query + qs = query + image_token_se = ( + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + ) + if IMAGE_PLACEHOLDER in qs: + if self.model.config.mm_use_im_start_end: + qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) + else: + qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) + else: + if self.model.config.mm_use_im_start_end: + qs = image_token_se + "\n" + qs + else: + qs = DEFAULT_IMAGE_TOKEN + "\n" + qs + + # Prepare conversation + conv = conv_templates[self.conv_mode].copy() + conv.append_message(conv.roles[0], qs) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + # Process images + # image_files = image_parser(self.args) + images = load_images(image_files) + image_sizes = [x.size for x in images] + images_tensor = process_images( + images, self.image_processor, self.model.config + ).to(self.model.device, dtype=torch.float16) + + # Tokenize prompt + input_ids = ( + tokenizer_image_token( + prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" + ) + .unsqueeze(0) + .cuda() + ) + + # Inference + with torch.inference_mode(): + output_ids = self.model.generate( + input_ids, + images=images_tensor, + image_sizes=image_sizes, + do_sample=True if self.args.temperature > 0 else False, + temperature=self.args.temperature, + top_p=self.args.top_p, + num_beams=self.args.num_beams, + max_new_tokens=self.args.max_new_tokens, + use_cache=True, + ) + + outputs = self.tokenizer.batch_decode( + output_ids, skip_special_tokens=True + )[0].strip() + print(outputs) + + return outputs diff --git a/drivellava/scripts/eval.py b/drivellava/scripts/eval.py new file mode 100644 index 0000000..616a6f9 --- /dev/null +++ b/drivellava/scripts/eval.py @@ -0,0 +1,165 @@ +""" +Evaluated DriveLLaVA on a video sequence +""" + +import os + +import cv2 +import numpy as np +from tqdm import tqdm + +from drivellava.constants import get_image_path +from drivellava.datasets.commavq import CommaVQPoseQuantizedDataset +from drivellava.model import DriveLLaVA +from drivellava.sparse_llava_dataset import get_drivellava_prompt +from drivellava.trajectory_encoder import ( + NUM_TRAJECTORY_TEMPLATES, + TRAJECTORY_SIZE, + TRAJECTORY_TEMPLATES_KMEANS_PKL, + TRAJECTORY_TEMPLATES_NPY, + TrajectoryEncoder, +) +from drivellava.utils import plot_bev_trajectory, plot_steering_traj + + +def main(): + + fine_tuned_model_path = "liuhaotian/llava-v1.5-7b" + + args = type( + "Args", + (), + { + "model_path": fine_tuned_model_path, + "model_base": None, + # "model_name": get_model_name_from_path(fine_tuned_model_path), + # "query": prompt, + "conv_mode": None, + # "image_file": image_file, + "sep": ",", + "temperature": 0, + "top_p": None, + "num_beams": 1, + "max_new_tokens": 512, + }, + )() + + model = DriveLLaVA(args) + + NUM_FRAMES = 20 * 1 + + encoded_video_path = "/root/Datasets/commavq/data_0_to_2500/000e83c564317de4668c2cb372f89b91_6.npy" # noqa + + assert os.path.isfile(encoded_video_path), encoded_video_path + + pose_path = encoded_video_path.replace("data_", "pose_data_").replace( + "val", "pose_val" + ) + assert os.path.isfile(pose_path), pose_path + + decoded_imgs_list = [] + + for frame_index in range(1200): + frame_path = get_image_path(encoded_video_path, frame_index) + if os.path.isfile(frame_path): + decoded_imgs_list.append(frame_path) + + trajectory_encoder = TrajectoryEncoder( + num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES, + trajectory_size=TRAJECTORY_SIZE, + trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY, + trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL, + ) + + pose_dataset = CommaVQPoseQuantizedDataset( + pose_path, + num_frames=NUM_FRAMES, + window_length=21 * 2 - 1, + polyorder=1, + trajectory_encoder=trajectory_encoder, + ) + + # Iterate over the embeddings in batches and decode the images + for i in tqdm( + range(0, len(decoded_imgs_list) - NUM_FRAMES, 1), + desc="Video", + ): + if not os.path.isfile(decoded_imgs_list[i]): + continue + img = cv2.imread(decoded_imgs_list[i]) + + trajectory, trajectory_encoded = pose_dataset[i] + trajectory_quantized = trajectory_encoder.decode(trajectory_encoded) + + model_trajectory_quantized = model.run( + get_drivellava_prompt(trajectory_encoder), + [ + decoded_imgs_list[i], + ], + ) + print("Model Trajectory Token: ", model_trajectory_quantized) + + print( + "trajectory[0]", + (np.min(trajectory[:, 0]), np.max(trajectory[:, 0])), + ) + print( + "trajectory[1]", + (np.min(trajectory[:, 1]), np.max(trajectory[:, 1])), + ) + print( + "trajectory[2]", + (np.min(trajectory[:, 2]), np.max(trajectory[:, 2])), + ) + dx = trajectory[1:, 2] - trajectory[:-1, 2] + speed = dx / (1.0 / 20.0) + # m/s to km/h + speed_kmph = speed * 3.6 + # speed mph + speed_mph = speed_kmph * 0.621371 + + img = plot_steering_traj( + img, + trajectory, + color=(255, 0, 0), + ) + + img = plot_steering_traj( + img, + trajectory_quantized, + color=(0, 255, 0), + ) + + img_bev = plot_bev_trajectory(trajectory, img, color=(255, 0, 0)) + img_bev = plot_bev_trajectory( + trajectory_quantized, img, color=(0, 255, 0) + ) + + # Write speed on img + font = cv2.FONT_HERSHEY_SIMPLEX + bottomLeftCornerOfText = (10, 50) + fontScale = 0.5 + fontColor = (255, 255, 255) + lineType = 2 + + img = cv2.resize(img, (0, 0), fx=2, fy=2) + + cv2.putText( + img, + f"Speed: {speed_mph.mean():.2f} mph", + bottomLeftCornerOfText, + font, + fontScale, + fontColor, + lineType, + ) + + vis = np.concatenate([img, img_bev], axis=1) + + cv2.imwrite("test_media/vis.png", vis) + + exit() + + +if __name__ == "__main__": + main() diff --git a/drivellava/scripts/generate_commavq_images.py b/drivellava/scripts/generate_commavq_images.py index b511dca..275f0c4 100644 --- a/drivellava/scripts/generate_commavq_images.py +++ b/drivellava/scripts/generate_commavq_images.py @@ -19,6 +19,7 @@ def main(): + SHOW_IMAGES = False batch_size = 4 decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda") @@ -37,6 +38,8 @@ def main(): print(f"Skipping {encoded_video_path}") continue + print("encoded_video_path", encoded_video_path) + # embeddings: (1200, 8, 16) -> (B, x, y) embeddings = np.load(encoded_video_path) @@ -66,9 +69,12 @@ def main(): os.makedirs(os.path.dirname(frame_path), exist_ok=True) # frame = (frame *).astype(np.uint8) cv2.imwrite(frame_path, frame) - cv2.imshow("frame_path", cv2.resize(frame, (0, 0), fx=2, fy=2)) - cv2.waitKey(1) + if SHOW_IMAGES: + cv2.imshow( + "frame_path", cv2.resize(frame, (0, 0), fx=2, fy=2) + ) + cv2.waitKey(1) if __name__ == "__main__": diff --git a/drivellava/scripts/generate_sparse_llava_dataset.py b/drivellava/scripts/generate_sparse_llava_dataset.py index 4c90e82..1e232e4 100644 --- a/drivellava/scripts/generate_sparse_llava_dataset.py +++ b/drivellava/scripts/generate_sparse_llava_dataset.py @@ -4,10 +4,9 @@ from tqdm import tqdm -from drivellava.constants import ENCODED_POSE_ALL, DECODER_ONNX_PATH -from drivellava.sparse_llava_dataset import generate_sparse_dataset +from drivellava.constants import DECODER_ONNX_PATH, ENCODED_POSE_ALL from drivellava.onnx import load_model_from_onnx_comma -from drivellava.trajectory_encoder import TRAJECTORY_SIZE +from drivellava.sparse_llava_dataset import generate_sparse_dataset from drivellava.trajectory_encoder import ( NUM_TRAJECTORY_TEMPLATES, TRAJECTORY_SIZE, @@ -16,15 +15,14 @@ TrajectoryEncoder, ) + def main(): NUM_FRAMES = TRAJECTORY_SIZE WINDOW_LENGTH = 21 * 2 - 1 SKIP_FRAMES = 20 * 20 - decoder_onnx = load_model_from_onnx_comma( - DECODER_ONNX_PATH, device="cuda" - ) + decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda") trajectory_encoder = TrajectoryEncoder( num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES, diff --git a/drivellava/sparse_llava_dataset.py b/drivellava/sparse_llava_dataset.py index 633c871..efef3f7 100644 --- a/drivellava/sparse_llava_dataset.py +++ b/drivellava/sparse_llava_dataset.py @@ -4,10 +4,10 @@ import json import os -import onnxruntime as ort import cv2 import numpy as np +import onnxruntime as ort from tqdm import tqdm from drivellava.constants import DECODER_ONNX_PATH, get_image_path, get_json @@ -100,18 +100,32 @@ def visualize_pose( exit() +def get_drivellava_prompt(trajectory_encoder: TrajectoryEncoder): + return ( + "\nYou are DriveLLaVA, a " + + "self-driving car. You will select the " + + "appropriate trrajectory token given the " + + "above image as context.\n" + + "You may select one from the " + + "following templates: {TEM}" + + ",".join(trajectory_encoder.token2trajectory.keys()) + ) + + def generate_sparse_dataset( pose_path: str, pose_index: int, NUM_FRAMES: int, WINDOW_LENGTH: int, SKIP_FRAMES: int, - trajectory_encoder: TrajectoryEncoder = None, - decoder_onnx: ort.InferenceSession = None, + trajectory_encoder: TrajectoryEncoder = None, # type: ignore + decoder_onnx: ort.InferenceSession = None, # type: ignore ): batch_size = 1 if decoder_onnx is None: - decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda") + decoder_onnx = load_model_from_onnx_comma( + DECODER_ONNX_PATH, device="cuda" + ) encoded_video_path = pose_path.replace("pose_data", "data").replace( "pose_val", "val" @@ -174,17 +188,7 @@ def generate_sparse_dataset( "conversations": [ { "from": "human", - "value": ( - "\nYou are DriveLLaVA, a " - + "self-driving car. You will select the " - + "appropriate trrajectory token given the " - + "above image as context.\n" - + "You may select one from the " - + "following templates: " - + ",".join( - trajectory_encoder.token2trajectory.keys() - ) - ), + "value": get_drivellava_prompt(trajectory_encoder), }, {"from": "gpt", "value": trajectory_encoded}, ],