Skip to content

Commit

Permalink
Add TAXPoseD support for RelDist decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
oadonca committed Dec 27, 2023
1 parent 3ec7a83 commit a8e0291
Show file tree
Hide file tree
Showing 7 changed files with 376 additions and 339 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

defaults:
- /commands/rlbench/_train_taxposed@_here_
- override /model: taxposed_mlat_s100
- override /task: insert_onto_square_peg
- override /phase: place
- _self_

break_symmetry: False

training:
init_cond_x: True
freeze_embnn: True
freeze_residual_flow: True
freeze_z_embnn: True

load_from_checkpoint: True
checkpoint_file: null

dm:
train_dset:
num_points: 256
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# @package _global_

defaults:
- /commands/rlbench/_train_taxposed@_here_
- override /model: taxposed_mlat_s100
- override /task: insert_onto_square_peg
- override /phase: place
- _self_

break_symmetry: False

dm:
train_dset:
num_points: 256
9 changes: 9 additions & 0 deletions configs/model/taxposed_mlat_s100.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- _taxposed

name: taxposed_mlat_s100

multilaterate: True
mlat_sample: True
mlat_nkps: 100
break_symmetry: False
2 changes: 1 addition & 1 deletion scripts/train_residual_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(cfg):

dm.setup()

if cfg.model.name in ["taxposed"]:
if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]:
TP_input_dims = Multimodal_ResidualFlow_DiffEmbTransformer.TP_INPUT_DIMS[cfg.model.conditioning]

taxpose_decoder_network = ResidualFlow_DiffEmbTransformer(
Expand Down
88 changes: 22 additions & 66 deletions taxpose/nets/multimodal_transformer_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self, residualflow_diffembtransformer, gumbel_temp=0.5, freeze_resi
# assert not freeze_residual_flow and not freeze_z_embnn, "Prob didn't want to freeze residual flow or z embnn when using latent_z_linear"

self.tax_pose = residualflow_diffembtransformer
self.return_flow_component = True

self.emb_dims = self.EMB_DIMS_BY_CONDITIONING[self.conditioning]
self.num_emb_heads = self.NUM_HEADS_BY_CONDITIONING[self.conditioning]
Expand Down Expand Up @@ -265,32 +264,15 @@ def forward(self, *input, mode="forward"):
if self.conditioning in ["latent_z_linear", "latent_z_linear_internalcond"]:
goal_emb = goal_emb[0]

if self.return_flow_component:
if self.freeze_residual_flow:
flow_action['flow_action'] = flow_action['flow_action'].detach()
flow_action['flow_anchor'] = flow_action['flow_anchor'].detach()
flow_action = {
**flow_action,
'goal_emb': goal_emb,
**for_debug,
}
else:
if self.freeze_residual_flow:
flow_action = (flow_action[0].detach(), flow_action[1].detach(), *flow_action[2:])

if self.conditioning in ["latent_z_linear", "latent_z_linear_internalcond"] and mode == "forward":
# These are for the loss
heads = {
k: for_debug[k] for k in ['goal_emb_mu', 'goal_emb_logvar']
}
flow_action = (
*flow_action,
goal_emb,
heads
)
else:
flow_action = (*flow_action, goal_emb)

if self.freeze_residual_flow:
flow_action['flow_action'] = flow_action['flow_action'].detach()
flow_action['flow_anchor'] = flow_action['flow_anchor'].detach()

flow_action = {
**flow_action,
'goal_emb': goal_emb,
**for_debug,
}
return flow_action

def sample(self, action_points, anchor_points):
Expand Down Expand Up @@ -459,45 +441,19 @@ def prepare(arr, is_action):
action_center=action_center,
anchor_center=anchor_center)

if self.residflow_embnn.return_flow_component:
# If the demo is available, run p(z|Y)
if input[2] is not None:
# Inputs 2 and 3 are the objects in demo positions
# If we have access to these, we can run the pzY network
pzY_results = self.residflow_embnn(*input)
goal_emb = pzY_results['goal_emb']
else:
goal_emb = None

