diff --git a/examples/quadrotor/plant.py b/examples/quadrotor/plant.py index 9c1ae4d..60e5e93 100644 --- a/examples/quadrotor/plant.py +++ b/examples/quadrotor/plant.py @@ -1,9 +1,17 @@ -from typing import Tuple, Union +from typing import Optional, Tuple, Union +from typing_extensions import Self import numpy as np import jax import jax.numpy as jnp import pydrake.symbolic as sym +from pydrake.examples import QuadrotorPlant +import pydrake.systems.framework +import pydrake.geometry +import pydrake.multibody.plant +import pydrake.multibody.parsing +import pydrake.math +from pydrake.common.value import Value def quat2rotmat(quat: Union[np.ndarray, jnp.ndarray]) -> Union[np.ndarray, jnp.ndarray]: @@ -33,7 +41,7 @@ def set_val(i, j, val): return R -class QuadrotorPolyPlant: +class QuadrotorPolyPlant(pydrake.systems.framework.LeafSystem): """ The state is x=(quat - [1, 0, 0, 0], pos, pos_dot, omega_WB_B) Note that the equilibrium is at x=0. @@ -48,14 +56,29 @@ class QuadrotorPolyPlant: kM: float def __init__(self): - self.m = 0.775 - self.I = np.array([[0.0015, 0, 0], [0, 0.0025, 0], [0, 0, 0.0035]]) - self.g = 9.81 - self.l = 0.15 - self.kF = 1.0 - self.kM = 0.0245 + super().__init__() + self.DeclareVectorInputPort("thrust", 4) + state_index = self.DeclareContinuousState(13) + self.DeclareStateOutputPort("x", state_index) + drake_plant = QuadrotorPlant() + self.m = drake_plant.m() + self.I = drake_plant.inertia() + self.g = drake_plant.g() + self.l = drake_plant.length() + self.kF = drake_plant.force_constant() + self.kM = drake_plant.moment_constant() self.I_inv = np.linalg.inv(self.I) + def DoCalcTimeDerivatives( + self, + context: pydrake.systems.framework.Context, + derivatives: pydrake.systems.framework.ContinuousState, + ): + x = context.get_continuous_state_vector().CopyToVector() + u = self.EvalVectorInput(context, 0).CopyToVector() + xdot: np.ndarray = self.dynamics(x, u) + derivatives.SetFromVector(xdot) + def dynamics( self, x: Union[np.ndarray, jnp.ndarray], u: Union[np.ndarray, jnp.ndarray] ) -> Union[np.ndarray, jnp.ndarray]: @@ -186,3 +209,59 @@ def f(xu): AB_jnp = jax.jacfwd(f)(xu_jnp) AB = np.array(AB_jnp) return AB[:, :13], AB[:, 13:] + + +class QuadrotorPolyGeometry(pydrake.systems.framework.LeafSystem): + def __init__( + self, scene_graph: pydrake.geometry.SceneGraph, name: Optional[str] = None + ): + super().__init__() + mbp = pydrake.multibody.plant.MultibodyPlant(0.0) + parser = pydrake.multibody.parsing.Parser(mbp, scene_graph) + model_instance_indices = parser.AddModelsFromUrl( + "package://drake_models/skydio_2/quadrotor.urdf" + ) + mbp.Finalize() + + body_indices = mbp.GetBodyIndices(model_instance_indices[0]) + body_index = body_indices[0] + self.source_id = mbp.get_source_id() + self.frame_id = mbp.GetBodyFrameIdOrThrow(body_index) + + self.DeclareVectorInputPort("state", 13) + self.DeclareAbstractOutputPort( + "geometry_pose", + lambda: Value[pydrake.geometry.FramePoseVector](), + self.output_geometry_pose, + ) + + def output_geometry_pose( + self, + context: pydrake.systems.framework.Context, + poses, + ): + state = self.get_input_port(0).Eval(context) + + pose = pydrake.math.RigidTransform( + pydrake.common.eigen_geometry.Quaternion( + state[0] + 1, state[1], state[2], state[3] + ), + state[4:7], + ) + poses_value = pydrake.geometry.FramePoseVector() + poses_value.set_value(self.frame_id, pose) + poses.set_value(poses_value) + + @staticmethod + def AddToBuilder( + builder: pydrake.systems.framework.DiagramBuilder, + quadrotor_state_port: pydrake.systems.framework.OutputPort, + scene_graph: pydrake.geometry.SceneGraph, + ) -> Self: + quadrotor_geometry = builder.AddSystem(QuadrotorPolyGeometry(scene_graph)) + builder.Connect(quadrotor_state_port, quadrotor_geometry.get_input_port(0)) + builder.Connect( + quadrotor_geometry.get_output_port(0), + scene_graph.get_source_pose_port(quadrotor_geometry.source_id), + ) + return quadrotor_geometry