generated from rochacbruno/python-project-template
-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(eval): drivellava inference script
- Loading branch information
Showing
7 changed files
with
348 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,3 +125,4 @@ checkpoints_pretrained/ | |
|
||
docker/ | ||
core | ||
test_media/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -132,3 +132,4 @@ dmypy.json | |
.github/templates/* | ||
|
||
checkpoints/* | ||
test_media/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import os | ||
import re | ||
import sys | ||
from io import BytesIO | ||
|
||
import requests | ||
import torch | ||
from PIL import Image | ||
|
||
|
||
def image_parser(args): | ||
out = args.image_file.split(args.sep) | ||
return out | ||
|
||
|
||
def load_image(image_file): | ||
if image_file.startswith("http") or image_file.startswith("https"): | ||
response = requests.get(image_file) | ||
image = Image.open(BytesIO(response.content)).convert("RGB") | ||
else: | ||
image = Image.open(image_file).convert("RGB") | ||
return image | ||
|
||
|
||
def load_images(image_files): | ||
out = [] | ||
for image_file in image_files: | ||
image = load_image(image_file) | ||
out.append(image) | ||
return out | ||
|
||
|
||
class DriveLLaVA: | ||
def __init__(self, args): | ||
|
||
LLAVA_PATH = os.path.abspath("./LLaVA") | ||
|
||
if LLAVA_PATH not in sys.path: | ||
sys.path.append(LLAVA_PATH) | ||
|
||
from llava.mm_utils import get_model_name_from_path | ||
from llava.model.builder import load_pretrained_model | ||
from llava.utils import disable_torch_init | ||
|
||
# Model Initialization | ||
# Assuming this function disables initialization in PyTorch | ||
disable_torch_init() | ||
|
||
self.model_name = get_model_name_from_path(args.model_path) | ||
self.tokenizer, self.model, self.image_processor, self.context_len = ( | ||
load_pretrained_model( | ||
args.model_path, args.model_base, self.model_name | ||
) | ||
) | ||
|
||
# Infer conversation mode based on model name | ||
if "llama-2" in self.model_name.lower(): | ||
self.conv_mode = "llava_llama_2" | ||
elif "mistral" in self.model_name.lower(): | ||
self.conv_mode = "mistral_instruct" | ||
elif "v1.6-34b" in self.model_name.lower(): | ||
self.conv_mode = "chatml_direct" | ||
elif "v1" in self.model_name.lower(): | ||
self.conv_mode = "llava_v1" | ||
elif "mpt" in self.model_name.lower(): | ||
self.conv_mode = "mpt" | ||
else: | ||
self.conv_mode = "llava_v0" | ||
|
||
if args.conv_mode is not None and self.conv_mode != args.conv_mode: | ||
print( | ||
f"[WARNING] the auto inferred conversation mode is " | ||
f"{self.conv_mode}, while `--conv-mode` is {args.conv_mode}, " | ||
f"using {args.conv_mode}" | ||
) | ||
self.conv_mode = args.conv_mode | ||
|
||
self.args = args | ||
|
||
def run(self, query, image_files): | ||
|
||
from llava.constants import ( | ||
DEFAULT_IM_END_TOKEN, | ||
DEFAULT_IM_START_TOKEN, | ||
DEFAULT_IMAGE_TOKEN, | ||
IMAGE_PLACEHOLDER, | ||
IMAGE_TOKEN_INDEX, | ||
) | ||
from llava.conversation import conv_templates | ||
from llava.mm_utils import process_images, tokenizer_image_token | ||
|
||
# Process query | ||
qs = query | ||
image_token_se = ( | ||
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN | ||
) | ||
if IMAGE_PLACEHOLDER in qs: | ||
if self.model.config.mm_use_im_start_end: | ||
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) | ||
else: | ||
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) | ||
else: | ||
if self.model.config.mm_use_im_start_end: | ||
qs = image_token_se + "\n" + qs | ||
else: | ||
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs | ||
|
||
# Prepare conversation | ||
conv = conv_templates[self.conv_mode].copy() | ||
conv.append_message(conv.roles[0], qs) | ||
conv.append_message(conv.roles[1], None) | ||
prompt = conv.get_prompt() | ||
|
||
# Process images | ||
# image_files = image_parser(self.args) | ||
images = load_images(image_files) | ||
image_sizes = [x.size for x in images] | ||
images_tensor = process_images( | ||
images, self.image_processor, self.model.config | ||
).to(self.model.device, dtype=torch.float16) | ||
|
||
# Tokenize prompt | ||
input_ids = ( | ||
tokenizer_image_token( | ||
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | ||
) | ||
.unsqueeze(0) | ||
.cuda() | ||
) | ||
|
||
# Inference | ||
with torch.inference_mode(): | ||
output_ids = self.model.generate( | ||
input_ids, | ||
images=images_tensor, | ||
image_sizes=image_sizes, | ||
do_sample=True if self.args.temperature > 0 else False, | ||
temperature=self.args.temperature, | ||
top_p=self.args.top_p, | ||
num_beams=self.args.num_beams, | ||
max_new_tokens=self.args.max_new_tokens, | ||
use_cache=True, | ||
) | ||
|
||
outputs = self.tokenizer.batch_decode( | ||
output_ids, skip_special_tokens=True | ||
)[0].strip() | ||
print(outputs) | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
Evaluated DriveLLaVA on a video sequence | ||
""" | ||
|
||
import os | ||
|
||
import cv2 | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from drivellava.constants import get_image_path | ||
from drivellava.datasets.commavq import CommaVQPoseQuantizedDataset | ||
from drivellava.model import DriveLLaVA | ||
from drivellava.sparse_llava_dataset import get_drivellava_prompt | ||
from drivellava.trajectory_encoder import ( | ||
NUM_TRAJECTORY_TEMPLATES, | ||
TRAJECTORY_SIZE, | ||
TRAJECTORY_TEMPLATES_KMEANS_PKL, | ||
TRAJECTORY_TEMPLATES_NPY, | ||
TrajectoryEncoder, | ||
) | ||
from drivellava.utils import plot_bev_trajectory, plot_steering_traj | ||
|
||
|
||
def main(): | ||
|
||
fine_tuned_model_path = "liuhaotian/llava-v1.5-7b" | ||
|
||
args = type( | ||
"Args", | ||
(), | ||
{ | ||
"model_path": fine_tuned_model_path, | ||
"model_base": None, | ||
# "model_name": get_model_name_from_path(fine_tuned_model_path), | ||
# "query": prompt, | ||
"conv_mode": None, | ||
# "image_file": image_file, | ||
"sep": ",", | ||
"temperature": 0, | ||
"top_p": None, | ||
"num_beams": 1, | ||
"max_new_tokens": 512, | ||
}, | ||
)() | ||
|
||
model = DriveLLaVA(args) | ||
|
||
NUM_FRAMES = 20 * 1 | ||
|
||
encoded_video_path = "/root/Datasets/commavq/data_0_to_2500/000e83c564317de4668c2cb372f89b91_6.npy" # noqa | ||
|
||
assert os.path.isfile(encoded_video_path), encoded_video_path | ||
|
||
pose_path = encoded_video_path.replace("data_", "pose_data_").replace( | ||
"val", "pose_val" | ||
) | ||
assert os.path.isfile(pose_path), pose_path | ||
|
||
decoded_imgs_list = [] | ||
|
||
for frame_index in range(1200): | ||
frame_path = get_image_path(encoded_video_path, frame_index) | ||
if os.path.isfile(frame_path): | ||
decoded_imgs_list.append(frame_path) | ||
|
||
trajectory_encoder = TrajectoryEncoder( | ||
num_trajectory_templates=NUM_TRAJECTORY_TEMPLATES, | ||
trajectory_size=TRAJECTORY_SIZE, | ||
trajectory_templates_npy=TRAJECTORY_TEMPLATES_NPY, | ||
trajectory_templates_kmeans_pkl=TRAJECTORY_TEMPLATES_KMEANS_PKL, | ||
) | ||
|
||
pose_dataset = CommaVQPoseQuantizedDataset( | ||
pose_path, | ||
num_frames=NUM_FRAMES, | ||
window_length=21 * 2 - 1, | ||
polyorder=1, | ||
trajectory_encoder=trajectory_encoder, | ||
) | ||
|
||
# Iterate over the embeddings in batches and decode the images | ||
for i in tqdm( | ||
range(0, len(decoded_imgs_list) - NUM_FRAMES, 1), | ||
desc="Video", | ||
): | ||
if not os.path.isfile(decoded_imgs_list[i]): | ||
continue | ||
img = cv2.imread(decoded_imgs_list[i]) | ||
|
||
trajectory, trajectory_encoded = pose_dataset[i] | ||
trajectory_quantized = trajectory_encoder.decode(trajectory_encoded) | ||
|
||
model_trajectory_quantized = model.run( | ||
get_drivellava_prompt(trajectory_encoder), | ||
[ | ||
decoded_imgs_list[i], | ||
], | ||
) | ||
print("Model Trajectory Token: ", model_trajectory_quantized) | ||
|
||
print( | ||
"trajectory[0]", | ||
(np.min(trajectory[:, 0]), np.max(trajectory[:, 0])), | ||
) | ||
print( | ||
"trajectory[1]", | ||
(np.min(trajectory[:, 1]), np.max(trajectory[:, 1])), | ||
) | ||
print( | ||
"trajectory[2]", | ||
(np.min(trajectory[:, 2]), np.max(trajectory[:, 2])), | ||
) | ||
dx = trajectory[1:, 2] - trajectory[:-1, 2] | ||
speed = dx / (1.0 / 20.0) | ||
# m/s to km/h | ||
speed_kmph = speed * 3.6 | ||
# speed mph | ||
speed_mph = speed_kmph * 0.621371 | ||
|
||
img = plot_steering_traj( | ||
img, | ||
trajectory, | ||
color=(255, 0, 0), | ||
) | ||
|
||
img = plot_steering_traj( | ||
img, | ||
trajectory_quantized, | ||
color=(0, 255, 0), | ||
) | ||
|
||
img_bev = plot_bev_trajectory(trajectory, img, color=(255, 0, 0)) | ||
img_bev = plot_bev_trajectory( | ||
trajectory_quantized, img, color=(0, 255, 0) | ||
) | ||
|
||
# Write speed on img | ||
font = cv2.FONT_HERSHEY_SIMPLEX | ||
bottomLeftCornerOfText = (10, 50) | ||
fontScale = 0.5 | ||
fontColor = (255, 255, 255) | ||
lineType = 2 | ||
|
||
img = cv2.resize(img, (0, 0), fx=2, fy=2) | ||
|
||
cv2.putText( | ||
img, | ||
f"Speed: {speed_mph.mean():.2f} mph", | ||
bottomLeftCornerOfText, | ||
font, | ||
fontScale, | ||
fontColor, | ||
lineType, | ||
) | ||
|
||
vis = np.concatenate([img, img_bev], axis=1) | ||
|
||
cv2.imwrite("test_media/vis.png", vis) | ||
|
||
exit() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.