Skip to content

Commit

Permalink
linting, docstrings and minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Berducci, Luigi committed Jun 14, 2024
1 parent d6f60e0 commit 7528d56
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 46 deletions.
106 changes: 71 additions & 35 deletions f1tenth_gym/envs/track/cubic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Cubic Spline interpolation using scipy.interpolate
Provides utilities for position, curvature, yaw, and arclength calculation
"""

import math

import numpy as np
Expand All @@ -10,8 +11,9 @@
from numba import njit
from typing import Union


@njit(fastmath=False, cache=True)
def nearest_point_on_trajectory(point, trajectory):
def nearest_point_on_trajectory(point: np.ndarray, trajectory: np.ndarray) -> tuple:
"""
Return the nearest point along the given piecewise linear trajectory.
Expand All @@ -20,9 +22,23 @@ def nearest_point_on_trajectory(point, trajectory):
Order of magnitude: trajectory length: 1000 --> 0.0002 second computation (5000fps)
point: size 2 numpy array
trajectory: Nx2 matrix of (x,y) trajectory waypoints
- these must be unique. If they are not unique, a divide by 0 error will destroy the world
Parameters
----------
point: np.ndarray
The 2d point to project onto the trajectory
trajectory: np.ndarray
The trajectory to project the point onto, shape (N, 2)
The points must be unique. If they are not unique, a divide by 0 error will destroy the world
Returns
-------
nearest_point: np.ndarray
The nearest point on the trajectory
distance: float
The distance from the point to the nearest point on the trajectory
t: float
min_dist_segment: int
The index of the nearest point on the trajectory
"""
diffs = trajectory[1:, :] - trajectory[:-1, :]
l2s = diffs[:, 0] ** 2 + diffs[:, 1] ** 2
Expand All @@ -49,6 +65,7 @@ def nearest_point_on_trajectory(point, trajectory):
min_dist_segment,
)


class CubicSpline2D:
"""
Cubic CubicSpline2D class.
Expand All @@ -66,11 +83,13 @@ class CubicSpline2D:
def __init__(self, x, y):
self.points = np.c_[x, y]
if not np.all(self.points[-1] == self.points[0]):
self.points = np.vstack((self.points, self.points[0])) # Ensure the path is closed
self.points = np.vstack(
(self.points, self.points[0])
) # Ensure the path is closed
self.s = self.__calc_s(self.points[:, 0], self.points[:, 1])
# Use scipy CubicSpline to interpolate the points with periodic boundary conditions
# This is necessary to ensure the path is continuous
self.spline = interpolate.CubicSpline(self.s, self.points, bc_type='periodic')
self.spline = interpolate.CubicSpline(self.s, self.points, bc_type="periodic")

