Skip to content

Commit

Permalink
adds new module high_level
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-bachhuber committed Feb 21, 2023
1 parent 4dc17c0 commit 270fefb
Show file tree
Hide file tree
Showing 7 changed files with 760 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cc/utils/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extra_sources import build_extra_sources, loop_observer_configs
from .masterplot_siso import ExtraSource, LoopObserverConfig, masterplot_siso
3 changes: 3 additions & 0 deletions cc/utils/high_level/baselines/controller_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import jax

{"two_segments_v2": {"f_depth": 0, "state_dim": 50}}
6 changes: 6 additions & 0 deletions cc/utils/high_level/baselines/model_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import jax

{
"two_segments_v2": {"state_dim": 75, "u_transform": jax.numpy.arctan, "f_depth": 0},
"rover": {},
}
13 changes: 13 additions & 0 deletions cc/utils/high_level/baselines/pid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Best PID-Gains for various environments."""

from cc.env import make_env

best_pd = {
make_env("rover"): {"PD": {"P": 1.0, "D": 1.836274}},
# these were from before i switched from degrees to rad
# thuse they should now be scaled with np.rad2deg
make_env("muscle_asymmetric"): {"PD": {"D": 0.000273, "P": 0.004213}},
make_env("muscle_asymmetric", physics_kwargs={"corner": 0.03}): {
"PD": {"P": 0.026084, "D": 0.005709}
},
}
245 changes: 245 additions & 0 deletions cc/utils/high_level/end_to_end_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import jax
import jax.numpy as jnp
import jax.random as jrand
import numpy as np
import optax

from cc.env.collect import random_steps_source, sample_feedforward_and_collect
from cc.env.wrappers import AddRefSignalRewardFnWrapper
from cc.examples.neural_ode_controller_compact_example import make_neural_ode_controller
from cc.examples.neural_ode_model_compact_example import make_neural_ode_model
from cc.train import (
DictLogger,
EvaluationMetrices,
ModelControllerTrainer,
Regularisation,
SupervisedDataset,
Tracker,
TrainingOptionsController,
TrainingOptionsModel,
UnsupervisedDataset,
make_dataloader,
)
from cc.utils import l1_norm, l2_norm, rmse
from cc.utils.high_level import (
build_extra_sources,
loop_observer_configs,
masterplot_siso,
)


def make_model(
env,
sample_train,
sample_val,
model_kwargs: dict,
seed_model: int,
n_steps: int,
lambda_l1_norm: float = 0.0,
lambda_l2_norm: float = 0.0,
optimizer=optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)),
):
model = make_neural_ode_model(
env.action_spec(),
env.observation_spec(),
env.control_timestep(),
key=jrand.PRNGKey(seed_model),
**model_kwargs,
)

model_train_dataloader = make_dataloader(
SupervisedDataset(sample_train.action, sample_train.obs), # <- (X, y)
n_minibatches=4,
do_bootstrapping=True,
)

regularisers = (
Regularisation(
prefactor=lambda_l1_norm,
reduce_weights=lambda vector_of_params: {
"l1_norm": l1_norm(vector_of_params)
},
),
Regularisation(
prefactor=lambda_l2_norm,
reduce_weights=lambda vector_of_params: {
"l2_norm": l2_norm(vector_of_params)
},
),
)

metrices = (
EvaluationMetrices(
data=(sample_val.action, sample_val.obs), # <- (X, y)
metrices=(lambda y, yhat: {"val_rmse": rmse(y, yhat)},),
),
)

model_train_options = TrainingOptionsModel(
model_train_dataloader, optimizer, regularisers=regularisers, metrices=metrices
)

model_trainer = ModelControllerTrainer(
model,
model_train_options=model_train_options,
loggers=[DictLogger()],
trackers=[Tracker("val_rmse")],
)

model_trainer.run(n_steps)

return model_trainer


def tree_transform(bound: float = 3.0):

upper_bound = bound
lower_bound = -bound

@jax.vmap
def _random_step(ref, key):
return jnp.ones_like(ref) * jrand.uniform(
key, (), minval=lower_bound, maxval=upper_bound
)

def _tree_transform(key, ref, bs):
keys = jrand.split(key, bs)
return jax.tree_map(lambda ref: _random_step(ref, keys), ref)

return _tree_transform


def make_controller(
env,
env_w_source,
model,
seed_controller: int,
training_step_source_amplitude: float,
controller_kwargs: dict,
n_steps: int,
noise_scale=None,
):

controller = make_neural_ode_controller(
env_w_source.observation_spec(),
env_w_source.action_spec(),
env_w_source.control_timestep(),
key=jrand.PRNGKey(seed_controller),
**controller_kwargs,
)

controller_dataloader = make_dataloader(
UnsupervisedDataset(
random_steps_source(env, list(range(30))).get_references_for_optimisation()
),
n_minibatches=5,
tree_transform=tree_transform(training_step_source_amplitude),
)

lr = 1e-3
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(lr))

regularisers = [Regularisation(1.0, lambda w: {"MSW": jnp.mean(w**2)})]

controller_train_options = TrainingOptionsController(
controller_dataloader,
optimizer,
regularisers=regularisers,
noise_scale=noise_scale,
)
controller_trainer = ModelControllerTrainer(
model,
controller,
controller_train_options=controller_train_options,
trackers=[Tracker("loss")],
loggers=[DictLogger()],
)
controller_trainer.run(n_steps)

return controller_trainer


data_configs = {"two_segments_v2": {}}


def make_masterplot(
env_id,
env,
filename,
record_video,
experiment_id,
model_kwargs,
seed_model,
n_steps_model,
controller_kwargs,
seed_controller,
n_steps_controller,
noise_scale_controller=None,
model_optimizer=optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)),
) -> float:
training_step_source_amplitude = 3.0
if env_id == "muscle_asymmetric":
training_step_source_amplitude = np.deg2rad(60.0)

train_gp = list(range(1))
train_cos = [1, 1.5, 2, 3, 3.5, 4, 5, 5.5, 6, 7, 8, 9, 10, 12, 14, 11, 13, 15]
val_gp = [15, 16]
val_cos = [2.5, 7.5, 10.5, 16]
#
# train_gp = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
# train_cos = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14]
# val_gp = [15, 16, 17, 18]
# val_cos = [2.5, 5.0, 7.5, 10.0]

train_sample = sample_feedforward_and_collect(env, train_gp, train_cos)
val_sample = sample_feedforward_and_collect(env, val_gp, val_cos)

test_source = random_steps_source(
env, list(range(6)), training_step_source_amplitude
)
env_w_source = AddRefSignalRewardFnWrapper(env, test_source)

lambda_l2_norm = model_kwargs.pop("lambda_l2_norm", 0.0)
model_trainer = make_model(
env,
train_sample,
val_sample,
model_kwargs,
seed_model,
n_steps_model,
lambda_l2_norm,
model_optimizer,
)
model = model_trainer.trackers[0].best_model_or_controller()
controller_trainer = make_controller(
env,
env_w_source,
model,
seed_controller,
training_step_source_amplitude,
controller_kwargs,
n_steps_controller,
noise_scale_controller,
)
controller = controller_trainer.trackers[0].best_model_or_controller()

results = masterplot_siso(
env,
test_source,
controller,
filename,
build_extra_sources(env_id, record_video),
controller_trainer.get_logs()[0],
controller_trainer.get_tracker_logs()[0],
model,
model_trainer.get_logs()[0],
model_trainer.get_tracker_logs()[0],
train_sample,
[0, 1, 4, 5],
val_sample,
[0, 2, 3],
experiment_id=experiment_id,
loop_observer_config=loop_observer_configs[env_id],
)

return results
60 changes: 60 additions & 0 deletions cc/utils/high_level/extra_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np

from cc.env import make_env
from cc.env.collect import (
constant_after_transform_source,
double_step_source,
high_steps_source,
sample_feedforward_collect_and_make_source,
)
from cc.env.loop_observer import AnglesEnvLoopObserver

from .masterplot_siso import ExtraSource, LoopObserverConfig

loop_observer_configs = {
"two_segments_v2": LoopObserverConfig(
AnglesEnvLoopObserver(),
"sum of angles [deg]",
lambda lr, idx: lr["hinge_1 [deg]"][idx] + lr["hinge_2 [deg]"][idx],
),
"rover": None,
"muscle_asymmetric": None,
}


def build_extra_sources(env_id: str, record_video):

if env_id == "two_segments_v2":
camera_id = "skyview"
high_amp = 6.0
step_amp = 2.0
elif env_id == "rover":
camera_id = "target"
high_amp = 5.0
step_amp = 2.0
elif env_id == "muscle_asymmetric":
camera_id = "upfront"
high_amp = np.deg2rad(110)
step_amp = np.deg2rad(45)
else:
raise NotImplementedError()

env = make_env(env_id)

smooth_source, _, _ = sample_feedforward_collect_and_make_source(env, seeds=[1, 2])
smooth_source_constant = constant_after_transform_source(smooth_source, 5.0)

extras = [
ExtraSource(
high_steps_source(env, high_amp), "high_amplitude", camera_id, record_video
),
ExtraSource(
double_step_source(env, step_amp), "double_steps", camera_id, record_video
),
ExtraSource(smooth_source, "smooth_refs", camera_id, record_video),
ExtraSource(
smooth_source_constant, "smooth_to_constant_refs", camera_id, record_video
),
]

return extras
Loading

0 comments on commit 270fefb

Please sign in to comment.