Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TAXPoseD #18

Open
wants to merge 85 commits into
base: main
Choose a base branch
from
Open

Add TAXPoseD #18

wants to merge 85 commits into from

Conversation

oadonca
Copy link
Contributor

@oadonca oadonca commented Dec 15, 2023

Add TAXPoseD with support for base TAXPose decoder and RLBench training.

Only changes that should possibly affect non TAXPoseD code are the changes to scripts/train_residual_flow.py, taxpose/nets/transformer_flow.py, and third_party/dcp/model.

Currently using a separate base config file (configs/train_ndf_multimodal.yaml), let me know if you would prefer a solution that doesn't need this.

@oadonca oadonca requested review from beneisner and himty December 15, 2023 23:46
@oadonca oadonca changed the title Add TAXPoseD with support for base TAXPose decoder and RLBench Add TAXPoseD Dec 18, 2023
Copy link
Collaborator

@beneisner beneisner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. try running the unit tests and formatter

pytest

and

black (forget how this works...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file looks kind of weird, no? why is the diff so scattered.... any way to make the diff clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, will move these changes into taxpose/nets which should also have a cleaner diff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved these changes out of third_party/... and into a new file in taxpose/nets/taxposed_dgcnn.py, maybe a better file name exists, could also probably merge the other DGCNN implementations (DGCNN_GC, VN_DGCNN) into a single custom dgcnns file in the future.

@@ -280,18 +281,100 @@ def forward(self, x):


class DGCNN(nn.Module):
def __init__(self, emb_dims=512):
def __init__(self, emb_dims=512, input_dims=3, num_heads=1, conditioning_size=0, last_relu=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two thoughts:

  1. In general, I'm trying to keep this file as close as possible to the original code we took from the DCP paper. In its current state it's a true third-party import, and changing it kind of contaminates that. If you want to add additional models, put them in taxpose/nets, and mention at the top that it's derived from this file. I know it's a bit clunky, but I really want to make sure we keep track of various low-level modifications to networks or code we borrow.

  2. What is the motivation behind teh changes to DGCNN here? is it to just make it conditional? in my mind that's a pretty distinct (custom) model that should live in taxpose/nets. Are there other logical changes? i.e. last_relu...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Makes complete sense, will move these changes into a new class in taxpose/nets
  2. Motivation is primarily to handle conditioning, either through an extra channels in the input or added to the later conv layers. Additionally, supports having multiple heads when TAXPoseD vae encoder needs to predict a normal distribution.

)
constrained_axix_angle = rot_ratio * axis_angle_random # max angle is rot_var
R = axis_angle_to_matrix(constrained_axix_angle)
def random_se3(N, rot_var=np.pi/180 * 5, trans_var=0.1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this won't pass black formatting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted this change as i realized it was unnecessary

if emb_nn == "dgcnn":
self.emb_nn_action = DGCNN(emb_dims=self.emb_dims)
self.emb_nn_anchor = DGCNN(emb_dims=self.emb_dims)
self.emb_nn_action = DGCNN(emb_dims=self.emb_dims, input_dims=self.input_dims, conditioning_size=self.conditioning_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if i'd prefer a different case here - like c-dgcnn or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, will add a new case for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a case for "cond_dgcnn"

action_points = input[0].permute(0, 2, 1)[:, :self.input_dims] # B,3,num_points
anchor_points = input[1].permute(0, 2, 1)[:, :self.input_dims]

if action_center is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is action_center / anchor_center here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taxposed vae encoder will output distributions over the action/anchor point clouds which we sample a point from each to encode the demo. Jenny's found some benefit to centering the point clouds about the sampled points instead of mean, so on the taxposed side we pass in either the mean or sampled points for action/anchor point cloud centering as action_center/anchor_center

@@ -606,8 +629,8 @@ def forward(self, *input):
head_action_output = self.head_action(
action_embedding_tf,
anchor_embedding_tf,
action_points,
anchor_points,
action_points[:, :3, :],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what gets reshaped to make this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taxposed action_points/anchor_points will be [Batch size, XYZ + Conditioning size, Num. points], we want the full size inputs for the dgcnn embeddings but for the heads it seems we only need XYZ, additionally seems necessary for dimensions to align.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine, totally isolated from other code in the repo.

softmax_temperature=cfg.task.phase.softmax_temperature,
flow_supervision=cfg.training.flow_supervision,
)
if cfg.model.name in ["taxposed", "taxposed_mlat_s100", "taxposed_mlat_s256"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to extract this conditional into a function above. just so that it's clear what's actually being set up in the main function. (i.e. a network, a model, a dataset, and training).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created functions for network creation, model creation, and maybe loading weights so main function is clear

@beneisner
Copy link
Collaborator

Seems pretty cool - can you post links to some training runs?

Curious to see if the mlat stuff worked as well!

also in general, the more you can isolate changes for these codepaths the better (which is why i commented only on files that touch normal taxpose training). I'm a bit paranoid about logic/structural changes to existing code which is known to work w/o running full training + eval on a few tasks. So the less that happens, the easier it is to say that nothing meaningful has changed about the code.

But in general I'm really impressed with how quickly you got this going. The config complexity alone is annoying enough as it is, so good work figuring that out so fast!

@beneisner
Copy link
Collaborator

Also, can you put some example commands for how to train this thing in scripts/README.md?

You might even try to get this running on the cluster o.O

@oadonca
Copy link
Contributor Author

oadonca commented Dec 19, 2023

Updated code with black formatting and ran pytests, only cases that fail are due to missing checkpoint configs, mainly for evaluations and loading pretrained weights.

Ran the first training stage for TAXPoseD on the RLBench tasks with and without mlat: https://wandb.ai/r-pad/taxposed
Will start second training stage for these shortly but the above should give an indication for a kind of upper bound on performance.

Base automatically changed from back_to_rlbench to main May 14, 2024 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants