Skip to content

Commit

Permalink
feat(gpt_vision): ui added to show prompt and state
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Apr 26, 2024
1 parent cb72b7a commit bccf7b3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 27 deletions.
13 changes: 5 additions & 8 deletions drivellava/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@
KMPH_2_MPS = 1 / 3.6
DEG_2_RAD = np.pi / 180

DEFAULT_LLM_PROVIDER = "openai"
DEFAULT_VISION_MODEL = "gpt-4-vision-preview"
DEFAULT_CONTROLS_MODEL = "gpt-3.5-turbo"
DEFAULT_LLM_PROVIDER = "openai"

# DEFAULT_LLM_PROVIDER = "ollama"
# DEFAULT_VISION_MODEL = "llava"
# DEFAULT_CONTROLS_MODEL = "llama3"

DEFAULT_MISSION = dedent(
"""As DriveLLaVA, the autonomous vehicle, your task is to analyze the \
given image and determine the optimal driving path. Choose the most \
suitable trajectory option from the list provided based on the \
visual information. Make sure to stay centered in your lane. \
If you deviate from the lane make sure to make course corrections"""
)
DEFAULT_MISSION = dedent("""Explore the town while obeying traffic laws""")


def encode_opencv_image(img):
Expand Down
22 changes: 14 additions & 8 deletions drivellava/gpt/gpt_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def llm_client_factory(llm_provider: str):


class GPTVision:
def __init__(self, max_history=1):
def __init__(self, max_history=0):
self.client = llm_client_factory(settings.system.SYSTEM_LLM_PROVIDER)

self.previous_messages = GPTState()
Expand All @@ -74,6 +74,14 @@ def step(
gpt_controls = DroneControls(trajectory_index=0, speed_index=0)
return gpt_controls

# If number of messages is greater than max_history,
# remove the oldest message. Do not remove the system message
# The -2 is to account for the [system message, setup message]
if len(self.previous_messages) - 2 > self.max_history:
EXTRA_SIZE = len(self.previous_messages)
for _ in range(2, EXTRA_SIZE):
self.previous_messages.pop(-1)

trajectory_templates, colors = (
self.trajectory_encoder.get_colors_left_to_right()
)
Expand Down Expand Up @@ -172,13 +180,7 @@ def step(
],
)

# If number of messages is greater than max_history,
# remove the oldest message. Do not remove the system message
# The -2 is to account for the [system message, setup message]
if len(self.previous_messages) - 2 > self.max_history:
self.previous_messages.pop(1)

print(desc)
# print(desc)

gpt_controls = self.client.chat.completions.create(
model=CONTROLS_MODEL,
Expand Down Expand Up @@ -209,6 +211,10 @@ def step(

gpt_controls.trajectory_index += offset

print("=" * 10)
print(self.previous_messages.to_str(0, ["system", "user"]))
print("=" * 10)

print("gpt:", gpt_controls)

return gpt_controls
11 changes: 6 additions & 5 deletions drivellava/gpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""

GPT_SYSTEM = """
You are an autonomous vehicle. You are given the vehicle's environment state \
and you must control the vehicle to complete the mission.
You are DriveLLaVA, an autonomous vehicle. You analyze the situation and \
make descisions on the control signals to drive the vehicle. You are given \
the vehicle's environment state and you must control the vehicle to \
complete the mission.
"""

GPT_PROMPT_SETUP = """
Expand All @@ -21,12 +23,11 @@
Make use of the trajectory's color to identify it
What are your next actions? Be short and brief with your thoughts
What are your next actions? Let us think step by step
"""

GPT_PROMPT_UPDATE = """MISSION: {mission}
Updated visual is provided
What are your next actions? Be short and brief with your thoughts
Updated visual is provided. What are your next actions?
"""

GPT_PROMPT_CONTROLS = """
Expand Down
8 changes: 5 additions & 3 deletions drivellava/schema/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def pop(self, index=-1):
def to_prompt(self):
return [m.to_dict() for m in self.messages]

def to_str(self) -> str:
def to_str(self, start_index: int = 1, roles: list = ["system"]) -> str:
"""
Returns a human readable prompt
"""
result = ""
for message in self.to_prompt():
for index, message in enumerate(self.to_prompt()):
content = ""
for cont in message["content"]:
if cont["type"] == "image_url":
Expand All @@ -136,6 +136,8 @@ def to_str(self) -> str:
content += cont["text"]
else:
assert False
result += f"{message['role']}: {content}\n"

if index > start_index and message["role"] in roles:
result += f"{message['role']}: {content}\n"

return result
12 changes: 9 additions & 3 deletions drivellava/scripts/carla_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,31 @@ def print_text_image(
width=50,
font_size=0.5,
font_thickness=2,
text_color=(255, 255, 255),
text_color_bg=(0, 0, 0),
):
font = cv2.FONT_HERSHEY_SIMPLEX
wrapped_text = textwrap.wrap(text, width=width)

for i, line in enumerate(wrapped_text):
textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0]
text_w, text_h = textsize

gap = textsize[1] + 10

y = (i + 1) * gap
x = 10

cv2.rectangle(
img, (x, y - text_h), (x + text_w, y + text_h), text_color_bg, -1
)
cv2.putText(
img,
line,
(x, y),
font,
font_size,
(255, 255, 255),
text_color,
font_thickness,
lineType=cv2.LINE_AA,
)
Expand Down Expand Up @@ -157,8 +163,8 @@ def main(): # pragma: no cover
image, template_trajectory_3d, color=color, track=False
)

print_text_image(image, "Prompt")
visual[0:128, 0:256] = image
print_text_image(visual[0:128, 0:256], "Prompt")

image_vis = image_raw.copy()

Expand Down Expand Up @@ -193,8 +199,8 @@ def main(): # pragma: no cover
color=(255, 0, 0),
track=True,
)
print_text_image(image_vis, "DriveLLaVA Controls")
visual[0:128, 256:512] = image_vis
print_text_image(visual[0:128, 256:512], "DriveLLaVA Controls")

text_visual = np.zeros((512 - 128, 512, 3), dtype=np.uint8)
print_text_image(
Expand Down

0 comments on commit bccf7b3

Please sign in to comment.