diff --git a/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzX_place.yaml b/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzX_place.yaml new file mode 100644 index 0000000..24f140f --- /dev/null +++ b/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzX_place.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: phone_on_base + - override /phase: place + - _self_ + +break_symmetry: False + +training: + init_cond_x: True + freeze_embnn: True + freeze_residual_flow: True + freeze_z_embnn: True + + load_from_checkpoint: True + checkpoint_file: null \ No newline at end of file diff --git a/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzY_place.yaml b/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzY_place.yaml new file mode 100644 index 0000000..3c8c809 --- /dev/null +++ b/configs/commands/rlbench/phone_on_base/train_taxposed_mlat_pzY_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: phone_on_base + - override /phase: place + - _self_ + +break_symmetry: False diff --git a/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzX_place.yaml b/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzX_place.yaml new file mode 100644 index 0000000..574a3c5 --- /dev/null +++ b/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzX_place.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: place_hanger_on_rack + - override /phase: place + - _self_ + +break_symmetry: False + +training: + init_cond_x: True + freeze_embnn: True + freeze_residual_flow: True + freeze_z_embnn: True + + load_from_checkpoint: True + checkpoint_file: null + +dm: + train_dset: + num_points: 256 \ No newline at end of file diff --git a/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzY_place.yaml b/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzY_place.yaml new file mode 100644 index 0000000..639a40f --- /dev/null +++ b/configs/commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzY_place.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: place_hanger_on_rack + - override /phase: place + - _self_ + +break_symmetry: False + +dm: + train_dset: + num_points: 256 \ No newline at end of file diff --git a/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzX_place.yaml b/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzX_place.yaml new file mode 100644 index 0000000..2af7fa4 --- /dev/null +++ b/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzX_place.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: put_toilet_roll_on_stand + - override /phase: place + - _self_ + +break_symmetry: False + +training: + init_cond_x: True + freeze_embnn: True + freeze_residual_flow: True + freeze_z_embnn: True + + load_from_checkpoint: True + checkpoint_file: null \ No newline at end of file diff --git a/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzY_place.yaml b/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzY_place.yaml new file mode 100644 index 0000000..6aa4a49 --- /dev/null +++ b/configs/commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzY_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: put_toilet_roll_on_stand + - override /phase: place + - _self_ + +break_symmetry: False diff --git a/configs/commands/rlbench/solve_puzzle/train_taxposed_pzX_grasp.yaml b/configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzX_place.yaml similarity index 83% rename from configs/commands/rlbench/solve_puzzle/train_taxposed_pzX_grasp.yaml rename to configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzX_place.yaml index 285b197..66af036 100644 --- a/configs/commands/rlbench/solve_puzzle/train_taxposed_pzX_grasp.yaml +++ b/configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzX_place.yaml @@ -2,9 +2,9 @@ defaults: - /commands/rlbench/_train_taxposed@_here_ - - override /model: taxposed + - override /model: taxposed_mlat_s100 - override /task: solve_puzzle - - override /phase: grasp + - override /phase: place - _self_ break_symmetry: False diff --git a/configs/commands/rlbench/solve_puzzle/train_taxposed_pzY_grasp.yaml b/configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzY_place.yaml similarity index 73% rename from configs/commands/rlbench/solve_puzzle/train_taxposed_pzY_grasp.yaml rename to configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzY_place.yaml index 1e12115..e653e8e 100644 --- a/configs/commands/rlbench/solve_puzzle/train_taxposed_pzY_grasp.yaml +++ b/configs/commands/rlbench/solve_puzzle/train_taxposed_mlat_pzY_place.yaml @@ -2,9 +2,9 @@ defaults: - /commands/rlbench/_train_taxposed@_here_ - - override /model: taxposed + - override /model: taxposed_mlat_s100 - override /task: solve_puzzle - - override /phase: grasp + - override /phase: place - _self_ break_symmetry: False diff --git a/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzX_place.yaml b/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzX_place.yaml new file mode 100644 index 0000000..464aaaa --- /dev/null +++ b/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzX_place.yaml @@ -0,0 +1,19 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: stack_wine + - override /phase: place + - _self_ + +break_symmetry: False + +training: + init_cond_x: True + freeze_embnn: True + freeze_residual_flow: True + freeze_z_embnn: True + + load_from_checkpoint: True + checkpoint_file: null \ No newline at end of file diff --git a/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzY_place.yaml b/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzY_place.yaml new file mode 100644 index 0000000..0534974 --- /dev/null +++ b/configs/commands/rlbench/stack_wine/train_taxposed_mlat_pzY_place.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: + - /commands/rlbench/_train_taxposed@_here_ + - override /model: taxposed_mlat_s100 + - override /task: stack_wine + - override /phase: place + - _self_ + +break_symmetry: False diff --git a/configs/model/_taxposed.yaml b/configs/model/_taxposed.yaml index 025049a..7e2db0d 100644 --- a/configs/model/_taxposed.yaml +++ b/configs/model/_taxposed.yaml @@ -3,7 +3,7 @@ name: ??? conditioning: pos_delta_l2norm emb_dims: 512 -emb_nn: dgcnn +emb_nn: cond_dgcnn return_flow_component: True center_feature: True inital_sampling_ratio: 1 @@ -13,6 +13,7 @@ mlat_sample: null mlat_nkps: null break_symmetry: False latent_z_linear_size: 40 +num_points: 1024 gumbel_temp: 1 division_smooth_factor: 1 diff --git a/configs/train_ndf.yaml b/configs/train_ndf.yaml index 966fa51..9fc6068 100644 --- a/configs/train_ndf.yaml +++ b/configs/train_ndf.yaml @@ -45,11 +45,13 @@ dm: training: batch_size: 8 max_epochs: 500 + max_steps: -1 sigmoid_on: True # Optimizer Settings lr: 1e-4 + gradient_clipping: null # Loss Settings flow_supervision: both diff --git a/configs/train_ndf_multimodal.yaml b/configs/train_ndf_multimodal.yaml index 3d23fd6..560557b 100644 --- a/configs/train_ndf_multimodal.yaml +++ b/configs/train_ndf_multimodal.yaml @@ -44,7 +44,7 @@ dm: training: batch_size: 8 - max_epochs: 20 + max_epochs: 500 max_steps: -1 sigmoid_on: True @@ -82,7 +82,7 @@ training: checkpoint_file: null # Visualization Settings - image_logging_period: 1001 + image_logging_period: 100 log_every_n_steps: 100 check_val_every_n_epoch: 5 diff --git a/scripts/README.md b/scripts/README.md index 1adbba4..92a04e8 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -85,6 +85,63 @@ If you write some scripts which are meant to be run stand-alone, and not importe ./launch_autobot.sh 4 python scripts/train_residual_flow.py --config-name commands/rlbench/train_mlat_rlbench_solve_puzzle_place.yaml wandb.group=rlbench_mlat resources.num_workers=0 ./launch_autobot.sh 5 python scripts/train_residual_flow.py --config-name commands/rlbench/train_mlat_rlbench_place_hanger_on_rack_place.yaml wandb.group=rlbench_mlat resources.num_workers=0 +### TAXPoseD +#### TAXPoseD p(z|Y) + +./launch.sh ${RPAD_PLATFORM} 0 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +./launch.sh ${RPAD_PLATFORM} 1 python scripts/train_residual_flow.py --config-name commands/rlbench/insert_onto_square_peg/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +./launch.sh ${RPAD_PLATFORM} 2 python scripts/train_residual_flow.py --config-name commands/rlbench/phone_on_base/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +./launch.sh ${RPAD_PLATFORM} 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_toilet_roll_on_stand/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +./launch.sh ${RPAD_PLATFORM} 4 python scripts/train_residual_flow.py --config-name commands/rlbench/place_hanger_on_rack/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +./launch.sh ${RPAD_PLATFORM} 5 python scripts/train_residual_flow.py --config-name commands/rlbench/solve_puzzle/train_taxposed_pzY_place.yaml wandb.group=rlbench_taxposed + +#### TAXPoseD p(z|X) + +./launch.sh ${RPAD_PLATFORM} 0 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 1 python scripts/train_residual_flow.py --config-name commands/rlbench/insert_onto_square_peg/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 2 python scripts/train_residual_flow.py --config-name commands/rlbench/phone_on_base/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_toilet_roll_on_stand/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 4 python scripts/train_residual_flow.py --config-name commands/rlbench/place_hanger_on_rack/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 5 python scripts/train_residual_flow.py --config-name commands/rlbench/solve_puzzle/train_taxposed_pzX_place.yaml wandb.group=rlbench_taxposed training.checkpoint_file= + +#### TAXPoseD Mlat p(z|Y) + +./launch.sh ${RPAD_PLATFORM} 0 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +./launch.sh ${RPAD_PLATFORM} 1 python scripts/train_residual_flow.py --config-name commands/rlbench/insert_onto_square_peg/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +./launch.sh ${RPAD_PLATFORM} 2 python scripts/train_residual_flow.py --config-name commands/rlbench/phone_on_base/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +./launch.sh ${RPAD_PLATFORM} 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +./launch.sh ${RPAD_PLATFORM} 4 python scripts/train_residual_flow.py --config-name commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +./launch.sh ${RPAD_PLATFORM} 5 python scripts/train_residual_flow.py --config-name commands/rlbench/solve_puzzle/train_taxposed_mlat_pzY_place.yaml wandb.group=rlbench_taxposed_mlat + +#### TAXPoseD Mlat p(z|X) + +./launch.sh ${RPAD_PLATFORM} 0 python scripts/train_residual_flow.py --config-name commands/rlbench/stack_wine/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 1 python scripts/train_residual_flow.py --config-name commands/rlbench/insert_onto_square_peg/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 2 python scripts/train_residual_flow.py --config-name commands/rlbench/phone_on_base/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 3 python scripts/train_residual_flow.py --config-name commands/rlbench/put_toilet_roll_on_stand/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 4 python scripts/train_residual_flow.py --config-name commands/rlbench/place_hanger_on_rack/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + +./launch.sh ${RPAD_PLATFORM} 5 python scripts/train_residual_flow.py --config-name commands/rlbench/solve_puzzle/train_taxposed_mlat_pzX_place.yaml wandb.group=rlbench_taxposed_mlat training.checkpoint_file= + ### Ablations diff --git a/scripts/train_residual_flow.py b/scripts/train_residual_flow.py index be1212b..619f839 100644 --- a/scripts/train_residual_flow.py +++ b/scripts/train_residual_flow.py @@ -11,8 +11,8 @@ from taxpose.datasets.point_cloud_data_module import MultiviewDataModule from taxpose.nets.multimodal_transformer_flow import ( - Multimodal_ResidualFlow_DiffEmbTransformer, - Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX + Multimodal_ResidualFlow_DiffEmbTransformer, + Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX, ) from taxpose.nets.transformer_flow import ResidualFlow_DiffEmbTransformer from taxpose.training.flow_equivariance_training_module_nocentering import ( @@ -20,9 +20,10 @@ ) from taxpose.training.multimodal_flow_equivariance_training_module_nocentering import ( Multimodal_EquivarianceTrainingModule, - Multimodal_EquivarianceTrainingModule_WithPZCondX + Multimodal_EquivarianceTrainingModule_WithPZCondX, ) + def write_to_file(file_name, string): with open(file_name, "a") as f: f.writelines(string) @@ -66,71 +67,134 @@ def maybe_load_from_wandb(checkpoint_reference, wandb_cfg=None, run=None): return ckpt_file -@hydra.main(version_base="1.1", config_path="../configs", config_name="train_ndf") -def main(cfg): - print(OmegaConf.to_yaml(cfg, resolve=True)) +def maybe_load_training_model_weights(model, cfg, logger): + if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]: + if ( + not cfg.training.joint_train_prior + and cfg.training.init_cond_x + and (not cfg.training.freeze_embnn or not cfg.training.freeze_residual_flow) + ): + raise ValueError("YOU PROBABLY DIDN'T MEAN TO DO JOINT TRAINING") + if ( + not cfg.training.joint_train_prior + and cfg.training.init_cond_x + and cfg.training.checkpoint_file is None + ): + raise ValueError( + "YOU PROBABLY DIDN'T MEAN TO TRAIN BOTH P(Z|X) AND P(Z|Y) FROM SCRATCH" + ) - # torch.set_float32_matmul_precision("medium") - pl.seed_everything(cfg.seed) - logger = WandbLogger( - entity=cfg.wandb.entity, - project=cfg.wandb.project, - group=cfg.wandb.group, - save_dir=cfg.wandb.save_dir, - job_type=cfg.job_type, - save_code=True, - log_model=True, - config=omegaconf.OmegaConf.to_container( - cfg, resolve=True, throw_on_missing=True - ), - ) - # logger.log_hyperparams(cfg) - # logger.log_hyperparams({"working_dir": os.getcwd()}) - trainer = pl.Trainer( - logger=logger, - accelerator="gpu", - devices=[0], - log_every_n_steps=cfg.training.log_every_n_steps, - check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, - # reload_dataloaders_every_n_epochs=1, - # callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], - callbacks=[ - # This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes. - # It saves everything, and you can load by referencing last.ckpt. - ModelCheckpoint( - dirpath=cfg.lightning.checkpoint_dir, - filename="{epoch}-{step}", - monitor="step", - mode="max", - save_weights_only=False, - save_last=True, - every_n_epochs=1, - ), - # This checkpoint will get saved to WandB. The Callback mechanism in lightning is poorly designed, so we have to put it last. - ModelCheckpoint( - dirpath=cfg.lightning.checkpoint_dir, - filename="{epoch}-{step}-{train_loss:.2f}-weights-only", - monitor="val_loss", - mode="min", - save_weights_only=True, - ), - ], - max_epochs=cfg.training.max_epochs, - max_steps=cfg.training.max_steps, - gradient_clip_val=cfg.training.gradient_clipping - ) + # TODO: Add support for loading pretraining + if cfg.training.load_from_checkpoint: + assert cfg.training.checkpoint_file is not None + print(f"--------------- Loading Checkpoint File: ---------------") + print(cfg.training.checkpoint_file) - dm = MultiviewDataModule( - batch_size=cfg.training.batch_size, - num_workers=cfg.resources.num_workers, - cfg=cfg.dm, - ) + # If using p(z|X) training model + if cfg.training.init_cond_x: + # If loading p(z|X) from checkpoint + if cfg.training.load_cond_x: + model.load_state_dict( + torch.load( + hydra.utils.to_absolute_path( + maybe_load_from_wandb( + cfg.training.checkpoint_file, + cfg.wandb, + logger.experiment, + ) + ) + )["state_dict"] + ) + print("--------------- Loaded P(z|X) ---------------") + # Otherwise load p(z|Y) from checkpoint + else: + model.training_module_no_cond_x.load_state_dict( + torch.load( + hydra.utils.to_absolute_path( + maybe_load_from_wandb( + cfg.training.checkpoint_file, + cfg.wandb, + logger.experiment, + ) + ) + )["state_dict"] + ) + print("--------------- Loaded P(z|Y) for P(z|X) ---------------") + # Otherwise load p(z|Y) from checkpoint + else: + model.load_state_dict( + torch.load( + hydra.utils.to_absolute_path( + maybe_load_from_wandb( + cfg.training.checkpoint_file, + cfg.wandb, + logger.experiment, + ) + ) + )["state_dict"] + ) + print("--------------- Loaded P(z|Y) ---------------") + print("--------------- Done Loading Checkpoint File ---------------\n") - dm.setup() + else: + if cfg.training.load_from_checkpoint: + print("loaded checkpoint from") + print(cfg.training.checkpoint_file) + model.load_state_dict( + torch.load(hydra.utils.to_absolute_path(cfg.training.checkpoint_file))[ + "state_dict" + ] + ) + else: + # Might be empty and not have those keys defined. + # TODO: move this pretraining into the model itself. + # TODO: figure out if we can get rid of the dictionary and make it null. + if cfg.model.pretraining: + if cfg.model.pretraining.checkpoint_file_action is not None: + # # Check to see if it's a wandb checkpoint. + # TODO: need to retrain a few things... checkpoint didn't stick... + emb_nn_action_state_dict = load_emb_weights( + cfg.pretraining.checkpoint_file_action, + cfg.wandb, + logger.experiment, + ) + # checkpoint_file_fn = maybe_load_from_wandb( + # cfg.pretraining.checkpoint_file_action, cfg.wandb, logger.experiment.run + # ) + + model.model.emb_nn_action.load_state_dict(emb_nn_action_state_dict) + print( + "-----------------------Pretrained EmbNN Action Model Loaded!-----------------------" + ) + print( + "Loaded Pretrained EmbNN Action: {}".format( + cfg.pretraining.checkpoint_file_action + ) + ) + if cfg.pretraining.checkpoint_file_anchor is not None: + emb_nn_anchor_state_dict = load_emb_weights( + cfg.pretraining.checkpoint_file_anchor, + cfg.wandb, + logger.experiment, + ) + model.model.emb_nn_anchor.load_state_dict(emb_nn_anchor_state_dict) + print( + "-----------------------Pretrained EmbNN Anchor Model Loaded!-----------------------" + ) + print( + "Loaded Pretrained EmbNN Anchor: {}".format( + cfg.pretraining.checkpoint_file_anchor + ) + ) + + +def get_training_network(cfg): if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]: - TP_input_dims = Multimodal_ResidualFlow_DiffEmbTransformer.TP_INPUT_DIMS[cfg.model.conditioning] - + TP_input_dims = Multimodal_ResidualFlow_DiffEmbTransformer.TP_INPUT_DIMS[ + cfg.model.conditioning + ] + taxpose_decoder_network = ResidualFlow_DiffEmbTransformer( emb_dims=cfg.model.emb_dims, input_dims=TP_input_dims, @@ -142,9 +206,11 @@ def main(cfg): sample=cfg.model.mlat_sample, mlat_nkps=cfg.model.mlat_nkps, break_symmetry=cfg.break_symmetry, - conditioning_size=cfg.model.latent_z_linear_size if cfg.model.conditioning in ["latent_z_linear_internalcond"] else 0, + conditioning_size=cfg.model.latent_z_linear_size + if cfg.model.conditioning in ["latent_z_linear_internalcond"] + else 0, ) - + print("--------------- Initializing P(z|Y) ---------------") pzY_prediction_network = Multimodal_ResidualFlow_DiffEmbTransformer( residualflow_diffembtransformer=taxpose_decoder_network, gumbel_temp=cfg.model.gumbel_temp, @@ -155,11 +221,28 @@ def main(cfg): add_smooth_factor=cfg.model.add_smooth_factor, conditioning=cfg.model.conditioning, latent_z_linear_size=cfg.model.latent_z_linear_size, - taxpose_centering=cfg.model.taxpose_centering + taxpose_centering=cfg.model.taxpose_centering, + ) + network = pzY_prediction_network + else: + network = ResidualFlow_DiffEmbTransformer( + emb_dims=cfg.model.emb_dims, + emb_nn=cfg.model.emb_nn, + return_flow_component=cfg.model.return_flow_component, + center_feature=cfg.model.center_feature, + pred_weight=cfg.model.pred_weight, + multilaterate=cfg.model.multilaterate, + sample=cfg.model.mlat_sample, + mlat_nkps=cfg.model.mlat_nkps, + break_symmetry=cfg.break_symmetry, ) - + return network + + +def get_training_model(training_network, cfg, logger): + if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]: training_model = Multimodal_EquivarianceTrainingModule( - pzY_prediction_network, + training_network, lr=cfg.training.lr, image_log_period=cfg.training.image_logging_period, point_loss_type=cfg.training.point_loss_type, @@ -173,22 +256,19 @@ def main(cfg): sigmoid_on=cfg.training.sigmoid_on, softmax_temperature=cfg.task.phase.softmax_temperature, min_err_across_racks_debug=cfg.training.min_err_across_racks_debug, - error_mode_2rack=cfg.training.error_mode_2rack + error_mode_2rack=cfg.training.error_mode_2rack, ) - - if not cfg.training.joint_train_prior and cfg.training.init_cond_x and (not cfg.training.freeze_embnn or not cfg.training.freeze_residual_flow): - raise ValueError("YOU PROBABLY DIDN'T MEAN TO DO JOINT TRAINING") - if not cfg.training.joint_train_prior and cfg.training.init_cond_x and cfg.training.checkpoint_file is None: - raise ValueError("YOU PROBABLY DIDN'T MEAN TO TRAIN BOTH P(Z|X) AND P(Z|Y) FROM SCRATCH") - - if cfg.training.init_cond_x: - print(f'--------------- Initializing P(z|X) ---------------') - pzX_prediction_network = Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX( - residualflow_embnn=pzY_prediction_network, - encoder_type=cfg.model.pzcondx_encoder_type, - shuffle_for_pzX=cfg.model.shuffle_for_pzX, + + if cfg.training.init_cond_x: + print(f"--------------- Initializing P(z|X) ---------------") + pzX_prediction_network = ( + Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX( + residualflow_embnn=training_network, + encoder_type=cfg.model.pzcondx_encoder_type, + shuffle_for_pzX=cfg.model.shuffle_for_pzX, + ) ) - + training_model_cond_x = Multimodal_EquivarianceTrainingModule_WithPZCondX( pzX_prediction_network, training_model, @@ -197,44 +277,14 @@ def main(cfg): joint_train_prior_freeze_embnn=cfg.training.joint_train_prior_freeze_embnn, freeze_residual_flow=cfg.training.freeze_residual_flow, freeze_z_embnn=cfg.training.freeze_z_embnn, - freeze_embnn=cfg.training.freeze_embnn) - - # TODO: Add support for loading pretraining - if cfg.training.load_from_checkpoint: - assert cfg.training.checkpoint_file is not None - print(f'--------------- Loading Checkpoint File From: ---------------') - print(cfg.training.checkpoint_file) - if not cfg.training.load_cond_x: - training_model.load_state_dict( - torch.load(hydra.utils.to_absolute_path( - maybe_load_from_wandb(cfg.training.checkpoint_file, cfg.wandb, logger.experiment)))["state_dict"]) - print(f'--------------- Loaded P(z|Y) ---------------') - else: - training_model_cond_x.load_state_dict( - torch.load(hydra.utils.to_absolute_path( - maybe_load_from_wandb(cfg.training.checkpoint_file, cfg.wandb, logger.experiment)))["state_dict"]) - print(f'--------------- Loaded P(z|X) ---------------') - print(f'--------------- Checkpoint Loaded ---------------') - + freeze_embnn=cfg.training.freeze_embnn, + ) + model = training_model_cond_x if cfg.training.init_cond_x else training_model - model.cuda() - model.train() - - else: - network = ResidualFlow_DiffEmbTransformer( - emb_dims=cfg.model.emb_dims, - emb_nn=cfg.model.emb_nn, - return_flow_component=cfg.model.return_flow_component, - center_feature=cfg.model.center_feature, - pred_weight=cfg.model.pred_weight, - multilaterate=cfg.model.multilaterate, - sample=cfg.model.mlat_sample, - mlat_nkps=cfg.model.mlat_nkps, - break_symmetry=cfg.break_symmetry, - ) + else: model = EquivarianceTrainingModule( - network, + training_network, lr=cfg.training.lr, image_log_period=cfg.training.image_logging_period, displace_loss_weight=cfg.training.displace_loss_weight, @@ -246,55 +296,78 @@ def main(cfg): flow_supervision=cfg.training.flow_supervision, ) - model.cuda() - model.train() - if cfg.training.load_from_checkpoint: - print("loaded checkpoint from") - print(cfg.training.checkpoint_file) - model.load_state_dict( - torch.load(hydra.utils.to_absolute_path(cfg.training.checkpoint_file))[ - "state_dict" - ] - ) + model.cuda() + model.train() + return model - else: - # Might be empty and not have those keys defined. - # TODO: move this pretraining into the model itself. - # TODO: figure out if we can get rid of the dictionary and make it null. - if cfg.model.pretraining: - if cfg.model.pretraining.checkpoint_file_action is not None: - # # Check to see if it's a wandb checkpoint. - # TODO: need to retrain a few things... checkpoint didn't stick... - emb_nn_action_state_dict = load_emb_weights( - cfg.pretraining.checkpoint_file_action, cfg.wandb, logger.experiment - ) - # checkpoint_file_fn = maybe_load_from_wandb( - # cfg.pretraining.checkpoint_file_action, cfg.wandb, logger.experiment.run - # ) - model.model.emb_nn_action.load_state_dict(emb_nn_action_state_dict) - print( - "-----------------------Pretrained EmbNN Action Model Loaded!-----------------------" - ) - print( - "Loaded Pretrained EmbNN Action: {}".format( - cfg.pretraining.checkpoint_file_action - ) - ) - if cfg.pretraining.checkpoint_file_anchor is not None: - emb_nn_anchor_state_dict = load_emb_weights( - cfg.pretraining.checkpoint_file_anchor, cfg.wandb, logger.experiment - ) - model.model.emb_nn_anchor.load_state_dict(emb_nn_anchor_state_dict) - print( - "-----------------------Pretrained EmbNN Anchor Model Loaded!-----------------------" - ) - print( - "Loaded Pretrained EmbNN Anchor: {}".format( - cfg.pretraining.checkpoint_file_anchor - ) - ) - +@hydra.main(version_base="1.1", config_path="../configs", config_name="train_ndf") +def main(cfg): + print(OmegaConf.to_yaml(cfg, resolve=True)) + + # torch.set_float32_matmul_precision("medium") + pl.seed_everything(cfg.seed) + logger = WandbLogger( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + group=cfg.wandb.group, + save_dir=cfg.wandb.save_dir, + job_type=cfg.job_type, + save_code=True, + log_model=True, + config=omegaconf.OmegaConf.to_container( + cfg, resolve=True, throw_on_missing=True + ), + ) + # logger.log_hyperparams(cfg) + # logger.log_hyperparams({"working_dir": os.getcwd()}) + trainer = pl.Trainer( + logger=logger, + accelerator="gpu", + devices=[0], + log_every_n_steps=cfg.training.log_every_n_steps, + check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, + # reload_dataloaders_every_n_epochs=1, + # callbacks=[SaverCallbackModel(), SaverCallbackEmbnnActionAnchor()], + callbacks=[ + # This checkpoint callback saves the latest model during training, i.e. so we can resume if it crashes. + # It saves everything, and you can load by referencing last.ckpt. + ModelCheckpoint( + dirpath=cfg.lightning.checkpoint_dir, + filename="{epoch}-{step}", + monitor="step", + mode="max", + save_weights_only=False, + save_last=True, + every_n_epochs=1, + ), + # This checkpoint will get saved to WandB. The Callback mechanism in lightning is poorly designed, so we have to put it last. + ModelCheckpoint( + dirpath=cfg.lightning.checkpoint_dir, + filename="{epoch}-{step}-{train_loss:.2f}-weights-only", + monitor="val_loss", + mode="min", + save_weights_only=True, + ), + ], + max_epochs=cfg.training.max_epochs, + max_steps=cfg.training.max_steps, + gradient_clip_val=cfg.training.gradient_clipping, + ) + + dm = MultiviewDataModule( + batch_size=cfg.training.batch_size, + num_workers=cfg.resources.num_workers, + cfg=cfg.dm, + ) + dm.setup() + + network = get_training_network(cfg) + + model = get_training_model(network, cfg, logger) + + maybe_load_training_model_weights(model, cfg, logger) + trainer.fit(model, dm) diff --git a/taxpose/nets/multimodal_transformer_flow.py b/taxpose/nets/multimodal_transformer_flow.py index 8706e99..def5570 100644 --- a/taxpose/nets/multimodal_transformer_flow.py +++ b/taxpose/nets/multimodal_transformer_flow.py @@ -6,57 +6,70 @@ import torch.nn as nn import torch.nn.functional as F -from third_party.dcp.model import DGCNN, DGCNNClassification +from taxpose.nets.taxposed_dgcnn import DGCNN, DGCNNClassification + class Multimodal_ResidualFlow_DiffEmbTransformer(nn.Module): EMB_DIMS_BY_CONDITIONING = { - 'pos_delta_l2norm': 1, + "pos_delta_l2norm": 1, "uniform_prior_pos_delta_l2norm": 1, # 'latent_z': 1, # Make the dimensions as close as possible to the ablations we're comparing this against # 'latent_z_1pred': 1, # Same # 'latent_z_1pred_10d': 10, # Same - 'latent_z_linear': 512, - 'latent_z_linear_internalcond': 512, - 'pos_delta_vec': 1, - 'pos_onehot': 1, - 'pos_loc3d': 3, + "latent_z_linear": 512, + "latent_z_linear_internalcond": 512, + "pos_delta_vec": 1, + "pos_onehot": 1, + "pos_loc3d": 3, } # Number of heads that the DGCNN should output NUM_HEADS_BY_CONDITIONING = { - 'pos_delta_l2norm': 1, + "pos_delta_l2norm": 1, "uniform_prior_pos_delta_l2norm": 1, # 'latent_z': 2, # One for mu and one for var # 'latent_z_1pred': 2, # Same # 'latent_z_1pred_10d': 2, # Same - 'latent_z_linear': 2, - 'latent_z_linear_internalcond': 2, - 'pos_delta_vec': 1, - 'pos_onehot': 1, - 'pos_loc3d': 1, + "latent_z_linear": 2, + "latent_z_linear_internalcond": 2, + "pos_delta_vec": 1, + "pos_onehot": 1, + "pos_loc3d": 1, } DEPRECATED_CONDITIONINGS = ["latent_z", "latent_z_1pred", "latent_z_1pred_10d"] TP_INPUT_DIMS = { - 'pos_delta_l2norm': 3 + 1, - 'uniform_prior_pos_delta_l2norm': 3 + 1, + "pos_delta_l2norm": 3 + 1, + "uniform_prior_pos_delta_l2norm": 3 + 1, # Not implemented because it's dynamic. Also this isn't used anymore # 'latent_z_linear': 3 + cfg.latent_z_linear_size, - 'latent_z_linear_internalcond': 3, - 'pos_delta_vec': 3 + 3, - 'pos_onehot': 3 + 1, - 'pos_loc3d': 3 + 3, + "latent_z_linear_internalcond": 3, + "pos_delta_vec": 3 + 3, + "pos_onehot": 3 + 1, + "pos_loc3d": 3 + 3, "latent_3d_z": 3 + 3, } - - def __init__(self, residualflow_diffembtransformer, gumbel_temp=0.5, freeze_residual_flow=False, center_feature=False, freeze_z_embnn=False, - division_smooth_factor=1, add_smooth_factor=0.05, conditioning="pos_delta_l2norm", latent_z_linear_size=40, - taxpose_centering="mean"): + + def __init__( + self, + residualflow_diffembtransformer, + gumbel_temp=0.5, + freeze_residual_flow=False, + center_feature=False, + freeze_z_embnn=False, + division_smooth_factor=1, + add_smooth_factor=0.05, + conditioning="pos_delta_l2norm", + latent_z_linear_size=40, + taxpose_centering="mean", + ): super(Multimodal_ResidualFlow_DiffEmbTransformer, self).__init__() assert taxpose_centering in ["mean", "z"] - assert conditioning not in self.DEPRECATED_CONDITIONINGS, f"This conditioning {conditioning} is deprecated and should not be used" + assert ( + conditioning not in self.DEPRECATED_CONDITIONINGS + ), f"This conditioning {conditioning} is deprecated and should not be used" assert conditioning in self.EMB_DIMS_BY_CONDITIONING.keys() self.latent_z_linear_size = latent_z_linear_size @@ -71,9 +84,16 @@ def __init__(self, residualflow_diffembtransformer, gumbel_temp=0.5, freeze_resi self.num_emb_heads = self.NUM_HEADS_BY_CONDITIONING[self.conditioning] # Point cloud with class labels between action and anchor if self.conditioning not in ["latent_z_linear", "latent_z_linear_internalcond"]: - self.emb_nn_objs_at_goal = DGCNN(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, last_relu=False) + self.emb_nn_objs_at_goal = DGCNN( + emb_dims=self.emb_dims, num_heads=self.num_emb_heads, last_relu=False + ) else: - self.emb_nn_objs_at_goal = DGCNNClassification(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, dropout=0.5, output_channels=self.latent_z_linear_size) + self.emb_nn_objs_at_goal = DGCNNClassification( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + dropout=0.5, + output_channels=self.latent_z_linear_size, + ) # TODO self.freeze_residual_flow = freeze_residual_flow self.center_feature = center_feature @@ -83,64 +103,101 @@ def __init__(self, residualflow_diffembtransformer, gumbel_temp=0.5, freeze_resi self.division_smooth_factor = division_smooth_factor self.add_smooth_factor = add_smooth_factor - + def get_dense_translation_point(self, points, ref, conditioning): """ - points- point cloud. (B, 3, num_points) - ref- one hot vector (or nearly one-hot) that denotes the reference point - (B, num_points) + points- point cloud. (B, 3, num_points) + ref- one hot vector (or nearly one-hot) that denotes the reference point + (B, num_points) - Returns: - dense point cloud. Each point contains the distance to the reference point (B, 3 or 1, num_points) + Returns: + dense point cloud. Each point contains the distance to the reference point (B, 3 or 1, num_points) """ assert ref.ndim == 2 - assert torch.allclose(ref.sum(axis=1), torch.full((ref.shape[0], 1), 1, dtype=torch.float, device=ref.device)) + assert torch.allclose( + ref.sum(axis=1), + torch.full((ref.shape[0], 1), 1, dtype=torch.float, device=ref.device), + ) num_points = points.shape[2] - reference = (points*ref[:,None,:]).sum(axis=2) + reference = (points * ref[:, None, :]).sum(axis=2) if conditioning in ["pos_delta_l2norm", "uniform_prior_pos_delta_l2norm"]: dense = torch.norm(reference[:, :, None] - points, dim=1, keepdim=True) elif conditioning == "pos_delta_vec": dense = reference[:, :, None] - points elif conditioning == "pos_loc3d": - dense = reference[:,:,None].repeat(1, 1, 1024) + dense = reference[:, :, None].repeat(1, 1, 1024) elif conditioning == "pos_onehot": dense = ref[:, None, :] else: - raise ValueError(f"Conditioning {conditioning} probably doesn't require a dense representation. This function is for" \ - + "['pos_delta_l2norm', 'pos_delta_vec', 'pos_loc3d', 'pos_onehot', 'uniform_prior_pos_delta_l2norm']") + raise ValueError( + f"Conditioning {conditioning} probably doesn't require a dense representation. This function is for" + + "['pos_delta_l2norm', 'pos_delta_vec', 'pos_loc3d', 'pos_onehot', 'uniform_prior_pos_delta_l2norm']" + ) return dense, reference - def add_conditioning(self, goal_emb, action_points, anchor_points, conditioning): for_debug = {} - if conditioning in ['pos_delta_l2norm', 'pos_delta_vec', 'pos_loc3d', 'pos_onehot', 'uniform_prior_pos_delta_l2norm']: - + if conditioning in [ + "pos_delta_l2norm", + "pos_delta_vec", + "pos_loc3d", + "pos_onehot", + "uniform_prior_pos_delta_l2norm", + ]: goal_emb = (goal_emb + self.add_smooth_factor) / self.division_smooth_factor # Only handle the translation case for now - goal_emb_translation = goal_emb[:,0,:] + goal_emb_translation = goal_emb[:, 0, :] + + goal_emb_translation_action = goal_emb_translation[ + :, : action_points.shape[2] + ] + goal_emb_translation_anchor = goal_emb_translation[ + :, action_points.shape[2] : + ] + + translation_sample_action = F.gumbel_softmax( + goal_emb_translation_action, self.gumbel_temp, hard=True, dim=-1 + ) + translation_sample_anchor = F.gumbel_softmax( + goal_emb_translation_anchor, self.gumbel_temp, hard=True, dim=-1 + ) - goal_emb_translation_action = goal_emb_translation[:, :action_points.shape[2]] - goal_emb_translation_anchor = goal_emb_translation[:, action_points.shape[2]:] - - translation_sample_action = F.gumbel_softmax(goal_emb_translation_action, self.gumbel_temp, hard=True, dim=-1) - translation_sample_anchor = F.gumbel_softmax(goal_emb_translation_anchor, self.gumbel_temp, hard=True, dim=-1) - # This is the only line that's different among the 3 different conditioning schemes in this category - dense_trans_pt_action, ref_action = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point(None, action_points, translation_sample_action, conditioning=self.conditioning) - dense_trans_pt_anchor, ref_anchor = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point(None, anchor_points, translation_sample_anchor, conditioning=self.conditioning) - - action_points_and_cond = torch.cat([action_points] + [dense_trans_pt_action], axis=1) - anchor_points_and_cond = torch.cat([anchor_points] + [dense_trans_pt_anchor], axis=1) + ( + dense_trans_pt_action, + ref_action, + ) = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point( + None, + action_points, + translation_sample_action, + conditioning=self.conditioning, + ) + ( + dense_trans_pt_anchor, + ref_anchor, + ) = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point( + None, + anchor_points, + translation_sample_anchor, + conditioning=self.conditioning, + ) + + action_points_and_cond = torch.cat( + [action_points] + [dense_trans_pt_action], axis=1 + ) + anchor_points_and_cond = torch.cat( + [anchor_points] + [dense_trans_pt_anchor], axis=1 + ) for_debug = { - 'dense_trans_pt_action': dense_trans_pt_action, - 'dense_trans_pt_anchor': dense_trans_pt_anchor, - 'trans_pt_action': ref_action, - 'trans_pt_anchor': ref_anchor, - 'trans_sample_action': translation_sample_action, - 'trans_sample_anchor': translation_sample_anchor, + "dense_trans_pt_action": dense_trans_pt_action, + "dense_trans_pt_anchor": dense_trans_pt_anchor, + "trans_pt_action": ref_action, + "trans_pt_anchor": ref_anchor, + "trans_sample_action": translation_sample_action, + "trans_sample_anchor": translation_sample_anchor, } elif conditioning in ["latent_z_linear", "latent_z_linear_internalcond"]: # Do the reparametrization trick on the predicted mu and var @@ -153,26 +210,36 @@ def reparametrize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu - + goal_emb = reparametrize(goal_emb_mu, goal_emb_logvar) for_debug = { - 'goal_emb_mu': goal_emb_mu, - 'goal_emb_logvar': goal_emb_logvar, + "goal_emb_mu": goal_emb_mu, + "goal_emb_logvar": goal_emb_logvar, } if conditioning == "latent_z_linear": - action_points_and_cond = torch.cat([action_points] + [torch.tile(goal_emb, (1, 1, action_points.shape[-1]))], axis=1) - anchor_points_and_cond = torch.cat([anchor_points] + [torch.tile(goal_emb, (1, 1, anchor_points.shape[-1]))], axis=1) + action_points_and_cond = torch.cat( + [action_points] + + [torch.tile(goal_emb, (1, 1, action_points.shape[-1]))], + axis=1, + ) + anchor_points_and_cond = torch.cat( + [anchor_points] + + [torch.tile(goal_emb, (1, 1, anchor_points.shape[-1]))], + axis=1, + ) elif conditioning == "latent_z_linear_internalcond": # The cond will be added in by TAXPose action_points_and_cond = action_points anchor_points_and_cond = anchor_points - for_debug['goal_emb'] = goal_emb + for_debug["goal_emb"] = goal_emb else: raise ValueError("Why is it here?") else: - raise ValueError(f"Conditioning {conditioning} does not exist. Choose one of: {list(self.EMB_DIMS_BY_CONDITIONING.keys())}") + raise ValueError( + f"Conditioning {conditioning} does not exist. Choose one of: {list(self.EMB_DIMS_BY_CONDITIONING.keys())}" + ) return action_points_and_cond, anchor_points_and_cond, for_debug @@ -180,9 +247,9 @@ def forward(self, *input, mode="forward"): # Forward pass goes through all of the model # Inference will use a sample from the prior if there is one # - ex: conditioning = latent_z_linear_internalcond - assert mode in ['forward', 'inference'] + assert mode in ["forward", "inference"] - action_points = input[0].permute(0, 2, 1)[:, :3] # B,3,num_points + action_points = input[0].permute(0, 2, 1)[:, :3] # B,3,num_points anchor_points = input[1].permute(0, 2, 1)[:, :3] if input[2] is None: @@ -194,69 +261,97 @@ def forward(self, *input, mode="forward"): # mean center point cloud before DGCNN if self.center_feature: - mean_goal = torch.cat([goal_action_points, goal_anchor_points], axis=-1).mean(dim=2, keepdim=True) - goal_action_points_dmean = goal_action_points - \ - mean_goal - goal_anchor_points_dmean = goal_anchor_points - \ - mean_goal - action_points_dmean = action_points - \ - action_points.mean(dim=2, keepdim=True) - anchor_points_dmean = anchor_points - \ - anchor_points.mean(dim=2, keepdim=True) + mean_goal = torch.cat( + [goal_action_points, goal_anchor_points], axis=-1 + ).mean(dim=2, keepdim=True) + goal_action_points_dmean = goal_action_points - mean_goal + goal_anchor_points_dmean = goal_anchor_points - mean_goal + action_points_dmean = action_points - action_points.mean( + dim=2, keepdim=True + ) + anchor_points_dmean = anchor_points - anchor_points.mean( + dim=2, keepdim=True + ) else: goal_action_points_dmean = goal_action_points goal_anchor_points_dmean = goal_anchor_points action_points_dmean = action_points anchor_points_dmean = anchor_points - goal_points_dmean = torch.cat([goal_action_points_dmean, goal_anchor_points_dmean], axis=2) + goal_points_dmean = torch.cat( + [goal_action_points_dmean, goal_anchor_points_dmean], axis=2 + ) if self.freeze_z_embnn: with torch.no_grad(): if self.num_emb_heads > 1: - goal_emb = [a.detach() for a in self.emb_nn_objs_at_goal(goal_points_dmean)] + goal_emb = [ + a.detach() + for a in self.emb_nn_objs_at_goal(goal_points_dmean) + ] else: goal_emb = self.emb_nn_objs_at_goal(goal_points_dmean).detach() else: goal_emb = self.emb_nn_objs_at_goal(goal_points_dmean) - - action_points_and_cond, anchor_points_and_cond, for_debug = self.add_conditioning(goal_emb, action_points, anchor_points, self.conditioning) + ( + action_points_and_cond, + anchor_points_and_cond, + for_debug, + ) = self.add_conditioning( + goal_emb, action_points, anchor_points, self.conditioning + ) elif mode == "inference": - action_points_and_cond, anchor_points_and_cond, goal_emb, for_debug = self.sample(action_points, anchor_points) + ( + action_points_and_cond, + anchor_points_and_cond, + goal_emb, + for_debug, + ) = self.sample(action_points, anchor_points) else: raise ValueError(f"Unknown mode {mode}") tax_pose_conditioning_action = None tax_pose_conditioning_anchor = None if self.conditioning == "latent_z_linear_internalcond": - tax_pose_conditioning_action = torch.tile(for_debug['goal_emb'], (1, 1, action_points.shape[-1])) - tax_pose_conditioning_anchor = torch.tile(for_debug['goal_emb'], (1, 1, anchor_points.shape[-1])) + tax_pose_conditioning_action = torch.tile( + for_debug["goal_emb"], (1, 1, action_points.shape[-1]) + ) + tax_pose_conditioning_anchor = torch.tile( + for_debug["goal_emb"], (1, 1, anchor_points.shape[-1]) + ) if self.taxpose_centering == "mean": # Use TAX-Pose defaults action_center = action_points[:, :3].mean(dim=2, keepdim=True) anchor_center = anchor_points[:, :3].mean(dim=2, keepdim=True) elif self.taxpose_centering == "z": - action_center = for_debug['trans_pt_action'][:,:,None] - anchor_center = for_debug['trans_pt_anchor'][:,:,None] + action_center = for_debug["trans_pt_action"][:, :, None] + anchor_center = for_debug["trans_pt_anchor"][:, :, None] else: - raise ValueError(f"Unknown self.taxpose_centering: {self.taxpose_centering}") + raise ValueError( + f"Unknown self.taxpose_centering: {self.taxpose_centering}" + ) if self.freeze_residual_flow: with torch.no_grad(): - flow_action = self.tax_pose(action_points_and_cond.permute(0, 2, 1), anchor_points_and_cond.permute(0, 2, 1), - conditioning_action=tax_pose_conditioning_action, - conditioning_anchor=tax_pose_conditioning_anchor, - action_center=action_center, - anchor_center=anchor_center) + flow_action = self.tax_pose( + action_points_and_cond.permute(0, 2, 1), + anchor_points_and_cond.permute(0, 2, 1), + conditioning_action=tax_pose_conditioning_action, + conditioning_anchor=tax_pose_conditioning_anchor, + action_center=action_center, + anchor_center=anchor_center, + ) else: - flow_action = self.tax_pose(action_points_and_cond.permute(0, 2, 1), anchor_points_and_cond.permute(0, 2, 1), - conditioning_action=tax_pose_conditioning_action, - conditioning_anchor=tax_pose_conditioning_anchor, - action_center=action_center, - anchor_center=anchor_center) - + flow_action = self.tax_pose( + action_points_and_cond.permute(0, 2, 1), + anchor_points_and_cond.permute(0, 2, 1), + conditioning_action=tax_pose_conditioning_action, + conditioning_anchor=tax_pose_conditioning_anchor, + action_center=action_center, + anchor_center=anchor_center, + ) ########## LOGGING ############ @@ -265,16 +360,16 @@ def forward(self, *input, mode="forward"): goal_emb = goal_emb[0] if self.freeze_residual_flow: - flow_action['flow_action'] = flow_action['flow_action'].detach() - flow_action['flow_anchor'] = flow_action['flow_anchor'].detach() - + flow_action["flow_action"] = flow_action["flow_action"].detach() + flow_action["flow_anchor"] = flow_action["flow_anchor"].detach() + flow_action = { - **flow_action, - 'goal_emb': goal_emb, + **flow_action, + "goal_emb": goal_emb, **for_debug, } return flow_action - + def sample(self, action_points, anchor_points): if self.conditioning in ["latent_z_linear", "latent_z_linear_internalcond"]: # Take a SINGLE sample z ~ N(0,1) @@ -282,45 +377,89 @@ def sample(self, action_points, anchor_points): goal_emb_action = None goal_emb_anchor = None if self.conditioning == "latent_z_linear": - goal_emb = torch.tile(torch.randn((action_points.shape[0], self.emb_dims, 1)).to(action_points.device), (1, 1, action_points.shape[-1])) + goal_emb = torch.tile( + torch.randn((action_points.shape[0], self.emb_dims, 1)).to( + action_points.device + ), + (1, 1, action_points.shape[-1]), + ) action_points_and_cond = torch.cat([action_points, goal_emb], axis=1) anchor_points_and_cond = torch.cat([anchor_points, goal_emb], axis=1) elif self.conditioning == "latent_z_linear_internalcond": - goal_emb = torch.randn((action_points.shape[0], self.latent_z_linear_size, 1)).to(action_points.device) + goal_emb = torch.randn( + (action_points.shape[0], self.latent_z_linear_size, 1) + ).to(action_points.device) action_points_and_cond = action_points anchor_points_and_cond = anchor_points - for_debug['goal_emb'] = goal_emb + for_debug["goal_emb"] = goal_emb else: raise ValueError("Why is it here?") - elif self.conditioning in ['uniform_prior_pos_delta_l2norm']: + elif self.conditioning in ["uniform_prior_pos_delta_l2norm"]: # sample from a uniform prior - N_action, N_anchor, B = action_points.shape[-1], anchor_points.shape[-1], action_points.shape[0] - translation_sample_action = F.one_hot(torch.randint(N_action, (B,)), N_action).float().cuda() - translation_sample_anchor = F.one_hot(torch.randint(N_anchor, (B,)), N_anchor).float().cuda() - - dense_trans_pt_action, ref_action = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point(None, action_points, translation_sample_action, conditioning=self.conditioning) - dense_trans_pt_anchor, ref_anchor = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point(None, anchor_points, translation_sample_anchor, conditioning=self.conditioning) - - action_points_and_cond = torch.cat([action_points] + [dense_trans_pt_action], axis=1) - anchor_points_and_cond = torch.cat([anchor_points] + [dense_trans_pt_anchor], axis=1) + N_action, N_anchor, B = ( + action_points.shape[-1], + anchor_points.shape[-1], + action_points.shape[0], + ) + translation_sample_action = ( + F.one_hot(torch.randint(N_action, (B,)), N_action).float().cuda() + ) + translation_sample_anchor = ( + F.one_hot(torch.randint(N_anchor, (B,)), N_anchor).float().cuda() + ) + + ( + dense_trans_pt_action, + ref_action, + ) = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point( + None, + action_points, + translation_sample_action, + conditioning=self.conditioning, + ) + ( + dense_trans_pt_anchor, + ref_anchor, + ) = Multimodal_ResidualFlow_DiffEmbTransformer.get_dense_translation_point( + None, + anchor_points, + translation_sample_anchor, + conditioning=self.conditioning, + ) + + action_points_and_cond = torch.cat( + [action_points] + [dense_trans_pt_action], axis=1 + ) + anchor_points_and_cond = torch.cat( + [anchor_points] + [dense_trans_pt_anchor], axis=1 + ) goal_emb = None for_debug = { - 'dense_trans_pt_action': dense_trans_pt_action, - 'dense_trans_pt_anchor': dense_trans_pt_anchor, - 'trans_pt_action': ref_action, - 'trans_pt_anchor': ref_anchor, - 'trans_sample_action': translation_sample_action, - 'trans_sample_anchor': translation_sample_anchor, + "dense_trans_pt_action": dense_trans_pt_action, + "dense_trans_pt_anchor": dense_trans_pt_anchor, + "trans_pt_action": ref_action, + "trans_pt_anchor": ref_anchor, + "trans_sample_action": translation_sample_action, + "trans_sample_anchor": translation_sample_anchor, } else: - raise ValueError(f"Sampling not supported for conditioning {self.conditioning}. Pick one of the latent_z_xxx conditionings") + raise ValueError( + f"Sampling not supported for conditioning {self.conditioning}. Pick one of the latent_z_xxx conditionings" + ) return action_points_and_cond, anchor_points_and_cond, goal_emb, for_debug class Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX(nn.Module): - def __init__(self, residualflow_embnn, encoder_type="2_dgcnn", sample_z=True, shuffle_for_pzX=False, return_debug=False): + def __init__( + self, + residualflow_embnn, + encoder_type="2_dgcnn", + sample_z=True, + shuffle_for_pzX=False, + return_debug=False, + ): super(Multimodal_ResidualFlow_DiffEmbTransformer_WithPZCondX, self).__init__() self.residflow_embnn = residualflow_embnn @@ -336,13 +475,13 @@ def __init__(self, residualflow_embnn, encoder_type="2_dgcnn", sample_z=True, sh self.shuffle_for_pzX = shuffle_for_pzX self.return_debug = return_debug - #assert self.conditioning not in ['uniform_prior_pos_delta_l2norm'] + # assert self.conditioning not in ['uniform_prior_pos_delta_l2norm'] # assert self.conditioning not in ["latent_z_linear", "latent_z", "latent_z_1pred", "latent_z_1pred_10d", "latent_z_linear_internalcond"], "Latent z conditioning does not need a p(z|X) because it's regularized to N(0,1)" # Note: 1 DGCNN probably loses some of the rotational invariance between objects assert encoder_type in ["1_dgcnn", "2_dgcnn"] - + # disable smoothing self.add_smooth_factor = 0.05 self.division_smooth_factor = 1.0 @@ -353,33 +492,59 @@ def __init__(self, residualflow_embnn, encoder_type="2_dgcnn", sample_z=True, sh if self.conditioning not in ["latent_z_linear", "latent_z_linear_internalcond"]: if self.encoder_type == "1_dgcnn": - self.p_z_cond_x_embnn = DGCNN(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, last_relu=False) + self.p_z_cond_x_embnn = DGCNN( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + last_relu=False, + ) elif self.encoder_type == "2_dgcnn": - self.p_z_cond_x_embnn_action = DGCNN(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, last_relu=False) - self.p_z_cond_x_embnn_anchor = DGCNN(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, last_relu=False) + self.p_z_cond_x_embnn_action = DGCNN( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + last_relu=False, + ) + self.p_z_cond_x_embnn_anchor = DGCNN( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + last_relu=False, + ) else: raise ValueError() else: if self.encoder_type == "1_dgcnn": - self.p_z_cond_x_embnn = DGCNNClassification(emb_dims=self.emb_dims, num_heads=self.num_emb_heads) + self.p_z_cond_x_embnn = DGCNNClassification( + emb_dims=self.emb_dims, num_heads=self.num_emb_heads + ) elif self.encoder_type == "2_dgcnn": - self.p_z_cond_x_embnn_action = DGCNNClassification(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, dropout=0.5, output_channels=self.residflow_embnn.latent_z_linear_size) - self.p_z_cond_x_embnn_anchor = DGCNNClassification(emb_dims=self.emb_dims, num_heads=self.num_emb_heads, dropout=0.5, output_channels=self.residflow_embnn.latent_z_linear_size) + self.p_z_cond_x_embnn_action = DGCNNClassification( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + dropout=0.5, + output_channels=self.residflow_embnn.latent_z_linear_size, + ) + self.p_z_cond_x_embnn_anchor = DGCNNClassification( + emb_dims=self.emb_dims, + num_heads=self.num_emb_heads, + dropout=0.5, + output_channels=self.residflow_embnn.latent_z_linear_size, + ) else: raise ValueError() self.center_feature = self.residflow_embnn.center_feature def forward(self, *input): - action_points = input[0].permute(0, 2, 1)[:, :3] # B,3,num_points + action_points = input[0].permute(0, 2, 1)[:, :3] # B,3,num_points anchor_points = input[1].permute(0, 2, 1)[:, :3] # mean center point cloud before DGCNN if self.residflow_embnn.center_feature: - action_points_dmean = action_points - \ - action_points.mean(dim=2, keepdim=True) - anchor_points_dmean = anchor_points - \ - anchor_points.mean(dim=2, keepdim=True) + action_points_dmean = action_points - action_points.mean( + dim=2, keepdim=True + ) + anchor_points_dmean = anchor_points - anchor_points.mean( + dim=2, keepdim=True + ) else: action_points_dmean = action_points anchor_points_dmean = anchor_points @@ -387,20 +552,26 @@ def forward(self, *input): if self.shuffle_for_pzX: action_shuffle_idxs = torch.randperm(action_points_dmean.size()[2]) anchor_shuffle_idxs = torch.randperm(anchor_points_dmean.size()[2]) - action_points_dmean = action_points_dmean[:,:,action_shuffle_idxs] - anchor_points_dmean = anchor_points_dmean[:,:,anchor_shuffle_idxs] + action_points_dmean = action_points_dmean[:, :, action_shuffle_idxs] + anchor_points_dmean = anchor_points_dmean[:, :, anchor_shuffle_idxs] def prepare(arr, is_action): if self.shuffle_for_pzX: shuffle_idxs = action_shuffle_idxs if is_action else anchor_shuffle_idxs - return arr[:,:,torch.argsort(shuffle_idxs)] + return arr[:, :, torch.argsort(shuffle_idxs)] else: return arr if self.encoder_type == "1_dgcnn": - goal_emb_cond_x = self.p_z_cond_x_embnn(torch.cat([action_points_dmean, anchor_points_dmean], dim=-1)) - goal_emb_cond_x_action = prepare(goal_emb_cond_x[:, :, :action_points_dmean.shape[-1]]) - goal_emb_cond_x_anchor = prepare(goal_emb_cond_x[:, :, action_points_dmean.shape[-1]:]) + goal_emb_cond_x = self.p_z_cond_x_embnn( + torch.cat([action_points_dmean, anchor_points_dmean], dim=-1) + ) + goal_emb_cond_x_action = prepare( + goal_emb_cond_x[:, :, : action_points_dmean.shape[-1]] + ) + goal_emb_cond_x_anchor = prepare( + goal_emb_cond_x[:, :, action_points_dmean.shape[-1] :] + ) elif self.encoder_type == "2_dgcnn": # Sample a point goal_emb_cond_x_action = self.p_z_cond_x_embnn_action(action_points_dmean) @@ -408,52 +579,80 @@ def prepare(arr, is_action): if self.num_emb_heads > 1: goal_emb_cond_x = [ - torch.cat([prepare(action_head, True), prepare(anchor_head, False)], dim=-1) - for action_head, anchor_head in zip(goal_emb_cond_x_action, goal_emb_cond_x_anchor) + torch.cat( + [prepare(action_head, True), prepare(anchor_head, False)], + dim=-1, + ) + for action_head, anchor_head in zip( + goal_emb_cond_x_action, goal_emb_cond_x_anchor + ) ] else: - goal_emb_cond_x = torch.cat([prepare(goal_emb_cond_x_action, True), prepare(goal_emb_cond_x_anchor, False)], dim=-1) + goal_emb_cond_x = torch.cat( + [ + prepare(goal_emb_cond_x_action, True), + prepare(goal_emb_cond_x_anchor, False), + ], + dim=-1, + ) else: raise ValueError() - - action_points_and_cond, anchor_points_and_cond, for_debug = Multimodal_ResidualFlow_DiffEmbTransformer.add_conditioning(self, goal_emb_cond_x, action_points, anchor_points, self.conditioning) - + + ( + action_points_and_cond, + anchor_points_and_cond, + for_debug, + ) = Multimodal_ResidualFlow_DiffEmbTransformer.add_conditioning( + self, goal_emb_cond_x, action_points, anchor_points, self.conditioning + ) + tax_pose_conditioning_action = None tax_pose_conditioning_anchor = None if self.conditioning == "latent_z_linear_internalcond": - tax_pose_conditioning_action = torch.tile(for_debug['goal_emb'][:,:,0][:,:,None], (1, 1, action_points.shape[-1])) - tax_pose_conditioning_anchor = torch.tile(for_debug['goal_emb'][:,:,1][:,:,None], (1, 1, anchor_points.shape[-1])) + tax_pose_conditioning_action = torch.tile( + for_debug["goal_emb"][:, :, 0][:, :, None], + (1, 1, action_points.shape[-1]), + ) + tax_pose_conditioning_anchor = torch.tile( + for_debug["goal_emb"][:, :, 1][:, :, None], + (1, 1, anchor_points.shape[-1]), + ) if self.taxpose_centering == "mean": # Use TAX-Pose defaults action_center = action_points[:, :3].mean(dim=2, keepdim=True) anchor_center = anchor_points[:, :3].mean(dim=2, keepdim=True) elif self.taxpose_centering == "z": - action_center = for_debug['trans_pt_action'][:,:,None] - anchor_center = for_debug['trans_pt_anchor'][:,:,None] + action_center = for_debug["trans_pt_action"][:, :, None] + anchor_center = for_debug["trans_pt_anchor"][:, :, None] else: - raise ValueError(f"Unknown self.taxpose_centering: {self.taxpose_centering}") + raise ValueError( + f"Unknown self.taxpose_centering: {self.taxpose_centering}" + ) # Unpermute the action and anchor point clouds to match how tax pose is written - flow_action = self.residflow_embnn.tax_pose(action_points_and_cond.permute(0, 2, 1), anchor_points_and_cond.permute(0, 2, 1), - conditioning_action=tax_pose_conditioning_action, - conditioning_anchor=tax_pose_conditioning_anchor, - action_center=action_center, - anchor_center=anchor_center) + flow_action = self.residflow_embnn.tax_pose( + action_points_and_cond.permute(0, 2, 1), + anchor_points_and_cond.permute(0, 2, 1), + conditioning_action=tax_pose_conditioning_action, + conditioning_anchor=tax_pose_conditioning_anchor, + action_center=action_center, + anchor_center=anchor_center, + ) # If the demo is available, run p(z|Y) if input[2] is not None: # Inputs 2 and 3 are the objects in demo positions # If we have access to these, we can run the pzY network pzY_results = self.residflow_embnn(*input) - goal_emb = pzY_results['goal_emb'] + goal_emb = pzY_results["goal_emb"] else: goal_emb = None flow_action = { **flow_action, - 'goal_emb': goal_emb, - 'goal_emb_cond_x': goal_emb_cond_x, + "goal_emb": goal_emb, + "goal_emb_cond_x": goal_emb_cond_x, **for_debug, } return flow_action diff --git a/taxpose/nets/taxposed_dgcnn.py b/taxpose/nets/taxposed_dgcnn.py new file mode 100644 index 0000000..800da54 --- /dev/null +++ b/taxpose/nets/taxposed_dgcnn.py @@ -0,0 +1,165 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from third_party.dcp.model import get_graph_feature + + +class DGCNN(nn.Module): + """This is a modified version of the DGCNN model from the DCP paper + for variable size inputs and conditioning in the later conv layers + + See: https://github.com/WangYueFt/dcp/blob/master/model.py""" + + def __init__( + self, + emb_dims=512, + input_dims=3, + num_heads=1, + conditioning_size=0, + last_relu=True, + ): + super(DGCNN, self).__init__() + self.input_dims = input_dims + self.num_heads = num_heads + self.conditioning_size = conditioning_size + self.last_relu = last_relu + + self.conv1 = nn.Conv2d(2 * input_dims, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) + self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) + + if self.num_heads == 1: + self.conv5 = nn.Conv2d( + 512 + self.conditioning_size, emb_dims, kernel_size=1, bias=False + ) + self.bn5 = nn.BatchNorm2d(emb_dims) + else: + if self.conditioning_size > 0: + raise NotImplementedError( + "Conditioning not implemented for multi-head DGCNN" + ) + self.conv5s = nn.ModuleList( + [ + nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) + for _ in range(self.num_heads) + ] + ) + self.bn5s = nn.ModuleList( + [nn.BatchNorm2d(emb_dims) for _ in range(self.num_heads)] + ) + + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(64) + self.bn3 = nn.BatchNorm2d(128) + self.bn4 = nn.BatchNorm2d(256) + + def forward(self, x, conditioning=None): + batch_size, num_dims, num_points = x.size() + x = get_graph_feature(x) + x = F.relu(self.bn1(self.conv1(x))) + x1 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn2(self.conv2(x))) + x2 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn3(self.conv3(x))) + x3 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn4(self.conv4(x))) + x4 = x.max(dim=-1, keepdim=True)[0] + + if self.conditioning_size == 0: + assert conditioning is None + x = torch.cat((x1, x2, x3, x4), dim=1) + else: + assert conditioning is not None + x = torch.cat((x1, x2, x3, x4, conditioning[:, :, :, None]), dim=1) + + if self.num_heads == 1: + x = self.bn5(self.conv5(x)).view(batch_size, -1, num_points) + else: + x = [ + bn5(conv5(x)).view(batch_size, -1, num_points) + for bn5, conv5 in zip(self.bn5s, self.conv5s) + ] + + if self.last_relu: + if self.num_heads == 1: + x = F.relu(x) + else: + x = [F.relu(head) for head in x] + return x + + +class DGCNNClassification(nn.Module): + # Reference: https://github.com/WangYueFt/dgcnn/blob/master/pytorch/model.py#L88-L153 + + def __init__( + self, emb_dims=512, input_dims=3, num_heads=1, dropout=0.5, output_channels=40 + ): + super(DGCNNClassification, self).__init__() + self.emb_dims = emb_dims + self.input_dims = input_dims + self.num_heads = num_heads + self.dropout = dropout + self.output_channels = output_channels + self.conv1 = nn.Conv2d(self.input_dims * 2, 64, kernel_size=1, bias=False) + self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) + self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) + self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) + self.conv5 = nn.Conv2d(512, self.emb_dims, kernel_size=1, bias=False) + + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(64) + self.bn3 = nn.BatchNorm2d(128) + self.bn4 = nn.BatchNorm2d(256) + self.bn5 = nn.BatchNorm2d(self.emb_dims) + + self.linear1 = nn.Linear(self.emb_dims * 2, 512, bias=False) + self.bn6 = nn.BatchNorm1d(512) + self.dp1 = nn.Dropout(p=self.dropout) + self.linear2 = nn.Linear(512, 256) + self.bn7 = nn.BatchNorm1d(256) + self.dp2 = nn.Dropout(p=self.dropout) + + if self.num_heads == 1: + self.linear3 = nn.Linear(256, self.output_channels) + else: + self.linear3s = nn.ModuleList( + [nn.Linear(256, self.output_channels) for _ in range(self.num_heads)] + ) + + def forward(self, x): + batch_size, num_dims, num_points = x.size() + x = get_graph_feature(x) + x = F.relu(self.bn1(self.conv1(x))) + x1 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn2(self.conv2(x))) + x2 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn3(self.conv3(x))) + x3 = x.max(dim=-1, keepdim=True)[0] + + x = F.relu(self.bn4(self.conv4(x))) + x4 = x.max(dim=-1, keepdim=True)[0] + + x = torch.cat((x1, x2, x3, x4), dim=1) + + x = self.conv5(x).squeeze() + x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) + x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) + x = torch.cat((x1, x2), 1) + + x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) + x = self.dp1(x) + x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) + x = self.dp2(x) + + if self.num_heads == 1: + x = self.linear3(x)[:, :, None] + else: + x = [linear3(x)[:, :, None] for linear3 in self.linear3s] + return x diff --git a/taxpose/nets/transformer_flow.py b/taxpose/nets/transformer_flow.py index aebb21c..5e73d04 100644 --- a/taxpose/nets/transformer_flow.py +++ b/taxpose/nets/transformer_flow.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from taxpose.nets.pointnet import PointNet +from taxpose.nets.taxposed_dgcnn import DGCNN as CondDGCNN from taxpose.nets.transformer_flow_pm import CustomTransformer from taxpose.nets.tv_mlp import MLP as TVMLP from taxpose.nets.vn_dgcnn import VN_DGCNN, VNArgs @@ -487,13 +488,26 @@ def __init__( self.cycle = cycle self.break_symmetry = break_symmetry self.conditioning_size = conditioning_size + self.emb_nn = emb_nn if emb_nn == "dgcnn": - self.emb_nn_action = DGCNN(emb_dims=self.emb_dims, input_dims=self.input_dims, conditioning_size=self.conditioning_size) - self.emb_nn_anchor = DGCNN(emb_dims=self.emb_dims, input_dims=self.input_dims, conditioning_size=self.conditioning_size) + self.emb_nn_action = DGCNN(emb_dims=self.emb_dims) + self.emb_nn_anchor = DGCNN(emb_dims=self.emb_dims) + elif emb_nn == "cond_dgcnn": + self.emb_nn_action = CondDGCNN( + emb_dims=self.emb_dims, + input_dims=self.input_dims, + conditioning_size=self.conditioning_size, + ) + self.emb_nn_anchor = CondDGCNN( + emb_dims=self.emb_dims, + input_dims=self.input_dims, + conditioning_size=self.conditioning_size, + ) elif emb_nn == "vn_dgcnn": args = VNArgs() - self.emb_nn_action = VN_DGCNN(args, num_part=self.emb_dims, gc=False) # TODO: add input_dims and conditioning_size - self.emb_nn_anchor = VN_DGCNN(args, num_part=self.emb_dims, gc=False) # TODO: add input_dims and conditioning_size + # TODO: add variable input and conditioning + self.emb_nn_action = VN_DGCNN(args, num_part=self.emb_dims, gc=False) + self.emb_nn_anchor = VN_DGCNN(args, num_part=self.emb_dims, gc=False) else: raise Exception("Not implemented") self.center_feature = center_feature @@ -508,7 +522,7 @@ def __init__( self.transformer_anchor = CustomTransformer( emb_dims=emb_dims, return_attn=self.return_attn, bidirectional=False ) - if multilaterate: # TODO: add input_dims + if multilaterate: self.head_action = MultilaterationHead( emb_dims=emb_dims, pred_weight=self.pred_weight, @@ -545,9 +559,17 @@ def __init__( nn.Conv1d(emb_dims_sym * 4, self.emb_dims, kernel_size=1, bias=False), ) - def forward(self, *input, conditioning_action=None, conditioning_anchor=None, action_center=None, anchor_center=None): - action_points = input[0].permute(0, 2, 1)[:, :self.input_dims] # B,3,num_points - anchor_points = input[1].permute(0, 2, 1)[:, :self.input_dims] + def forward( + self, + *input, + conditioning_action=None, + conditioning_anchor=None, + action_center=None, + anchor_center=None, + ): + # B,input_dims,num_points + action_points = input[0].permute(0, 2, 1)[:, : self.input_dims] + anchor_points = input[1].permute(0, 2, 1)[:, : self.input_dims] if action_center is None: action_center = action_points[:, :3].mean(dim=2, keepdim=True) @@ -556,19 +578,17 @@ def forward(self, *input, conditioning_action=None, conditioning_anchor=None, ac action_points_dmean = torch.cat( [ - action_points[:,:3,:] - \ - action_center, - action_points[:,3:,:], + action_points[:, :3, :] - action_center, + action_points[:, 3:, :], ], - dim=1 + dim=1, ) anchor_points_dmean = torch.cat( [ - anchor_points[:,:3,:] - \ - anchor_center, - anchor_points[:,3:,:], + anchor_points[:, :3, :] - anchor_center, + anchor_points[:, 3:, :], ], - dim=1 + dim=1, ) # mean center point cloud before DGCNN @@ -576,8 +596,16 @@ def forward(self, *input, conditioning_action=None, conditioning_anchor=None, ac action_points_dmean = action_points anchor_points_dmean = anchor_points - action_embedding = self.emb_nn_action(action_points_dmean) - anchor_embedding = self.emb_nn_anchor(anchor_points_dmean) + if self.emb_nn == "cond_dgcnn": + action_embedding = self.emb_nn_action( + action_points_dmean, conditioning=conditioning_action + ) + anchor_embedding = self.emb_nn_anchor( + anchor_points_dmean, conditioning=conditioning_anchor + ) + else: + action_embedding = self.emb_nn_action(action_points_dmean) + anchor_embedding = self.emb_nn_anchor(anchor_points_dmean) if self.freeze_embnn: action_embedding = action_embedding.detach() diff --git a/taxpose/training/multimodal_flow_equivariance_training_module_nocentering.py b/taxpose/training/multimodal_flow_equivariance_training_module_nocentering.py index 7b6f2b6..cd504fe 100644 --- a/taxpose/training/multimodal_flow_equivariance_training_module_nocentering.py +++ b/taxpose/training/multimodal_flow_equivariance_training_module_nocentering.py @@ -6,38 +6,48 @@ from taxpose.training.point_cloud_training_module import PointCloudTrainingModule from taxpose.utils.color_utils import get_color, color_gradient from taxpose.utils.error_metrics import get_2rack_errors -from taxpose.utils.se3 import dense_flow_loss, dualflow2pose, get_degree_angle, get_translation +from taxpose.utils.se3 import ( + dense_flow_loss, + dualflow2pose, + get_degree_angle, + get_translation, +) import torch.nn.functional as F import wandb -mse_criterion = nn.MSELoss(reduction='sum') +mse_criterion = nn.MSELoss(reduction="sum") to_tensor = ToTensor() -class Multimodal_EquivarianceTrainingModule(PointCloudTrainingModule): - def __init__(self, - model=None, - lr=1e-3, - image_log_period=500, - action_weight=1, - anchor_weight=1, - smoothness_weight=0.1, - consistency_weight=1, - latent_weight=0.1, - vae_reg_loss_weight=0.01, - rotation_weight=0, - chamfer_weight=10000, - point_loss_type=0, - return_flow_component=False, - weight_normalize='l1', - sigmoid_on=False, - softmax_temperature=None, - min_err_across_racks_debug=False, - error_mode_2rack="batch_min_rack"): - super().__init__(model=model, lr=lr, - image_log_period=image_log_period,) +class Multimodal_EquivarianceTrainingModule(PointCloudTrainingModule): + def __init__( + self, + model=None, + lr=1e-3, + image_log_period=500, + action_weight=1, + anchor_weight=1, + smoothness_weight=0.1, + consistency_weight=1, + latent_weight=0.1, + vae_reg_loss_weight=0.01, + rotation_weight=0, + chamfer_weight=10000, + point_loss_type=0, + return_flow_component=False, + weight_normalize="l1", + sigmoid_on=False, + softmax_temperature=None, + min_err_across_racks_debug=False, + error_mode_2rack="batch_min_rack", + ): + super().__init__( + model=model, + lr=lr, + image_log_period=image_log_period, + ) self.model = model self.lr = lr self.image_log_period = image_log_period @@ -59,8 +69,8 @@ def __init__(self, self.softmax_temperature = softmax_temperature self.min_err_across_racks_debug = min_err_across_racks_debug self.error_mode_2rack = error_mode_2rack - if self.weight_normalize == 'l1': - assert (self.sigmoid_on), "l1 weight normalization need sigmoid on" + if self.weight_normalize == "l1": + assert self.sigmoid_on, "l1 weight normalization need sigmoid on" def action_centered(self, points_action, points_anchor): """ @@ -71,12 +81,16 @@ def action_centered(self, points_action, points_anchor): points_action_mean_centered = points_action - points_action_mean points_anchor_mean_centered = points_anchor - points_action_mean - return points_action_mean_centered, points_anchor_mean_centered, points_action_mean + return ( + points_action_mean_centered, + points_anchor_mean_centered, + points_action_mean, + ) def extract_flow_and_weight(self, x): # x: Batch, num_points, 4 pred_flow = x[:, :, :3] - if(x.shape[2] > 3): + if x.shape[2] > 3: if self.sigmoid_on: pred_w = torch.sigmoid(x[:, :, 3]) else: @@ -86,9 +100,9 @@ def extract_flow_and_weight(self, x): return pred_flow, pred_w def predict(self, model_output, points_trans_action, points_trans_anchor): - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: ixs_action = model_output["sampled_ixs_action"].unsqueeze(-1) @@ -105,51 +119,68 @@ def predict(self, model_output, points_trans_action, points_trans_anchor): ) else: sampled_points_trans_anchor - + pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, - xyz_tgt=sampled_points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.weight_normalize, - temperature=self.softmax_temperature) - - pred_points_action = pred_T_action.transform_points(points_trans_action) + pred_T_action = dualflow2pose( + xyz_src=sampled_points_trans_action, + xyz_tgt=sampled_points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.weight_normalize, + temperature=self.softmax_temperature, + ) - return {"pred_T_action": pred_T_action, - "pred_points_action": pred_points_action} + pred_points_action = pred_T_action.transform_points(points_trans_action) - def get_transform(self, points_trans_action, points_trans_anchor, points_onetrans_action=None, points_onetrans_anchor=None, mode="forward"): - model_output = self.model(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor, - mode=mode) + return { + "pred_T_action": pred_T_action, + "pred_points_action": pred_points_action, + } + + def get_transform( + self, + points_trans_action, + points_trans_anchor, + points_onetrans_action=None, + points_onetrans_anchor=None, + mode="forward", + ): + model_output = self.model( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + mode=mode, + ) points_trans_action = points_trans_action[:, :, :3] points_trans_anchor = points_trans_anchor[:, :, :3] - - ans_dict = self.predict(model_output, - points_trans_action=points_trans_action, - points_trans_anchor=points_trans_anchor) - - ans_dict['flow_components'] = model_output + + ans_dict = self.predict( + model_output, + points_trans_action=points_trans_action, + points_trans_anchor=points_trans_anchor, + ) + + ans_dict["flow_components"] = model_output return ans_dict - def compute_loss(self, model_output, batch, log_values={}, loss_prefix='', heads=None): - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - goal_emb = model_output['goal_emb'] - - points_action = batch['points_action'][:, :, :3] - points_anchor = batch['points_anchor'][:, :, :3] - points_trans_action = batch['points_action_trans'][:, :, :3] - points_trans_anchor = batch['points_anchor_trans'][:, :, :3] + def compute_loss( + self, model_output, batch, log_values={}, loss_prefix="", heads=None + ): + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + goal_emb = model_output["goal_emb"] + + points_action = batch["points_action"][:, :, :3] + points_anchor = batch["points_anchor"][:, :, :3] + points_trans_action = batch["points_action_trans"][:, :, :3] + points_trans_anchor = batch["points_anchor_trans"][:, :, :3] # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: @@ -166,8 +197,8 @@ def compute_loss(self, model_output, batch, log_values={}, loss_prefix='', heads points_trans_anchor, ixs_anchor, dim=1 ) - T0 = Transform3d(matrix=batch['T0']) - T1 = Transform3d(matrix=batch['T1']) + T0 = Transform3d(matrix=batch["T0"]) + T1 = Transform3d(matrix=batch["T1"]) R0_max, R0_min, R0_mean = get_degree_angle(T0) R1_max, R1_min, R1_mean = get_degree_angle(T1) @@ -177,30 +208,36 @@ def compute_loss(self, model_output, batch, log_values={}, loss_prefix='', heads pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - pred_T_action = dualflow2pose(xyz_src=points_trans_action, - xyz_tgt=points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.weight_normalize, - temperature=self.softmax_temperature) - - induced_flow_action = (pred_T_action.transform_points( - points_trans_action) - points_trans_action).detach() - pred_points_action = pred_T_action.transform_points( - points_trans_action) + pred_T_action = dualflow2pose( + xyz_src=points_trans_action, + xyz_tgt=points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.weight_normalize, + temperature=self.softmax_temperature, + ) + + induced_flow_action = ( + pred_T_action.transform_points(points_trans_action) - points_trans_action + ).detach() + pred_points_action = pred_T_action.transform_points(points_trans_action) # pred_T_action=T1T0^-1 gt_T_action = T0.inverse().compose(T1) points_action_target = T1.transform_points(points_action) if self.min_err_across_racks_debug: - error_R_mean, error_t_mean = get_2rack_errors(pred_T_action, T0, T1, mode=self.error_mode_2rack) - log_values[loss_prefix+'error_R_mean'] = error_R_mean - log_values[loss_prefix+'error_t_mean'] = error_t_mean - log_values[loss_prefix+'rotation_loss'] = self.rotation_weight * error_R_mean + error_R_mean, error_t_mean = get_2rack_errors( + pred_T_action, T0, T1, mode=self.error_mode_2rack + ) + log_values[loss_prefix + "error_R_mean"] = error_R_mean + log_values[loss_prefix + "error_t_mean"] = error_t_mean + log_values[loss_prefix + "rotation_loss"] = ( + self.rotation_weight * error_R_mean + ) # Loss associated with ground truth transform point_loss_action = mse_criterion( @@ -210,9 +247,9 @@ def compute_loss(self, model_output, batch, log_values={}, loss_prefix='', heads point_loss = self.action_weight * point_loss_action - dense_loss = dense_flow_loss(points=points_trans_action, - flow_pred=pred_flow_action, - trans_gt=gt_T_action) + dense_loss = dense_flow_loss( + points=points_trans_action, flow_pred=pred_flow_action, trans_gt=gt_T_action + ) # Loss associated flow vectors matching a consistent rigid transform smoothness_loss_action = mse_criterion( @@ -222,106 +259,175 @@ def compute_loss(self, model_output, batch, log_values={}, loss_prefix='', heads smoothness_loss = self.action_weight * smoothness_loss_action - loss = point_loss + self.smoothness_weight * \ - smoothness_loss + self.consistency_weight * dense_loss #+ latent_loss + loss = ( + point_loss + + self.smoothness_weight * smoothness_loss + + self.consistency_weight * dense_loss + ) # + latent_loss - log_values[loss_prefix+'point_loss'] = point_loss + log_values[loss_prefix + "point_loss"] = point_loss - log_values[loss_prefix + - 'smoothness_loss'] = self.smoothness_weight * smoothness_loss - log_values[loss_prefix + - 'dense_loss'] = self.consistency_weight * dense_loss - #log_values[loss_prefix + + log_values[loss_prefix + "smoothness_loss"] = ( + self.smoothness_weight * smoothness_loss + ) + log_values[loss_prefix + "dense_loss"] = self.consistency_weight * dense_loss + # log_values[loss_prefix + # 'latent_loss'] = self.latent_weight * latent_loss if self.model.conditioning in ["uniform_prior_pos_delta_l2norm"]: N = x_action.shape[1] - uniform = (torch.ones((goal_emb.shape[0], goal_emb.shape[1], N)) / goal_emb.shape[-1]).cuda().detach() - action_kl = F.kl_div(F.log_softmax(uniform, dim=-1), - F.log_softmax(goal_emb[:, :, :N], dim=-1), log_target=True, - reduction='batchmean') - anchor_kl = F.kl_div(F.log_softmax(uniform, dim=-1), - F.log_softmax(goal_emb[:, :, N:], dim=-1), log_target=True, - reduction='batchmean') + uniform = ( + ( + torch.ones((goal_emb.shape[0], goal_emb.shape[1], N)) + / goal_emb.shape[-1] + ) + .cuda() + .detach() + ) + action_kl = F.kl_div( + F.log_softmax(uniform, dim=-1), + F.log_softmax(goal_emb[:, :, :N], dim=-1), + log_target=True, + reduction="batchmean", + ) + anchor_kl = F.kl_div( + F.log_softmax(uniform, dim=-1), + F.log_softmax(goal_emb[:, :, N:], dim=-1), + log_target=True, + reduction="batchmean", + ) vae_reg_loss = action_kl + anchor_kl loss += self.vae_reg_loss_weight * vae_reg_loss - log_values[loss_prefix+'vae_reg_loss'] = self.vae_reg_loss_weight * vae_reg_loss + log_values[loss_prefix + "vae_reg_loss"] = ( + self.vae_reg_loss_weight * vae_reg_loss + ) if heads is not None: + def vae_regularization_loss(mu, log_var): # From https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/cvae.py#LL144C9-L144C105 - return torch.mean(-0.5 * (1 + log_var - mu ** 2 - log_var.exp()).sum(dim = -1).sum(dim = -1), dim = 0) - - if self.model.conditioning in ["latent_z", "latent_z_1pred", "latent_z_1pred_10d", "latent_z_linear", "latent_z_linear_internalcond"]: - vae_reg_loss = vae_regularization_loss(heads['goal_emb_mu'], heads['goal_emb_logvar']) + return torch.mean( + -0.5 + * (1 + log_var - mu**2 - log_var.exp()).sum(dim=-1).sum(dim=-1), + dim=0, + ) + + if self.model.conditioning in [ + "latent_z", + "latent_z_1pred", + "latent_z_1pred_10d", + "latent_z_linear", + "latent_z_linear_internalcond", + ]: + vae_reg_loss = vae_regularization_loss( + heads["goal_emb_mu"], heads["goal_emb_logvar"] + ) vae_reg_loss = torch.nan_to_num(vae_reg_loss) - + loss += self.vae_reg_loss_weight * vae_reg_loss - log_values[loss_prefix+'vae_reg_loss'] = self.vae_reg_loss_weight * vae_reg_loss + log_values[loss_prefix + "vae_reg_loss"] = ( + self.vae_reg_loss_weight * vae_reg_loss + ) else: - raise ValueError("ERROR: Why is there a non-None heads variable passed in when the model isn't even a latent_z model?") - - log_values[loss_prefix+'R0_mean'] = R0_mean - log_values[loss_prefix+'R0_max'] = R0_max - log_values[loss_prefix+'R0_min'] = R0_min - log_values[loss_prefix+'R1_mean'] = R1_mean - log_values[loss_prefix+'R1_max'] = R1_max - log_values[loss_prefix+'R1_min'] = R1_min - - log_values[loss_prefix+'t0_mean'] = t0_mean - log_values[loss_prefix+'t0_max'] = t0_max - log_values[loss_prefix+'t0_min'] = t0_min - log_values[loss_prefix+'t1_mean'] = t1_mean - log_values[loss_prefix+'t1_max'] = t1_max - log_values[loss_prefix+'t1_min'] = t1_min + raise ValueError( + "ERROR: Why is there a non-None heads variable passed in when the model isn't even a latent_z model?" + ) + + log_values[loss_prefix + "R0_mean"] = R0_mean + log_values[loss_prefix + "R0_max"] = R0_max + log_values[loss_prefix + "R0_min"] = R0_min + log_values[loss_prefix + "R1_mean"] = R1_mean + log_values[loss_prefix + "R1_max"] = R1_max + log_values[loss_prefix + "R1_min"] = R1_min + + log_values[loss_prefix + "t0_mean"] = t0_mean + log_values[loss_prefix + "t0_max"] = t0_max + log_values[loss_prefix + "t0_min"] = t0_min + log_values[loss_prefix + "t1_mean"] = t1_mean + log_values[loss_prefix + "t1_max"] = t1_max + log_values[loss_prefix + "t1_min"] = t1_min return loss, log_values - def module_step(self, batch, batch_idx, log_prefix=''): - points_trans_action = batch['points_action_trans'] - points_trans_anchor = batch['points_anchor_trans'] - points_onetrans_action = batch['points_action_onetrans'] if 'points_action_onetrans' in batch else batch['points_action'] - points_onetrans_anchor = batch['points_anchor_onetrans'] if 'points_anchor_onetrans' in batch else batch['points_anchor'] + def module_step(self, batch, batch_idx, log_prefix=""): + points_trans_action = batch["points_action_trans"] + points_trans_anchor = batch["points_anchor_trans"] + points_onetrans_action = ( + batch["points_action_onetrans"] + if "points_action_onetrans" in batch + else batch["points_action"] + ) + points_onetrans_anchor = ( + batch["points_anchor_onetrans"] + if "points_anchor_onetrans" in batch + else batch["points_anchor"] + ) # points_action = batch['points_action'] # points_anchor = batch['points_anchor'] # TODO only pass in points_anchor and points_action if the model is training - model_output = self.model(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor) - - if self.model.conditioning not in ["latent_z", "latent_z_1pred", "latent_z_1pred_10d", "latent_z_linear", "latent_z_linear_internalcond"]: + model_output = self.model( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + ) + + if self.model.conditioning not in [ + "latent_z", + "latent_z_1pred", + "latent_z_1pred_10d", + "latent_z_linear", + "latent_z_linear_internalcond", + ]: heads = None else: - heads = {'goal_emb_mu': model_output['goal_emb_mu'], 'goal_emb_logvar': model_output['goal_emb_logvar']} + heads = { + "goal_emb_mu": model_output["goal_emb_mu"], + "goal_emb_logvar": model_output["goal_emb_logvar"], + } log_values = {} loss, log_values = self.compute_loss( - model_output, batch, log_values=log_values, loss_prefix=log_prefix, heads=heads) - + model_output, + batch, + log_values=log_values, + loss_prefix=log_prefix, + heads=heads, + ) + torch.cuda.empty_cache() - + with torch.no_grad(): - def get_inference_error(log_values, batch, loss_prefix): - T0 = Transform3d(matrix=batch['T0']) - T1 = Transform3d(matrix=batch['T1']) - if self.model.conditioning not in ["uniform_prior_pos_delta_l2norm", "latent_z", "latent_z_1pred", "latent_z_1pred_10d", "latent_z_linear", "latent_z_linear_internalcond"]: - inference_mode='forward' + def get_inference_error(log_values, batch, loss_prefix): + T0 = Transform3d(matrix=batch["T0"]) + T1 = Transform3d(matrix=batch["T1"]) + + if self.model.conditioning not in [ + "uniform_prior_pos_delta_l2norm", + "latent_z", + "latent_z_1pred", + "latent_z_1pred_10d", + "latent_z_linear", + "latent_z_linear_internalcond", + ]: + inference_mode = "forward" else: - inference_mode='inference' - - model_output = self.model(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor, - mode=inference_mode) - - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - goal_emb = model_output['goal_emb'] - + inference_mode = "inference" + + model_output = self.model( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + mode=inference_mode, + ) + + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + goal_emb = model_output["goal_emb"] + # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: ixs_action = model_output["sampled_ixs_action"].unsqueeze(-1) @@ -338,117 +444,145 @@ def get_inference_error(log_values, batch, loss_prefix): ) else: sampled_points_trans_anchor = points_trans_anchor - + pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - + del x_action, x_anchor, goal_emb - pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, - xyz_tgt=sampled_points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.weight_normalize, - temperature=self.softmax_temperature) + pred_T_action = dualflow2pose( + xyz_src=sampled_points_trans_action, + xyz_tgt=sampled_points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.weight_normalize, + temperature=self.softmax_temperature, + ) if self.min_err_across_racks_debug: - error_R_mean, error_t_mean = get_2rack_errors(pred_T_action, T0, T1, mode=self.error_mode_2rack) - log_values[loss_prefix+'sample_error_R_mean'] = error_R_mean - log_values[loss_prefix+'sample_error_t_mean'] = error_t_mean - + error_R_mean, error_t_mean = get_2rack_errors( + pred_T_action, T0, T1, mode=self.error_mode_2rack + ) + log_values[loss_prefix + "sample_error_R_mean"] = error_R_mean + log_values[loss_prefix + "sample_error_t_mean"] = error_t_mean + get_inference_error(log_values, batch, loss_prefix=log_prefix) torch.cuda.empty_cache() return loss, log_values - def visualize_results(self, batch, batch_idx, log_prefix=''): + def visualize_results(self, batch, batch_idx, log_prefix=""): # classes = batch['classes'] # points = batch['points'] # points_trans = batch['points_trans'] - points_trans_action = batch['points_action_trans'] - points_trans_anchor = batch['points_anchor_trans'] - points_action = batch['points_action'] - points_anchor = batch['points_anchor'] - points_onetrans_action = batch['points_action_onetrans'] if 'points_action_onetrans' in batch else batch['points_action'] - points_onetrans_anchor = batch['points_anchor_onetrans'] if 'points_anchor_onetrans' in batch else batch['points_anchor'] - - T0 = Transform3d(matrix=batch['T0']) - T1 = Transform3d(matrix=batch['T1']) - - model_output = self.model(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor) - - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - goal_emb = model_output['goal_emb'] - residual_flow_action = model_output['residual_flow_action'] - residual_flow_anchor = model_output['residual_flow_anchor'] - corr_flow_action = model_output['corr_flow_action'] - corr_flow_anchor = model_output['corr_flow_anchor'] + points_trans_action = batch["points_action_trans"] + points_trans_anchor = batch["points_anchor_trans"] + points_action = batch["points_action"] + points_anchor = batch["points_anchor"] + points_onetrans_action = ( + batch["points_action_onetrans"] + if "points_action_onetrans" in batch + else batch["points_action"] + ) + points_onetrans_anchor = ( + batch["points_anchor_onetrans"] + if "points_anchor_onetrans" in batch + else batch["points_anchor"] + ) + + T0 = Transform3d(matrix=batch["T0"]) + T1 = Transform3d(matrix=batch["T1"]) + + model_output = self.model( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + ) + + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + goal_emb = model_output["goal_emb"] + residual_flow_action = model_output["residual_flow_action"] + residual_flow_anchor = model_output["residual_flow_anchor"] + corr_flow_action = model_output["corr_flow_action"] + corr_flow_anchor = model_output["corr_flow_anchor"] points_action = points_action[:, :, :3] points_anchor = points_anchor[:, :, :3] points_trans_action = points_trans_action[:, :, :3] points_trans_anchor = points_trans_anchor[:, :, :3] - + # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: ixs_action = model_output["sampled_ixs_action"].unsqueeze(-1) - sampled_points_action = torch.take_along_dim(points_action, ixs_action, dim=1) + sampled_points_action = torch.take_along_dim( + points_action, ixs_action, dim=1 + ) sampled_points_trans_action = torch.take_along_dim( points_trans_action, ixs_action, dim=1 ) else: + sampled_points_action = points_action sampled_points_trans_action = points_trans_action if "sampled_ixs_anchor" in model_output: ixs_anchor = model_output["sampled_ixs_anchor"].unsqueeze(-1) - sampled_points_anchor = torch.take_along_dim(points_anchor, ixs_anchor, dim=1) + sampled_points_anchor = torch.take_along_dim( + points_anchor, ixs_anchor, dim=1 + ) sampled_points_trans_anchor = torch.take_along_dim( points_trans_anchor, ixs_anchor, dim=1 ) else: + sampled_points_anchor = points_anchor sampled_points_trans_anchor = points_trans_anchor pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, - xyz_tgt=sampled_points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.weight_normalize, - temperature=self.softmax_temperature) - - pred_points_action = pred_T_action.transform_points( - points_trans_action) + pred_T_action = dualflow2pose( + xyz_src=sampled_points_trans_action, + xyz_tgt=sampled_points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.weight_normalize, + temperature=self.softmax_temperature, + ) + + pred_points_action = pred_T_action.transform_points(points_trans_action) points_action_target = T1.transform_points(points_action) res_images = {} demo_points = get_color( - tensor_list=[points_onetrans_action[0], points_onetrans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'demo_points'] = wandb.Object3D( - demo_points) + tensor_list=[points_onetrans_action[0], points_onetrans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "demo_points"] = wandb.Object3D(demo_points) action_transformed_action = get_color( - tensor_list=[points_action[0], points_trans_action[0]], color_list=['blue', 'red']) - res_images[log_prefix+'action_transformed_action'] = wandb.Object3D( - action_transformed_action) + tensor_list=[points_action[0], points_trans_action[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "action_transformed_action"] = wandb.Object3D( + action_transformed_action + ) anchor_transformed_anchor = get_color( - tensor_list=[points_anchor[0], points_trans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'anchor_transformed_anchor'] = wandb.Object3D( - anchor_transformed_anchor) + tensor_list=[points_anchor[0], points_trans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "anchor_transformed_anchor"] = wandb.Object3D( + anchor_transformed_anchor + ) # transformed_input_points = get_color(tensor_list=[ # points_trans_action[0], points_trans_anchor[0]], color_list=['blue', 'red']) @@ -456,24 +590,55 @@ def visualize_results(self, batch, batch_idx, log_prefix=''): # transformed_input_points) demo_points_apply_action_transform = get_color( - tensor_list=[pred_points_action[0], points_trans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'demo_points_apply_action_transform'] = wandb.Object3D( - demo_points_apply_action_transform) + tensor_list=[pred_points_action[0], points_trans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "demo_points_apply_action_transform"] = wandb.Object3D( + demo_points_apply_action_transform + ) apply_action_transform_demo_comparable = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'red']) - res_images[log_prefix+'apply_action_transform_demo_comparable'] = wandb.Object3D( - apply_action_transform_demo_comparable) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=["blue", "red"], + ) + res_images[ + log_prefix + "apply_action_transform_demo_comparable" + ] = wandb.Object3D(apply_action_transform_demo_comparable) predicted_vs_gt_transform_applied = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], points_action[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'green', 'red', ]) - res_images[log_prefix+'predicted_vs_gt_transform_applied'] = wandb.Object3D( - predicted_vs_gt_transform_applied) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + points_action[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=[ + "blue", + "green", + "red", + ], + ) + res_images[log_prefix + "predicted_vs_gt_transform_applied"] = wandb.Object3D( + predicted_vs_gt_transform_applied + ) apply_predicted_transform = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], T1.inverse().transform_points(points_trans_action)[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'orange', 'red', ]) - res_images[log_prefix+'apply_predicted_transform'] = wandb.Object3D( - apply_predicted_transform) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + T1.inverse().transform_points(points_trans_action)[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=[ + "blue", + "orange", + "red", + ], + ) + res_images[log_prefix + "apply_predicted_transform"] = wandb.Object3D( + apply_predicted_transform + ) # loss_points_action = get_color( # tensor_list=[points_action_target[0], pred_points_action[0]], color_list=['green', 'red']) @@ -482,59 +647,90 @@ def visualize_results(self, batch, batch_idx, log_prefix=''): colors_pred_w_action = color_gradient(pred_w_action[0]) colors_pred_w_anchor = color_gradient(pred_w_anchor[0]) - pred_w_points = torch.cat([sampled_points_action[0].detach(), sampled_points_anchor[0].detach()], dim=0).cpu().numpy() - pred_w_on_objects = np.concatenate([ - pred_w_points, - np.concatenate([colors_pred_w_action, colors_pred_w_anchor], axis=0)], - axis=-1) - - res_images[log_prefix+'pred_w'] = wandb.Object3D( - pred_w_on_objects, markerSize=1000) - - if self.model.conditioning not in ["latent_z_linear", "latent_z_linear_internalcond"]: - goal_emb_norm_action = F.softmax(goal_emb[0, :, :points_action.shape[1]], dim=-1).detach().cpu() - goal_emb_norm_anchor = F.softmax(goal_emb[0, :, points_action.shape[1]:], dim=-1).detach().cpu() + pred_w_points = ( + torch.cat( + [sampled_points_action[0].detach(), sampled_points_anchor[0].detach()], + dim=0, + ) + .cpu() + .numpy() + ) + pred_w_on_objects = np.concatenate( + [ + pred_w_points, + np.concatenate([colors_pred_w_action, colors_pred_w_anchor], axis=0), + ], + axis=-1, + ) + + res_images[log_prefix + "pred_w"] = wandb.Object3D( + pred_w_on_objects, markerSize=1000 + ) + + if self.model.conditioning not in [ + "latent_z_linear", + "latent_z_linear_internalcond", + ]: + goal_emb_norm_action = ( + F.softmax(goal_emb[0, :, : points_action.shape[1]], dim=-1) + .detach() + .cpu() + ) + goal_emb_norm_anchor = ( + F.softmax(goal_emb[0, :, points_action.shape[1] :], dim=-1) + .detach() + .cpu() + ) colors_action = color_gradient(goal_emb_norm_action[0]) colors_anchor = color_gradient(goal_emb_norm_anchor[0]) - goal_emb_on_objects = np.concatenate([ - torch.cat([points_action[0].detach(), points_anchor[0].detach()], dim=0).cpu().numpy(), - np.concatenate([colors_action, colors_anchor], axis=0)], - axis=-1) - res_images[log_prefix+'goal_emb'] = wandb.Object3D( - goal_emb_on_objects) + goal_emb_on_objects = np.concatenate( + [ + torch.cat( + [points_action[0].detach(), points_anchor[0].detach()], dim=0 + ) + .cpu() + .numpy(), + np.concatenate([colors_action, colors_anchor], axis=0), + ], + axis=-1, + ) + res_images[log_prefix + "goal_emb"] = wandb.Object3D(goal_emb_on_objects) return res_images class Multimodal_EquivarianceTrainingModule_WithPZCondX(PointCloudTrainingModule): - - def __init__(self, - model_with_cond_x, - training_module_no_cond_x, - goal_emb_cond_x_loss_weight=1, - pzy_pzx_loss_type="reverse_kl", - joint_train_prior=False, - joint_train_prior_freeze_embnn=False, - freeze_residual_flow=False, - freeze_z_embnn=False, - freeze_embnn=False): + def __init__( + self, + model_with_cond_x, + training_module_no_cond_x, + goal_emb_cond_x_loss_weight=1, + pzy_pzx_loss_type="reverse_kl", + joint_train_prior=False, + joint_train_prior_freeze_embnn=False, + freeze_residual_flow=False, + freeze_z_embnn=False, + freeze_embnn=False, + ): # TODO add this in assert pzy_pzx_loss_type in ["reverse_kl", "forward_kl", "mse"] - super().__init__(model=model_with_cond_x, lr=training_module_no_cond_x.lr, - image_log_period=training_module_no_cond_x.image_log_period) + super().__init__( + model=model_with_cond_x, + lr=training_module_no_cond_x.lr, + image_log_period=training_module_no_cond_x.image_log_period, + ) self.model_with_cond_x = model_with_cond_x self.model = self.model_with_cond_x.residflow_embnn self.training_module_no_cond_x = training_module_no_cond_x self.goal_emb_cond_x_loss_weight = goal_emb_cond_x_loss_weight - + self.joint_train_prior = joint_train_prior self.joint_train_prior_freeze_embnn = joint_train_prior_freeze_embnn self.cfg_freeze_residual_flow = freeze_residual_flow self.cfg_freeze_z_embnn = freeze_z_embnn self.cfg_freeze_embnn = freeze_embnn - def action_centered(self, points_action, points_anchor): """ @@ -545,15 +741,19 @@ def action_centered(self, points_action, points_anchor): points_action_mean_centered = points_action - points_action_mean points_anchor_mean_centered = points_anchor - points_action_mean - return points_action_mean_centered, points_anchor_mean_centered, points_action_mean + return ( + points_action_mean_centered, + points_anchor_mean_centered, + points_action_mean, + ) def extract_flow_and_weight(self, *args, **kwargs): return self.training_module_no_cond_x.extract_flow_and_weight(*args, **kwargs) def predict(self, model_output, points_trans_action, points_trans_anchor): - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: ixs_action = model_output["sampled_ixs_action"].unsqueeze(-1) @@ -570,54 +770,65 @@ def predict(self, model_output, points_trans_action, points_trans_anchor): ) else: sampled_points_trans_anchor - + pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, - xyz_tgt=sampled_points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.training_module_no_cond_x.weight_normalize, - temperature=self.training_module_no_cond_x.softmax_temperature) - - pred_points_action = pred_T_action.transform_points(points_trans_action) + pred_T_action = dualflow2pose( + xyz_src=sampled_points_trans_action, + xyz_tgt=sampled_points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.training_module_no_cond_x.weight_normalize, + temperature=self.training_module_no_cond_x.softmax_temperature, + ) - return {"pred_T_action": pred_T_action, - "pred_points_action": pred_points_action} + pred_points_action = pred_T_action.transform_points(points_trans_action) - def get_transform(self, points_trans_action, points_trans_anchor, points_action=None, points_anchor=None, mode="forward"): + return { + "pred_T_action": pred_T_action, + "pred_points_action": pred_points_action, + } + + def get_transform( + self, + points_trans_action, + points_trans_anchor, + points_action=None, + points_anchor=None, + mode="forward", + ): # mode is unused - model_output = self.model_with_cond_x(points_trans_action, - points_trans_anchor, - points_action, - points_anchor) + model_output = self.model_with_cond_x( + points_trans_action, points_trans_anchor, points_action, points_anchor + ) points_trans_action = points_trans_action[:, :, :3] points_trans_anchor = points_trans_anchor[:, :, :3] - - ans_dict = self.predict(model_output, - points_trans_action=points_trans_action, - points_trans_anchor=points_trans_anchor) - - ans_dict['flow_components'] = model_output + + ans_dict = self.predict( + model_output, + points_trans_action=points_trans_action, + points_trans_anchor=points_trans_anchor, + ) + + ans_dict["flow_components"] = model_output return ans_dict - def compute_loss(self, model_output, batch, log_values={}, loss_prefix=''): - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - goal_emb = model_output['goal_emb'] - goal_emb_cond_x = model_output['goal_emb_cond_x'] - + def compute_loss(self, model_output, batch, log_values={}, loss_prefix=""): + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + goal_emb = model_output["goal_emb"] + goal_emb_cond_x = model_output["goal_emb_cond_x"] + # Compute pzY losses using the pzX predictions (except goal_emb which is from pzY) - loss, log_values = self.training_module_no_cond_x.compute_loss(model_output, - batch, - log_values, - loss_prefix) + loss, log_values = self.training_module_no_cond_x.compute_loss( + model_output, batch, log_values, loss_prefix + ) # aka "if it is training time and not val time" if goal_emb is not None: @@ -625,43 +836,69 @@ def compute_loss(self, model_output, batch, log_values={}, loss_prefix=''): if self.model_with_cond_x.conditioning != "latent_z_linear_internalcond": N = x_action.shape[1] - action_kl = F.kl_div(F.log_softmax(goal_emb_cond_x[:, :, :N], dim=-1), - F.log_softmax(goal_emb[:, :, :N], dim=-1), log_target=True, - reduction='batchmean') - anchor_kl = F.kl_div(F.log_softmax(goal_emb_cond_x[:, :, N:], dim=-1), - F.log_softmax(goal_emb[:, :, N:], dim=-1), log_target=True, - reduction='batchmean') + action_kl = F.kl_div( + F.log_softmax(goal_emb_cond_x[:, :, :N], dim=-1), + F.log_softmax(goal_emb[:, :, :N], dim=-1), + log_target=True, + reduction="batchmean", + ) + anchor_kl = F.kl_div( + F.log_softmax(goal_emb_cond_x[:, :, N:], dim=-1), + F.log_softmax(goal_emb[:, :, N:], dim=-1), + log_target=True, + reduction="batchmean", + ) else: - goal_emb_cond_x = goal_emb_cond_x[0] # just take the mean - action_kl = F.kl_div(F.log_softmax(goal_emb_cond_x[:, :, 0], dim=-1), - F.log_softmax(goal_emb[:, :, 0], dim=-1), log_target=True, - reduction='batchmean') - anchor_kl = F.kl_div(F.log_softmax(goal_emb_cond_x[:, :, 1], dim=-1), - F.log_softmax(goal_emb[:, :, 0], dim=-1), log_target=True, - reduction='batchmean') - + goal_emb_cond_x = goal_emb_cond_x[0] # just take the mean + action_kl = F.kl_div( + F.log_softmax(goal_emb_cond_x[:, :, 0], dim=-1), + F.log_softmax(goal_emb[:, :, 0], dim=-1), + log_target=True, + reduction="batchmean", + ) + anchor_kl = F.kl_div( + F.log_softmax(goal_emb_cond_x[:, :, 1], dim=-1), + F.log_softmax(goal_emb[:, :, 0], dim=-1), + log_target=True, + reduction="batchmean", + ) + goal_emb_loss = action_kl + anchor_kl - if self.model_with_cond_x.freeze_residual_flow and self.model_with_cond_x.freeze_z_embnn and self.model_with_cond_x.freeze_embnn: + if ( + self.model_with_cond_x.freeze_residual_flow + and self.model_with_cond_x.freeze_z_embnn + and self.model_with_cond_x.freeze_embnn + ): # Overwrite the loss because the other losses are not used loss = self.goal_emb_cond_x_loss_weight * goal_emb_loss else: # DON'T overwrite the loss loss += self.goal_emb_cond_x_loss_weight * goal_emb_loss - log_values[loss_prefix+'goal_emb_cond_x_loss'] = self.goal_emb_cond_x_loss_weight * goal_emb_loss - log_values[loss_prefix+'action_kl'] = action_kl - log_values[loss_prefix+'anchor_kl'] = anchor_kl + log_values[loss_prefix + "goal_emb_cond_x_loss"] = ( + self.goal_emb_cond_x_loss_weight * goal_emb_loss + ) + log_values[loss_prefix + "action_kl"] = action_kl + log_values[loss_prefix + "anchor_kl"] = anchor_kl return loss, log_values def module_step(self, batch, batch_idx): - points_trans_action = batch['points_action_trans'] - points_trans_anchor = batch['points_anchor_trans'] - points_action = batch['points_action'] - points_anchor = batch['points_anchor'] - points_onetrans_action = batch['points_action_onetrans'] if 'points_action_onetrans' in batch else batch['points_action'] - points_onetrans_anchor = batch['points_anchor_onetrans'] if 'points_anchor_onetrans' in batch else batch['points_anchor'] + points_trans_action = batch["points_action_trans"] + points_trans_anchor = batch["points_anchor_trans"] + points_action = batch["points_action"] + points_anchor = batch["points_anchor"] + points_onetrans_action = ( + batch["points_action_onetrans"] + if "points_action_onetrans" in batch + else batch["points_action"] + ) + points_onetrans_anchor = ( + batch["points_anchor_onetrans"] + if "points_anchor_onetrans" in batch + else batch["points_anchor"] + ) # If joint training prior if self.joint_train_prior: @@ -670,111 +907,147 @@ def module_step(self, batch, batch_idx): self.training_module_no_cond_x.model.freeze_z_embnn = False self.training_module_no_cond_x.model.freeze_embnn = False self.training_module_no_cond_x.model.tax_pose.freeze_embnn = False - + # p(z|Y) pass - pzY_loss, pzY_log_values = self.training_module_no_cond_x.module_step(batch, batch_idx) - + pzY_loss, pzY_log_values = self.training_module_no_cond_x.module_step( + batch, batch_idx + ) + # Potentially freeze components for p(z|X) pass - self.training_module_no_cond_x.model.freeze_residual_flow = self.cfg_freeze_residual_flow - self.training_module_no_cond_x.model.freeze_z_embnn = self.cfg_freeze_z_embnn + self.training_module_no_cond_x.model.freeze_residual_flow = ( + self.cfg_freeze_residual_flow + ) + self.training_module_no_cond_x.model.freeze_z_embnn = ( + self.cfg_freeze_z_embnn + ) self.training_module_no_cond_x.model.freeze_embnn = self.cfg_freeze_embnn - self.training_module_no_cond_x.model.tax_pose.freeze_embnn = self.cfg_freeze_embnn - - model_output = self.model_with_cond_x(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor) + self.training_module_no_cond_x.model.tax_pose.freeze_embnn = ( + self.cfg_freeze_embnn + ) + + model_output = self.model_with_cond_x( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + ) log_values = {} - log_prefix = 'pzX_' if self.joint_train_prior else '' - loss, log_values = self.compute_loss(model_output, - batch, - log_values=log_values, - loss_prefix=log_prefix) - + log_prefix = "pzX_" if self.joint_train_prior else "" + loss, log_values = self.compute_loss( + model_output, batch, log_values=log_values, loss_prefix=log_prefix + ) + if self.joint_train_prior: loss = pzY_loss + loss log_values = {**pzY_log_values, **log_values} - + return loss, log_values - - def visualize_results(self, batch, batch_idx, log_prefix=''): - res_images = self.training_module_no_cond_x.visualize_results(batch, batch_idx, log_prefix='pzY_') - - points_trans_action = batch['points_action_trans'] - points_trans_anchor = batch['points_anchor_trans'] - points_action = batch['points_action'] - points_anchor = batch['points_anchor'] - points_onetrans_action = batch['points_action_onetrans'] if 'points_action_onetrans' in batch else batch['points_action'] - points_onetrans_anchor = batch['points_anchor_onetrans'] if 'points_anchor_onetrans' in batch else batch['points_anchor'] - - T0 = Transform3d(matrix=batch['T0']) - T1 = Transform3d(matrix=batch['T1']) - - model_output = self.model_with_cond_x(points_trans_action, - points_trans_anchor, - points_onetrans_action, - points_onetrans_anchor) - - x_action = model_output['flow_action'] - x_anchor = model_output['flow_anchor'] - goal_emb = model_output['goal_emb'] - goal_emb_cond_x = model_output['goal_emb_cond_x'] + + def visualize_results(self, batch, batch_idx, log_prefix=""): + res_images = self.training_module_no_cond_x.visualize_results( + batch, batch_idx, log_prefix="pzY_" + ) + + points_trans_action = batch["points_action_trans"] + points_trans_anchor = batch["points_anchor_trans"] + points_action = batch["points_action"] + points_anchor = batch["points_anchor"] + points_onetrans_action = ( + batch["points_action_onetrans"] + if "points_action_onetrans" in batch + else batch["points_action"] + ) + points_onetrans_anchor = ( + batch["points_anchor_onetrans"] + if "points_anchor_onetrans" in batch + else batch["points_anchor"] + ) + + T0 = Transform3d(matrix=batch["T0"]) + T1 = Transform3d(matrix=batch["T1"]) + + model_output = self.model_with_cond_x( + points_trans_action, + points_trans_anchor, + points_onetrans_action, + points_onetrans_anchor, + ) + + x_action = model_output["flow_action"] + x_anchor = model_output["flow_anchor"] + goal_emb = model_output["goal_emb"] + goal_emb_cond_x = model_output["goal_emb_cond_x"] points_action = points_action[:, :, :3] points_anchor = points_anchor[:, :, :3] points_trans_action = points_trans_action[:, :, :3] points_trans_anchor = points_trans_anchor[:, :, :3] - + # If we've applied some sampling, we need to extract the predictions too... if "sampled_ixs_action" in model_output: ixs_action = model_output["sampled_ixs_action"].unsqueeze(-1) - sampled_points_action = torch.take_along_dim(points_action, ixs_action, dim=1) + sampled_points_action = torch.take_along_dim( + points_action, ixs_action, dim=1 + ) sampled_points_trans_action = torch.take_along_dim( points_trans_action, ixs_action, dim=1 ) else: + sampled_points_action = points_action sampled_points_trans_action = points_trans_action if "sampled_ixs_anchor" in model_output: ixs_anchor = model_output["sampled_ixs_anchor"].unsqueeze(-1) - sampled_points_anchor = torch.take_along_dim(points_anchor, ixs_anchor, dim=1) + sampled_points_anchor = torch.take_along_dim( + points_anchor, ixs_anchor, dim=1 + ) sampled_points_trans_anchor = torch.take_along_dim( points_trans_anchor, ixs_anchor, dim=1 ) else: + sampled_points_anchor = points_anchor sampled_points_trans_anchor = points_trans_anchor pred_flow_action, pred_w_action = self.extract_flow_and_weight(x_action) pred_flow_anchor, pred_w_anchor = self.extract_flow_and_weight(x_anchor) - pred_T_action = dualflow2pose(xyz_src=sampled_points_trans_action, - xyz_tgt=sampled_points_trans_anchor, - flow_src=pred_flow_action, - flow_tgt=pred_flow_anchor, - weights_src=pred_w_action, - weights_tgt=pred_w_anchor, - return_transform3d=True, - normalization_scehme=self.training_module_no_cond_x.weight_normalize, - temperature=self.training_module_no_cond_x.softmax_temperature) - + pred_T_action = dualflow2pose( + xyz_src=sampled_points_trans_action, + xyz_tgt=sampled_points_trans_anchor, + flow_src=pred_flow_action, + flow_tgt=pred_flow_anchor, + weights_src=pred_w_action, + weights_tgt=pred_w_anchor, + return_transform3d=True, + normalization_scehme=self.training_module_no_cond_x.weight_normalize, + temperature=self.training_module_no_cond_x.softmax_temperature, + ) + pred_points_action = pred_T_action.transform_points(points_trans_action) points_action_target = T1.transform_points(points_action) - + demo_points = get_color( - tensor_list=[points_onetrans_action[0], points_onetrans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'demo_points'] = wandb.Object3D( - demo_points) + tensor_list=[points_onetrans_action[0], points_onetrans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "demo_points"] = wandb.Object3D(demo_points) action_transformed_action = get_color( - tensor_list=[points_action[0], points_trans_action[0]], color_list=['blue', 'red']) - res_images[log_prefix+'action_transformed_action'] = wandb.Object3D( - action_transformed_action) + tensor_list=[points_action[0], points_trans_action[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "action_transformed_action"] = wandb.Object3D( + action_transformed_action + ) anchor_transformed_anchor = get_color( - tensor_list=[points_anchor[0], points_trans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'anchor_transformed_anchor'] = wandb.Object3D( - anchor_transformed_anchor) + tensor_list=[points_anchor[0], points_trans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "anchor_transformed_anchor"] = wandb.Object3D( + anchor_transformed_anchor + ) # transformed_input_points = get_color(tensor_list=[ # points_trans_action[0], points_trans_anchor[0]], color_list=['blue', 'red']) @@ -782,24 +1055,55 @@ def visualize_results(self, batch, batch_idx, log_prefix=''): # transformed_input_points) demo_points_apply_action_transform = get_color( - tensor_list=[pred_points_action[0], points_trans_anchor[0]], color_list=['blue', 'red']) - res_images[log_prefix+'demo_points_apply_action_transform'] = wandb.Object3D( - demo_points_apply_action_transform) + tensor_list=[pred_points_action[0], points_trans_anchor[0]], + color_list=["blue", "red"], + ) + res_images[log_prefix + "demo_points_apply_action_transform"] = wandb.Object3D( + demo_points_apply_action_transform + ) apply_action_transform_demo_comparable = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'red']) - res_images[log_prefix+'apply_action_transform_demo_comparable'] = wandb.Object3D( - apply_action_transform_demo_comparable) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=["blue", "red"], + ) + res_images[ + log_prefix + "apply_action_transform_demo_comparable" + ] = wandb.Object3D(apply_action_transform_demo_comparable) predicted_vs_gt_transform_applied = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], points_action[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'green', 'red', ]) - res_images[log_prefix+'predicted_vs_gt_transform_applied'] = wandb.Object3D( - predicted_vs_gt_transform_applied) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + points_action[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=[ + "blue", + "green", + "red", + ], + ) + res_images[log_prefix + "predicted_vs_gt_transform_applied"] = wandb.Object3D( + predicted_vs_gt_transform_applied + ) apply_predicted_transform = get_color( - tensor_list=[T1.inverse().transform_points(pred_points_action)[0], T1.inverse().transform_points(points_trans_action)[0], T1.inverse().transform_points(points_trans_anchor)[0]], color_list=['blue', 'orange', 'red', ]) - res_images[log_prefix+'apply_predicted_transform'] = wandb.Object3D( - apply_predicted_transform) + tensor_list=[ + T1.inverse().transform_points(pred_points_action)[0], + T1.inverse().transform_points(points_trans_action)[0], + T1.inverse().transform_points(points_trans_anchor)[0], + ], + color_list=[ + "blue", + "orange", + "red", + ], + ) + res_images[log_prefix + "apply_predicted_transform"] = wandb.Object3D( + apply_predicted_transform + ) # loss_points_action = get_color( # tensor_list=[points_action_target[0], pred_points_action[0]], color_list=['green', 'red']) @@ -808,40 +1112,71 @@ def visualize_results(self, batch, batch_idx, log_prefix=''): colors_pred_w_action = color_gradient(pred_w_action[0]) colors_pred_w_anchor = color_gradient(pred_w_anchor[0]) - pred_w_points = torch.cat([sampled_points_action[0].detach(), sampled_points_anchor[0].detach()], dim=0).cpu().numpy() - pred_w_on_objects = np.concatenate([ - pred_w_points, - np.concatenate([colors_pred_w_action, colors_pred_w_anchor], axis=0)], - axis=-1) - - res_images[log_prefix+'pred_w'] = wandb.Object3D( - pred_w_on_objects, markerSize=1000) + pred_w_points = ( + torch.cat( + [sampled_points_action[0].detach(), sampled_points_anchor[0].detach()], + dim=0, + ) + .cpu() + .numpy() + ) + pred_w_on_objects = np.concatenate( + [ + pred_w_points, + np.concatenate([colors_pred_w_action, colors_pred_w_anchor], axis=0), + ], + axis=-1, + ) - # This goal_emb_cond_x visualization only applies to methods that have a per-point latent space - if self.model.conditioning not in ["latent_z_linear", "latent_z_linear_internalcond"]: + res_images[log_prefix + "pred_w"] = wandb.Object3D( + pred_w_on_objects, markerSize=1000 + ) - goal_emb_norm_action = F.softmax(goal_emb_cond_x[0, :, :points_action.shape[1]], dim=-1).detach().cpu() - goal_emb_norm_anchor = F.softmax(goal_emb_cond_x[0, :, points_action.shape[1]:], dim=-1).detach().cpu() + # This goal_emb_cond_x visualization only applies to methods that have a per-point latent space + if self.model.conditioning not in [ + "latent_z_linear", + "latent_z_linear_internalcond", + ]: + goal_emb_norm_action = ( + F.softmax(goal_emb_cond_x[0, :, : points_action.shape[1]], dim=-1) + .detach() + .cpu() + ) + goal_emb_norm_anchor = ( + F.softmax(goal_emb_cond_x[0, :, points_action.shape[1] :], dim=-1) + .detach() + .cpu() + ) # TODO CHANGE THIS. temporary only mug only_mug = False if only_mug: - colors_action = color_gradient(F.softmax(goal_emb_norm_action[0], dim=-1)) + colors_action = color_gradient( + F.softmax(goal_emb_norm_action[0], dim=-1) + ) points = points_action[0].detach().cpu().numpy() - goal_emb_on_objects = np.concatenate([ - points, - colors_action], - axis=-1) + goal_emb_on_objects = np.concatenate([points, colors_action], axis=-1) else: - colors_action = color_gradient(F.softmax(goal_emb_norm_action[0], dim=-1)) - colors_anchor = color_gradient(F.softmax(goal_emb_norm_anchor[0], dim=-1)) - points = torch.cat([points_action[0].detach(), points_anchor[0].detach()], dim=0).cpu().numpy() - goal_emb_on_objects = np.concatenate([ - points, - np.concatenate([colors_action, colors_anchor], axis=0)], - axis=-1) - - res_images['goal_emb_cond_x'] = wandb.Object3D( - goal_emb_on_objects, markerSize=1000) #marker_scale * range_size) + colors_action = color_gradient( + F.softmax(goal_emb_norm_action[0], dim=-1) + ) + colors_anchor = color_gradient( + F.softmax(goal_emb_norm_anchor[0], dim=-1) + ) + points = ( + torch.cat( + [points_action[0].detach(), points_anchor[0].detach()], dim=0 + ) + .cpu() + .numpy() + ) + goal_emb_on_objects = np.concatenate( + [points, np.concatenate([colors_action, colors_anchor], axis=0)], + axis=-1, + ) + + res_images["goal_emb_cond_x"] = wandb.Object3D( + goal_emb_on_objects, markerSize=1000 + ) # marker_scale * range_size) return res_images diff --git a/third_party/dcp/model.py b/third_party/dcp/model.py index f5a1109..05db8d3 100644 --- a/third_party/dcp/model.py +++ b/third_party/dcp/model.py @@ -3,7 +3,6 @@ # Only changes: # - Change `from util import quat2mat` to `from .util import quat2mat`. # - Add this comment. -# - Modify DGCNN and add DGCNNClassification #!/usr/bin/env python # -*- coding: utf-8 -*- @@ -281,100 +280,18 @@ def forward(self, x): class DGCNN(nn.Module): - def __init__(self, emb_dims=512, input_dims=3, num_heads=1, conditioning_size=0, last_relu=True): + def __init__(self, emb_dims=512): super(DGCNN, self).__init__() - self.input_dims = input_dims - self.num_heads = num_heads - self.conditioning_size = conditioning_size - self.last_relu = last_relu - - self.conv1 = nn.Conv2d(2*input_dims, 64, kernel_size=1, bias=False) - self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) - self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) - self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) - - if self.num_heads == 1: - self.conv5 = nn.Conv2d(512 + self.conditioning_size, emb_dims, kernel_size=1, bias=False) - self.bn5 = nn.BatchNorm2d(emb_dims) - else: - if self.conditioning_size > 0: - raise NotImplementedError("Conditioning not implemented for multi-head DGCNN") - self.conv5s = nn.ModuleList([nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) for _ in range(self.num_heads)]) - self.bn5s = nn.ModuleList([nn.BatchNorm2d(emb_dims) for _ in range(self.num_heads)]) - - self.bn1 = nn.BatchNorm2d(64) - self.bn2 = nn.BatchNorm2d(64) - self.bn3 = nn.BatchNorm2d(128) - self.bn4 = nn.BatchNorm2d(256) - - def forward(self, x, conditioning=None): - batch_size, num_dims, num_points = x.size() - x = get_graph_feature(x) - x = F.relu(self.bn1(self.conv1(x))) - x1 = x.max(dim=-1, keepdim=True)[0] - - x = F.relu(self.bn2(self.conv2(x))) - x2 = x.max(dim=-1, keepdim=True)[0] - - x = F.relu(self.bn3(self.conv3(x))) - x3 = x.max(dim=-1, keepdim=True)[0] - - x = F.relu(self.bn4(self.conv4(x))) - x4 = x.max(dim=-1, keepdim=True)[0] - - if self.conditioning_size == 0: - assert conditioning is None - x = torch.cat((x1, x2, x3, x4), dim=1) - else: - assert conditioning is not None - x = torch.cat((x1, x2, x3, x4, conditioning[:,:,:,None]), dim=1) - - if self.num_heads == 1: - x = self.bn5(self.conv5(x)).view(batch_size, -1, num_points) - else: - x = [bn5(conv5(x)).view(batch_size, -1, num_points) for bn5, conv5 in zip(self.bn5s, self.conv5s)] - - if self.last_relu: - if self.num_heads == 1: - x = F.relu(x) - else: - x = [F.relu(head) for head in x] - return x - - -class DGCNNClassification(nn.Module): - # Reference: https://github.com/WangYueFt/dgcnn/blob/master/pytorch/model.py#L88-L153 - - def __init__(self, emb_dims=512, input_dims=3, num_heads=1, dropout=0.5, output_channels=40): - super(DGCNNClassification, self).__init__() - self.emb_dims = emb_dims - self.input_dims = input_dims - self.num_heads = num_heads - self.dropout=dropout - self.output_channels = output_channels - self.conv1 = nn.Conv2d(self.input_dims*2, 64, kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) - self.conv5 = nn.Conv2d(512, self.emb_dims, kernel_size=1, bias=False) - + self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) self.bn3 = nn.BatchNorm2d(128) self.bn4 = nn.BatchNorm2d(256) - self.bn5 = nn.BatchNorm2d(self.emb_dims) - - self.linear1 = nn.Linear(self.emb_dims*2, 512, bias=False) - self.bn6 = nn.BatchNorm1d(512) - self.dp1 = nn.Dropout(p=self.dropout) - self.linear2 = nn.Linear(512, 256) - self.bn7 = nn.BatchNorm1d(256) - self.dp2 = nn.Dropout(p=self.dropout) - - if self.num_heads == 1: - self.linear3 = nn.Linear(256, self.output_channels) - else: - self.linear3s = nn.ModuleList([nn.Linear(256, self.output_channels) for _ in range(self.num_heads)]) + self.bn5 = nn.BatchNorm2d(emb_dims) def forward(self, x): batch_size, num_dims, num_points = x.size() @@ -393,20 +310,7 @@ def forward(self, x): x = torch.cat((x1, x2, x3, x4), dim=1) - x = self.conv5(x).squeeze() - x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) - x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) - x = torch.cat((x1, x2), 1) - - x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) - x = self.dp1(x) - x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) - x = self.dp2(x) - - if self.num_heads == 1: - x = self.linear3(x)[:,:,None] - else: - x = [linear3(x)[:,:,None] for linear3 in self.linear3s] + x = F.relu(self.bn5(self.conv5(x))).view(batch_size, -1, num_points) return x