Skip to content

Commit

Permalink
feat(eval): drivellava inference script
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Feb 25, 2024
1 parent b069568 commit 3c286af
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 23 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ checkpoints_pretrained/

docker/
core
test_media/*
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@ dmypy.json
.github/templates/*

checkpoints/*
test_media/*
150 changes: 150 additions & 0 deletions drivellava/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import re
import sys
from io import BytesIO

import requests
import torch
from PIL import Image


def image_parser(args):
out = args.image_file.split(args.sep)
return out


def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image


def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out


class DriveLLaVA:
def __init__(self, args):

LLAVA_PATH = os.path.abspath("./LLaVA")

if LLAVA_PATH not in sys.path:
sys.path.append(LLAVA_PATH)

from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init

# Model Initialization
# Assuming this function disables initialization in PyTorch
disable_torch_init()

self.model_name = get_model_name_from_path(args.model_path)
self.tokenizer, self.model, self.image_processor, self.context_len = (
load_pretrained_model(
args.model_path, args.model_base, self.model_name
)
)

# Infer conversation mode based on model name
if "llama-2" in self.model_name.lower():
self.conv_mode = "llava_llama_2"
elif "mistral" in self.model_name.lower():
self.conv_mode = "mistral_instruct"
elif "v1.6-34b" in self.model_name.lower():
self.conv_mode = "chatml_direct"
elif "v1" in self.model_name.lower():
self.conv_mode = "llava_v1"
elif "mpt" in self.model_name.lower():
self.conv_mode = "mpt"
else:
self.conv_mode = "llava_v0"

if args.conv_mode is not None and self.conv_mode != args.conv_mode:
print(
f"[WARNING] the auto inferred conversation mode is "
f"{self.conv_mode}, while `--conv-mode` is {args.conv_mode}, "
f"using {args.conv_mode}"
)
self.conv_mode = args.conv_mode

self.args = args

def run(self, query, image_files):

from llava.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_PLACEHOLDER,
IMAGE_TOKEN_INDEX,
)
from llava.conversation import conv_templates
from llava.mm_utils import process_images, tokenizer_image_token

# Process query
qs = query
image_token_se = (
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
)
if IMAGE_PLACEHOLDER in qs:
if self.model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if self.model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

# Prepare conversation
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

# Process images
# image_files = image_parser(self.args)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images, self.image_processor, self.model.config
).to(self.model.device, dtype=torch.float16)

# Tokenize prompt
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)

# Inference
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
do_sample=True if self.args.temperature > 0 else False,
temperature=self.args.temperature,
top_p=self.args.top_p,
num_beams=self.args.num_beams,
max_new_tokens=self.args.max_new_tokens,
use_cache=True,
)

outputs = self.tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)[0].strip()
print(outputs)

return outputs
165 changes: 165 additions & 0 deletions drivellava/scripts/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Evaluated DriveLLaVA on a video sequence
"""

import os

import cv2
import numpy as np
from tqdm import tqdm

from drivellava.constants import get_image_path
from drivellava.datasets.commavq import CommaVQPoseQuantizedDataset
from drivellava.model import DriveLLaVA
from drivellava.sparse_llava_dataset import get_drivellava_prompt
from drivellava.trajectory_encoder import (
NUM_TRAJECTORY_TEMPLATES,
TRAJECTORY_SIZE,
TRAJECTORY_TEMPLATES_KMEANS_PKL,
TRAJECTORY_TEMPLATES_NPY,
TrajectoryEncoder,
)
from drivellava.utils import plot_bev_trajectory, plot_steering_traj


def main():

fine_tuned_model_path = "liuhaotian/llava-v1.5-7b"

args = type(
"Args",
(),
{
"model_path": fine_tuned_model_path,
"model_base": None,
# "model_name": get_model_name_from_path(fine_tuned_model_path),
# "query": prompt,
"conv_mode": None,
# "image_file": image_file,
"sep": ",",
"temperature": 0,
"top_p": None,
"num_beams": 1,
"max_new_tokens": 512,
},
)()

model = DriveLLaVA(args)

NUM_FRAMES = 20 * 1

encoded_video_path = "/root/Datasets/commavq/data_0_to_2500/000e83c564317de4668c2cb372f89b91_6.npy" # noqa

assert os.path.isfile(encoded_video_path), encoded_video_path

pose_path = encoded_video_path.replace("data_", "pose_data_").replace(
"val", "pose_val"
)
assert os.path.isfile(pose_path), pose_path

decoded_imgs_list = []

for frame_index in range(1200):
frame_path = get_image_path(encoded_video_path, frame_index)
if os.path.isfile(frame_path):
decoded_imgs_list.append(frame_path)

trajectory_encoder = TrajectoryEncoder(
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES,
trajectory_size=TRAJECTORY_SIZE,
trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY,
trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL,
)

pose_dataset = CommaVQPoseQuantizedDataset(
pose_path,
num_frames=NUM_FRAMES,
window_length=21 * 2 - 1,
polyorder=1,
trajectory_encoder=trajectory_encoder,
)

# Iterate over the embeddings in batches and decode the images
for i in tqdm(
range(0, len(decoded_imgs_list) - NUM_FRAMES, 1),
desc="Video",
):
if not os.path.isfile(decoded_imgs_list[i]):
continue
img = cv2.imread(decoded_imgs_list[i])

trajectory, trajectory_encoded = pose_dataset[i]
trajectory_quantized = trajectory_encoder.decode(trajectory_encoded)

model_trajectory_quantized = model.run(
get_drivellava_prompt(trajectory_encoder),
[
decoded_imgs_list[i],
],
)
print("Model Trajectory Token: ", model_trajectory_quantized)

print(
"trajectory[0]",
(np.min(trajectory[:, 0]), np.max(trajectory[:, 0])),
)
print(
"trajectory[1]",
(np.min(trajectory[:, 1]), np.max(trajectory[:, 1])),
)
print(
"trajectory[2]",
(np.min(trajectory[:, 2]), np.max(trajectory[:, 2])),
)
dx = trajectory[1:, 2] - trajectory[:-1, 2]
speed = dx / (1.0 / 20.0)
# m/s to km/h
speed_kmph = speed * 3.6
# speed mph
speed_mph = speed_kmph * 0.621371

img = plot_steering_traj(
img,
trajectory,
color=(255, 0, 0),
)

img = plot_steering_traj(
img,
trajectory_quantized,
color=(0, 255, 0),
)

img_bev = plot_bev_trajectory(trajectory, img, color=(255, 0, 0))
img_bev = plot_bev_trajectory(
trajectory_quantized, img, color=(0, 255, 0)
)

# Write speed on img
font = cv2.FONT_HERSHEY_SIMPLEX
bottomLeftCornerOfText = (10, 50)
fontScale = 0.5
fontColor = (255, 255, 255)
lineType = 2

img = cv2.resize(img, (0, 0), fx=2, fy=2)

cv2.putText(
img,
f"Speed: {speed_mph.mean():.2f} mph",
bottomLeftCornerOfText,
font,
fontScale,
fontColor,
lineType,
)

vis = np.concatenate([img, img_bev], axis=1)

cv2.imwrite("test_media/vis.png", vis)

exit()


if __name__ == "__main__":
main()
10 changes: 8 additions & 2 deletions drivellava/scripts/generate_commavq_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

def main():

SHOW_IMAGES = False
batch_size = 4

decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda")
Expand All @@ -37,6 +38,8 @@ def main():
print(f"Skipping {encoded_video_path}")
continue

print("encoded_video_path", encoded_video_path)

# embeddings: (1200, 8, 16) -> (B, x, y)
embeddings = np.load(encoded_video_path)

Expand Down Expand Up @@ -66,9 +69,12 @@ def main():
os.makedirs(os.path.dirname(frame_path), exist_ok=True)
# frame = (frame *).astype(np.uint8)
cv2.imwrite(frame_path, frame)
cv2.imshow("frame_path", cv2.resize(frame, (0, 0), fx=2, fy=2))

cv2.waitKey(1)
if SHOW_IMAGES:
cv2.imshow(
"frame_path", cv2.resize(frame, (0, 0), fx=2, fy=2)
)
cv2.waitKey(1)


if __name__ == "__main__":
Expand Down
10 changes: 4 additions & 6 deletions drivellava/scripts/generate_sparse_llava_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from tqdm import tqdm

from drivellava.constants import ENCODED_POSE_ALL, DECODER_ONNX_PATH
from drivellava.sparse_llava_dataset import generate_sparse_dataset
from drivellava.constants import DECODER_ONNX_PATH, ENCODED_POSE_ALL
from drivellava.onnx import load_model_from_onnx_comma
from drivellava.trajectory_encoder import TRAJECTORY_SIZE
from drivellava.sparse_llava_dataset import generate_sparse_dataset
from drivellava.trajectory_encoder import (
NUM_TRAJECTORY_TEMPLATES,
TRAJECTORY_SIZE,
Expand All @@ -16,15 +15,14 @@
TrajectoryEncoder,
)


def main():

NUM_FRAMES = TRAJECTORY_SIZE
WINDOW_LENGTH = 21 * 2 - 1
SKIP_FRAMES = 20 * 20

decoder_onnx = load_model_from_onnx_comma(
DECODER_ONNX_PATH, device="cuda"
)
decoder_onnx = load_model_from_onnx_comma(DECODER_ONNX_PATH, device="cuda")

trajectory_encoder = TrajectoryEncoder(
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES,
Expand Down
Loading

0 comments on commit 3c286af

Please sign in to comment.