flow_action = {
**flow_action,
'goal_emb': goal_emb,
'goal_emb_cond_x': goal_emb_cond_x,
**for_debug,
}
# If the demo is available, run p(z|Y)
if input[2] is not None:
# Inputs 2 and 3 are the objects in demo positions
# If we have access to these, we can run the pzY network
pzY_results = self.residflow_embnn(*input)
goal_emb = pzY_results['goal_emb']
else:
# If the demo is available, run p(z|Y)
if input[2] is not None:
# Inputs 2 and 3 are the objects in demo positions
# If we have access to these, we can run the pzY network
pzY_results = self.residflow_embnn(*input)
goal_emb = pzY_results[2]
else:
goal_emb = None

flow_action = (*flow_action, goal_emb, goal_emb_cond_x)
# TODO put everything into a dictionary later. Getting outputs from a tuple with a changing length is annoying
if self.conditioning in ["latent_z_linear", "latent_z_linear_internalcond"]:
flow_action = (
*flow_action,

# These are for the loss
{
k: for_debug[k] for k in ['goal_emb_mu', 'goal_emb_logvar']
}
)
if self.return_debug:
flow_action = (*flow_action, for_debug)

goal_emb = None

flow_action = {
**flow_action,
'goal_emb': goal_emb,
'goal_emb_cond_x': goal_emb_cond_x,
**for_debug,
}
return flow_action
14 changes: 6 additions & 8 deletions taxpose/nets/transformer_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class ResidualMLPHead(nn.Module):
v_i = f(\phi_i) + \tilde{y}_i - x_i
"""

def __init__(self, emb_dims=512, output_dims=3, pred_weight=True, residual_on=True):
def __init__(self, emb_dims=512, pred_weight=True, residual_on=True):
super(ResidualMLPHead, self).__init__()

self.emb_dims = emb_dims
Expand All @@ -220,7 +220,7 @@ def __init__(self, emb_dims=512, output_dims=3, pred_weight=True, residual_on=Tr
else:
self.proj_flow = nn.Sequential(
PointNet([emb_dims, emb_dims // 2, emb_dims // 4, emb_dims // 8]),
nn.Conv1d(emb_dims // 8, output_dims, kernel_size=1, bias=False),
nn.Conv1d(emb_dims // 8, 3, kernel_size=1, bias=False),
)
self.pred_weight = pred_weight
if self.pred_weight:
Expand Down Expand Up @@ -524,13 +524,11 @@ def __init__(
else:
self.head_action = ResidualMLPHead(
emb_dims=emb_dims,
output_dims=input_dims,
pred_weight=self.pred_weight,
residual_on=self.residual_on,
)
self.head_anchor = ResidualMLPHead(
emb_dims=emb_dims,
output_dims=input_dims,
pred_weight=self.pred_weight,
residual_on=self.residual_on,
)
Expand Down Expand Up @@ -631,8 +629,8 @@ def forward(self, *input, conditioning_action=None, conditioning_anchor=None, ac
head_action_output = self.head_action(
action_embedding_tf,
anchor_embedding_tf,
action_points,
anchor_points,
action_points[:, :3, :],
anchor_points[:, :3, :],
scores=action_attn,
)
flow_action = head_action_output["full_flow"].permute(0, 2, 1)
Expand All @@ -657,8 +655,8 @@ def forward(self, *input, conditioning_action=None, conditioning_anchor=None, ac
head_anchor_output = self.head_anchor(
anchor_embedding_tf,
action_embedding_tf,
anchor_points,
action_points,
anchor_points[:, :3, :],
action_points[:, :3, :],
scores=anchor_attn,
)
flow_anchor = head_anchor_output["full_flow"].permute(0, 2, 1)
Expand Down
Loading

0 comments on commit a8e0291

Please sign in to comment.