def __calc_s(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -114,7 +133,7 @@ def calc_position(self, s: float) -> tuple[Union[float, None], Union[float, None
"""
return self.spline(s)

def calc_curvature(self, s: float) -> Union[float , None]:
def calc_curvature(self, s: float) -> Union[float, None]:
"""
Calc curvature at the given s.
Expand Down Expand Up @@ -152,18 +171,24 @@ def calc_yaw(self, s: float) -> Union[float, None]:
yaw = math.atan2(dy, dx)
# Convert yaw to [0, 2pi]
yaw = yaw % (2 * math.pi)

return yaw

def calc_arclength(self, x, y, s_guess=0.0):
def calc_arclength(
self, x: float, y: float, s_guess: float = 0.0
) -> tuple[float, float]:
"""
calc arclength
Calculate arclength for a given point (x, y) on the trajectory.
Parameters
----------
x : float
x position.
y : float
y position.
s_guess : float
initial guess for s.
Returns
-------
s : float
Expand All @@ -174,67 +199,78 @@ def calc_arclength(self, x, y, s_guess=0.0):

def distance_to_spline(s):
x_eval, y_eval = self.spline(s)[0]
return np.sqrt((x - x_eval)**2 + (y - y_eval)**2)
return np.sqrt((x - x_eval) ** 2 + (y - y_eval) ** 2)

output = so.fmin(distance_to_spline, s_guess, full_output=True, disp=False)
closest_s = output[0][0]
closest_s = float(output[0][0])
absolute_distance = output[1]
return closest_s, absolute_distance
def calc_arclength_inaccurate(self, x, y, s_guess=0.0):

def calc_arclength_inaccurate(self, x: float, y: float) -> tuple[float, float]:
"""
calc arclength, use nearest_point_on_trajectory
Less accuarate and less smooth than calc_arclength but
much faster - suitable for lap counting
Fast calculation of arclength for a given point (x, y) on the trajectory.
Less accuarate and less smooth than calc_arclength but much faster.
Suitable for lap counting.
Parameters
----------
x : float
x position.
y : float
y position.
Returns
-------
s : float
distance from the start point for given x, y.
ey : float
lateral deviation for given x, y.
"""
_, ey, t, min_dist_segment = nearest_point_on_trajectory(np.array([x, y]), self.points)
_, ey, t, min_dist_segment = nearest_point_on_trajectory(
np.array([x, y]), self.points
)
# s = s at closest_point + t
s = self.s[min_dist_segment] + t * (self.s[min_dist_segment + 1] - self.s[min_dist_segment])
s = float(
self.s[min_dist_segment]
+ t * (self.s[min_dist_segment + 1] - self.s[min_dist_segment])
)

return s, 0
return s, 0.0

def _calc_tangent(self, s: float) -> np.ndarray:
"""
Calculates the tangent to the curve at a given point.
def _calc_tangent(self, s):
'''
calculates the tangent to the curve at a given point
Parameters
----------
s : float
distance from the start point. if `s` is outside the data point's
range, return None.
distance from the start point.
If `s` is outside the data point's range, return None.
Returns
-------
tangent : float
tangent vector for given s.
'''
"""
dx, dy = self.spline(s, 1)
tangent = np.array([dx, dy])
return tangent

def _calc_normal(self, s):
'''
calculates the normal to the curve at a given point

def _calc_normal(self, s: float) -> np.ndarray:
"""
Calculate the normal to the curve at a given point.
Parameters
----------
s : float
distance from the start point. if `s` is outside the data point's
range, return None.
distance from the start point.
If `s` is outside the data point's range, return None.
Returns
-------
normal : float
normal vector for given s.
'''
"""
dx, dy = self.spline(s, 1)
normal = np.array([-dy, dx])
return normal
return normal
43 changes: 32 additions & 11 deletions tests/test_cubic_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from f1tenth_gym.envs.track import cubic_spline


class TestCubicSpline(unittest.TestCase):
def test_calc_curvature(self):
circle_x = np.cos(np.linspace(0, 2 * np.pi, 100))[:-1]
Expand All @@ -25,34 +26,54 @@ def test_calc_yaw(self):
self.assertAlmostEqual(track.calc_yaw(np.pi / 2), np.pi, places=2)
self.assertAlmostEqual(track.calc_yaw(np.pi), 3 * np.pi / 2, places=2)
self.assertAlmostEqual(track.calc_yaw(3 * np.pi / 2), 0, places=2)

def test_calc_position(self):
circle_x = np.cos(np.linspace(0, 2 * np.pi, 100))[:-1]
circle_y = np.sin(np.linspace(0, 2 * np.pi, 100))[:-1]
track = cubic_spline.CubicSpline2D(circle_x, circle_y)
# Test the position at the four corners of the circle
# The position of a circle is (x, y) = (cos(s), sin(s))
self.assertTrue(np.allclose(track.calc_position(0), np.array([1, 0]), atol=1e-3))
self.assertTrue(np.allclose(track.calc_position(np.pi / 2), np.array([0, 1]), atol=1e-3))
self.assertTrue(np.allclose(track.calc_position(np.pi), np.array([-1, 0]), atol=1e-3))
self.assertTrue(np.allclose(track.calc_position(3 * np.pi / 2), np.array([0, -1]), atol=1e-3))
self.assertTrue(
np.allclose(track.calc_position(0), np.array([1, 0]), atol=1e-3)
)
self.assertTrue(
np.allclose(track.calc_position(np.pi / 2), np.array([0, 1]), atol=1e-3)
)
self.assertTrue(
np.allclose(track.calc_position(np.pi), np.array([-1, 0]), atol=1e-3)
)
self.assertTrue(
np.allclose(
track.calc_position(3 * np.pi / 2), np.array([0, -1]), atol=1e-3
)
)

def test_calc_arclength(self):
circle_x = np.cos(np.linspace(0, 2 * np.pi, 100))[:-1]
circle_y = np.sin(np.linspace(0, 2 * np.pi, 100))[:-1]
track = cubic_spline.CubicSpline2D(circle_x, circle_y)
# Test the arclength at the four corners of the circle
self.assertAlmostEqual(track.calc_arclength(1, 0, 0)[0], 0, places=2)
self.assertAlmostEqual(track.calc_arclength(0, 1, 0)[0], np.pi/2, places=2)
self.assertAlmostEqual(track.calc_arclength(-1, 0, np.pi/2)[0], np.pi, places=2)
self.assertAlmostEqual(track.calc_arclength(0, -1, np.pi)[0], 3*np.pi/2, places=2)
self.assertAlmostEqual(track.calc_arclength(0, 1, 0)[0], np.pi / 2, places=2)
self.assertAlmostEqual(
track.calc_arclength(-1, 0, np.pi / 2)[0], np.pi, places=2
)
self.assertAlmostEqual(
track.calc_arclength(0, -1, np.pi)[0], 3 * np.pi / 2, places=2
)

def test_calc_arclength_inaccurate(self):
circle_x = np.cos(np.linspace(0, 2 * np.pi, 100))[:-1]
circle_y = np.sin(np.linspace(0, 2 * np.pi, 100))[:-1]
track = cubic_spline.CubicSpline2D(circle_x, circle_y)
# Test the arclength at the four corners of the circle
self.assertAlmostEqual(track.calc_arclength_inaccurate(1, 0)[0], 0, places=2)
self.assertAlmostEqual(track.calc_arclength_inaccurate(0, 1)[0], np.pi/2, places=2)
self.assertAlmostEqual(track.calc_arclength_inaccurate(-1, 0)[0], np.pi, places=2)
self.assertAlmostEqual(track.calc_arclength_inaccurate(0, -1)[0], 3*np.pi/2, places=2)
self.assertAlmostEqual(
track.calc_arclength_inaccurate(0, 1)[0], np.pi / 2, places=2
)
self.assertAlmostEqual(
track.calc_arclength_inaccurate(-1, 0)[0], np.pi, places=2
)
self.assertAlmostEqual(
track.calc_arclength_inaccurate(0, -1)[0], 3 * np.pi / 2, places=2
)

0 comments on commit 7528d56

Please sign in to comment.