Skip to content

Commit

Permalink
update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
hzheng40 committed Mar 4, 2024
1 parent 9beb5fa commit 97e6a24
Show file tree
Hide file tree
Showing 8 changed files with 991 additions and 442 deletions.
4 changes: 3 additions & 1 deletion docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,6 @@
# "**": ["search-field.html", "sidebar-nav-bs.html", "sidebar-ethical-ads.html"]
# }
html_last_updated_fmt = "%b %d, %Y"
html_show_sourcelink = True
html_show_sourcelink = True

copybutton_prompt_text = ">>> "
188 changes: 111 additions & 77 deletions gym/f110_gym/envs/collision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

@njit(cache=True)
def perpendicular(pt):
"""
Return a 2-vector's perpendicular vector
"""Return a 2-vector's perpendicular vector
Args:
pt (np.ndarray, (2,)): input vector
Parameters
----------
pt : np.ndarray
input vector
Returns:
pt (np.ndarray, (2,)): perpendicular vector
Returns
-------
np.ndarray
perpendicular vector
"""
temp = pt[0]
pt[0] = pt[1]
Expand All @@ -21,14 +24,21 @@ def perpendicular(pt):

@njit(cache=True)
def tripleProduct(a, b, c):
"""
Return triple product of three vectors
Args:
a, b, c (np.ndarray, (2,)): input vectors
Returns:
(np.ndarray, (2,)): triple product
"""Return triple product of three vectors
Parameters
----------
a : np.ndarray
input vector
b : np.ndarray
input vector
c : np.ndarray
input vector
Returns
-------
np.ndarray
triple product
"""
ac = a.dot(c)
bc = b.dot(c)
Expand All @@ -37,44 +47,57 @@ def tripleProduct(a, b, c):

@njit(cache=True)
def avgPoint(vertices):
"""
Return the average point of multiple vertices
"""Return the average point of multiple vertices
Args:
vertices (np.ndarray, (n, 2)): the vertices we want to find avg on
Parameters
----------
vertices : np.ndarray
the vertices we want to find avg on
Returns:
avg (np.ndarray, (2,)): average point of the vertices
Returns
-------
np.ndarray
average point of the vertices
"""
return np.sum(vertices, axis=0) / vertices.shape[0]


@njit(cache=True)
def indexOfFurthestPoint(vertices, d):
"""
Return the index of the vertex furthest away along a direction in the list of vertices
Args:
vertices (np.ndarray, (n, 2)): the vertices we want to find avg on
Returns:
idx (int): index of the furthest point
"""Return the index of the vertex furthest away along a direction in the list of vertices
Parameters
----------
vertices : np.ndarray
the vertices we want to find index on
d : np.ndarray
direction
Returns
-------
int
index of the furthest point
"""
return np.argmax(vertices.dot(d))


@njit(cache=True)
def support(vertices1, vertices2, d):
"""
Minkowski sum support function for GJK
Args:
vertices1 (np.ndarray, (n, 2)): vertices of the first body
vertices2 (np.ndarray, (n, 2)): vertices of the second body
d (np.ndarray, (2, )): direction to find the support along
Returns:
support (np.ndarray, (n, 2)): Minkowski sum
"""Minkowski sum support function for GJK
Parameters
----------
vertices1 : np.ndarray
vertices of the first body
vertices2 : np.ndarray
vertices of the second body
d : np.ndarray
direction to find the support along
Returns
-------
np.ndarray
Minkowski sum
"""
i = indexOfFurthestPoint(vertices1, d)
j = indexOfFurthestPoint(vertices2, -d)
Expand All @@ -83,15 +106,19 @@ def support(vertices1, vertices2, d):

@njit(cache=True)
def collision(vertices1, vertices2):
"""
GJK test to see whether two bodies overlap
Args:
vertices1 (np.ndarray, (n, 2)): vertices of the first body
vertices2 (np.ndarray, (n, 2)): vertices of the second body
Returns:
overlap (boolean): True if two bodies collide
"""GJK test to see whether two bodies overlap
Parameters
----------
vertices1 : np.ndarray
vertices of the first body
vertices2 : np.ndarray
vertices of the second body
Returns
-------
boolean
True if two bodies collide
"""
index = 0
simplex = np.empty((3, 2))
Expand Down Expand Up @@ -155,15 +182,19 @@ def collision(vertices1, vertices2):

@njit(cache=True)
def collision_multiple(vertices):
"""
Check pair-wise collisions for all provided vertices
Args:
vertices (np.ndarray (num_bodies, 4, 2)): all vertices for checking pair-wise collision
Returns:
collisions (np.ndarray (num_vertices, )): whether each body is in collision
collision_idx (np.ndarray (num_vertices, )): which index of other body is each index's body is in collision, -1 if not in collision
"""Check pair-wise collisions for all provided vertices
Parameters
----------
vertices : np.ndarray
all vertices for checking pair-wise collision
Returns
-------
collisions : np.ndarray
whether each body is in collision
collision_idx : np.ndarray
which index of other body is each index's body is in collision, -1 if not in collision
"""
collisions = np.zeros((vertices.shape[0],))
collision_idx = -1 * np.ones((vertices.shape[0],))
Expand All @@ -184,21 +215,19 @@ def collision_multiple(vertices):
return collisions, collision_idx


"""
Utility functions for getting vertices by pose and shape
"""


@njit(cache=True)
def get_trmtx(pose):
"""
Get transformation matrix of vehicle frame -> global frame
"""Get transformation matrix of vehicle frame -> global frame
Args:
pose (np.ndarray (3, )): current pose of the vehicle
Parameters
----------
pose : np.ndarray
current pose of the vehicle
return:
H (np.ndarray (4, 4)): transformation matrix
Returns
-------
np.ndarray
transformation matrix
"""
x = pose[0]
y = pose[1]
Expand All @@ -218,16 +247,21 @@ def get_trmtx(pose):

@njit(cache=True)
def get_vertices(pose, length, width):
"""
Utility function to return vertices of the car body given pose and size
Args:
pose (np.ndarray, (3, )): current world coordinate pose of the vehicle
length (float): car length
width (float): car width
Returns:
vertices (np.ndarray, (4, 2)): corner vertices of the vehicle body
"""Utility function to return vertices of the car body given pose and size
Parameters
----------
pose : np.ndarray
current world coordinate pose of the vehicle
length : float
car length
width : float
car width
Returns
-------
np.ndarray
corner vertices of the vehicle body
"""
H = get_trmtx(pose)
rl = H.dot(np.asarray([[-length / 2], [width / 2], [0.0], [1.0]])).flatten()
Expand Down
Loading

0 comments on commit 97e6a24

Please sign in to comment.